├── InGAN.py ├── LICENSE ├── README.md ├── SceneScripts.py ├── configs.py ├── examples └── fruit │ └── fruit.png ├── figs ├── fruits.gif └── monitor_60000.png ├── fruits.gif ├── monitor_60000.png ├── networks.py ├── non_rect.py ├── supp_video.py ├── test.py ├── train.py ├── train_supp_mat.py └── util.py /InGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import networks 4 | from util import random_size, get_scale_weights 5 | import os 6 | import warnings 7 | import numpy as np 8 | 9 | 10 | class LRPolicy(object): 11 | def __init__(self, start, end): 12 | self.start = start 13 | self.end = end 14 | 15 | def __call__(self, citer): 16 | return 1. - max(0., float(citer - self.start) / float(self.end - self.start)) 17 | 18 | 19 | # noinspection PyAttributeOutsideInit 20 | class InGAN: 21 | def __init__(self, conf): 22 | # Acquire configuration 23 | self.conf = conf 24 | self.cur_iter = 0 25 | self.max_iters = conf.max_iters 26 | 27 | # Define input tensor 28 | self.input_tensor = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda() 29 | self.real_example = torch.FloatTensor(1, 3, conf.output_crop_size, conf.output_crop_size).cuda() 30 | 31 | # Define networks 32 | self.G = networks.Generator(conf.G_base_channels, conf.G_num_resblocks, conf.G_num_downscales, conf.G_use_bias, 33 | conf.G_skip) 34 | self.D = networks.MultiScaleDiscriminator(conf.output_crop_size, self.conf.D_max_num_scales, 35 | self.conf.D_scale_factor, self.conf.D_base_channels) 36 | self.GAN_loss_layer = networks.GANLoss() 37 | self.Reconstruct_loss = networks.WeightedMSELoss(use_L1=conf.use_L1) 38 | self.RandCrop = networks.RandomCrop([conf.input_crop_size, conf.input_crop_size], must_divide=conf.must_divide) 39 | self.SwapCrops = networks.SwapCrops(conf.crop_swap_min_size, conf.crop_swap_max_size) 40 | 41 | # Make all networks run on GPU 42 | self.G.cuda() 43 | self.D.cuda() 44 | self.GAN_loss_layer.cuda() 45 | self.Reconstruct_loss.cuda() 46 | self.RandCrop.cuda() 47 | self.SwapCrops.cuda() 48 | 49 | # Define loss function 50 | self.criterionGAN = self.GAN_loss_layer.forward 51 | self.criterionReconstruction = self.Reconstruct_loss.forward 52 | 53 | # Keeping track of losses- prepare tensors 54 | self.losses_G_gan = torch.FloatTensor(conf.print_freq).cuda() 55 | self.losses_D_real = torch.FloatTensor(conf.print_freq).cuda() 56 | self.losses_D_fake = torch.FloatTensor(conf.print_freq).cuda() 57 | self.losses_G_reconstruct = torch.FloatTensor(conf.print_freq).cuda() 58 | if self.conf.reconstruct_loss_stop_iter > 0: 59 | self.losses_D_reconstruct = torch.FloatTensor(conf.print_freq).cuda() 60 | 61 | # Initialize networks 62 | self.G.apply(networks.weights_init) 63 | self.D.apply(networks.weights_init) 64 | 65 | # Initialize optimizers 66 | self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999)) 67 | self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999)) 68 | 69 | # Learning rate scheduler 70 | # First define linearly decaying functions (decay starts at a special iter) 71 | start_decay = conf.lr_start_decay_iter 72 | end_decay = conf.max_iters 73 | # def lr_function(n_iter): 74 | # return 1 - max(0, 1.0 * (n_iter - start_decay) / (conf.max_iters - start_decay)) 75 | lr_function = LRPolicy(start_decay, end_decay) 76 | # Define learning rate schedulers 77 | self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(self.optimizer_G, lr_function) 78 | self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(self.optimizer_D, lr_function) 79 | 80 | # # do we resume from checkpoint? 81 | # if self.conf.resume: 82 | # print('resuming checkpoint {}'.format(self.conf.resume)) 83 | # self.resume(self.conf.resume) 84 | 85 | def save(self, citer=None): 86 | if citer is None: 87 | filename = 'snapshot.pth.tar' 88 | elif isinstance(citer, str): 89 | filename = citer 90 | else: 91 | filename = 'snapshot-{:05d}.pth.tar'.format(citer) 92 | torch.save({'G': self.G.state_dict(), 93 | 'D': self.D.state_dict(), 94 | 'optim_G': self.optimizer_G.state_dict(), 95 | 'optim_D': self.optimizer_D.state_dict(), 96 | 'sched_G': self.lr_scheduler_G.state_dict(), 97 | 'sched_D': self.lr_scheduler_D.state_dict(), 98 | 'loss': self.GAN_loss_layer.state_dict(), 99 | 'iter': citer if citer else self.cur_iter}, 100 | os.path.join(self.conf.output_dir_path, filename)) 101 | 102 | def resume(self, resume_path, test_flag=False): 103 | resume = torch.load(resume_path, map_location={'cuda:5': 'cuda:0'}) 104 | missing = [] 105 | if 'G' in resume: 106 | self.G.load_state_dict(resume['G']) 107 | else: 108 | missing.append('G') 109 | if 'D' in resume: 110 | self.D.load_state_dict(resume['D']) 111 | else: 112 | missing.append('D') 113 | if not test_flag: 114 | if 'optim_G' in resume: 115 | self.optimizer_G.load_state_dict(resume['optim_G']) 116 | else: 117 | missing.append('optimizer G') 118 | if 'optim_D' in resume: 119 | self.optimizer_D.load_state_dict(resume['optim_D']) 120 | else: 121 | missing.append('optimizer D') 122 | if 'sched_G' in resume: 123 | self.lr_scheduler_G.load_state_dict(resume['sched_G']) 124 | else: 125 | missing.append('lr scheduler G') 126 | if 'sched_D' in resume: 127 | self.lr_scheduler_D.load_state_dict(resume['sched_D']) 128 | else: 129 | missing.append('lr scheduler G') 130 | if 'loss' in resume: 131 | self.GAN_loss_layer.load_state_dict(resume['loss']) 132 | else: 133 | missing.append('GAN loss') 134 | if len(missing): 135 | warnings.warn('Missing the following state dicts from checkpoint: {}'.format(', '.join(missing))) 136 | 137 | print('resuming checkpoint {}'.format(self.conf.resume)) 138 | 139 | def test(self, input_tensor, output_size, rand_affine, input_size, run_d_pred=True, run_reconstruct=True): 140 | with torch.no_grad(): 141 | self.G_pred = self.G.forward(Variable(input_tensor.detach()), output_size=output_size, random_affine=rand_affine) 142 | if run_d_pred: 143 | scale_weights_for_output = get_scale_weights(i=self.cur_iter, 144 | max_i=self.conf.D_scale_weights_iter_for_even_scales, 145 | start_factor=self.conf.D_scale_weights_sigma, 146 | input_shape=self.G_pred.shape[2:], 147 | min_size=self.conf.D_min_input_size, 148 | num_scales_limit=self.conf.D_max_num_scales, 149 | scale_factor=self.conf.D_scale_factor) 150 | scale_weights_for_input = get_scale_weights(i=self.cur_iter, 151 | max_i=self.conf.D_scale_weights_iter_for_even_scales, 152 | start_factor=self.conf.D_scale_weights_sigma, 153 | input_shape=input_tensor.shape[2:], 154 | min_size=self.conf.D_min_input_size, 155 | num_scales_limit=self.conf.D_max_num_scales, 156 | scale_factor=self.conf.D_scale_factor) 157 | self.D_preds = [self.D.forward(Variable(input_tensor.detach()), scale_weights_for_input), 158 | self.D.forward(Variable(self.G_pred.detach()), scale_weights_for_output)] 159 | else: 160 | self.D_preds = None 161 | 162 | self.G_preds = [input_tensor, self.G_pred] 163 | 164 | self.reconstruct = self.G.forward(self.G_pred, output_size=input_size, random_affine=-rand_affine) if run_reconstruct else None 165 | 166 | return self.G_preds, self.D_preds, self.reconstruct 167 | 168 | def train_g(self): 169 | # Zeroize gradients 170 | self.optimizer_G.zero_grad() 171 | self.optimizer_D.zero_grad() 172 | 173 | # Determine output size of G (dynamic change) 174 | output_size, random_affine = random_size(orig_size=self.input_tensor.shape[2:], 175 | curriculum=self.conf.curriculum, 176 | i=self.cur_iter, 177 | iter_for_max_range=self.conf.iter_for_max_range, 178 | must_divide=self.conf.must_divide, 179 | min_scale=self.conf.min_scale, 180 | max_scale=self.conf.max_scale, 181 | max_transform_magniutude=self.conf.max_transform_magnitude) 182 | 183 | # Add noise to G input for better generalization (make it ignore the 1/255 binning) 184 | self.input_tensor_noised = self.input_tensor + (torch.rand_like(self.input_tensor) - 0.5) * 2.0 / 255 185 | 186 | # Generator forward pass 187 | self.G_pred = self.G.forward(self.input_tensor_noised, output_size=output_size, random_affine=random_affine) 188 | 189 | # Run generator result through discriminator forward pass 190 | self.scale_weights = get_scale_weights(i=self.cur_iter, 191 | max_i=self.conf.D_scale_weights_iter_for_even_scales, 192 | start_factor=self.conf.D_scale_weights_sigma, 193 | input_shape=self.G_pred.shape[2:], 194 | min_size=self.conf.D_min_input_size, 195 | num_scales_limit=self.conf.D_max_num_scales, 196 | scale_factor=self.conf.D_scale_factor) 197 | d_pred_fake = self.D.forward(self.G_pred, self.scale_weights) 198 | 199 | # If reconstruction-loss is used, run through decoder to reconstruct, then calculate reconstruction loss 200 | if self.conf.reconstruct_loss_stop_iter > self.cur_iter: 201 | self.reconstruct = self.G.forward(self.G_pred, output_size=self.input_tensor.shape[2:], random_affine=-random_affine) 202 | self.loss_G_reconstruct = self.criterionReconstruction(self.reconstruct, self.input_tensor, self.loss_mask) 203 | 204 | # Calculate generator loss, based on discriminator prediction on generator result 205 | self.loss_G_GAN = self.criterionGAN(d_pred_fake, is_d_input_real=True) 206 | 207 | # Generator final loss 208 | # Weighted average of the two losses (if indicated to use reconstruction loss) 209 | if self.conf.reconstruct_loss_stop_iter < self.cur_iter: 210 | self.loss_G = self.loss_G_GAN 211 | else: 212 | self.loss_G = (self.conf.reconstruct_loss_proportion * self.loss_G_reconstruct + self.loss_G_GAN) 213 | 214 | # Calculate gradients 215 | # Note that the gradients are propagated from the loss through discriminator and then through generator 216 | self.loss_G.backward() 217 | 218 | # Update weights 219 | # Note that only generator weights are updated (by definition of the G optimizer) 220 | self.optimizer_G.step() 221 | 222 | # Extra training for the inverse G. The difference between this and the reconstruction is the .detach() which 223 | # makes the training only for the inverse G and not for regular G. 224 | if self.cur_iter > self.conf.G_extra_inverse_train_start_iter: 225 | for _ in range(self.conf.G_extra_inverse_train): 226 | self.optimizer_G.zero_grad() 227 | self.inverse = self.G.forward(self.G_pred.detach(), output_size=self.input_tensor.shape[2:], random_affine=-random_affine) 228 | self.loss_G_inverse = (self.criterionReconstruction(self.inverse, self.input_tensor, self.loss_mask) * 229 | self.conf.G_extra_inverse_train_ratio) 230 | self.loss_G_inverse.backward() 231 | self.optimizer_G.step() 232 | 233 | # Update learning rate scheduler 234 | self.lr_scheduler_G.step() 235 | 236 | def train_d(self): 237 | # Zeroize gradients 238 | self.optimizer_D.zero_grad() 239 | 240 | # Adding noise to D input to prevent overfitting to 1/255 bins 241 | real_example_with_noise = self.real_example + (torch.rand_like(self.real_example[-1]) - 0.5) * 2.0 / 255.0 242 | 243 | # Discriminator forward pass over real example 244 | self.d_pred_real = self.D.forward(real_example_with_noise, self.scale_weights) 245 | 246 | # Adding noise to D input to prevent overfitting to 1/255 bins 247 | # Note that generator result is detached so that gradients are not propagating back through generator 248 | g_pred_with_noise = self.G_pred.detach() + (torch.rand_like(self.G_pred) - 0.5) * 2.0 / 255 249 | 250 | # Discriminator forward pass over generated example example 251 | self.d_pred_fake = self.D.forward(g_pred_with_noise, self.scale_weights) 252 | 253 | # Calculate discriminator loss 254 | self.loss_D_fake = self.criterionGAN(self.d_pred_fake, is_d_input_real=False) 255 | self.loss_D_real = self.criterionGAN(self.d_pred_real, is_d_input_real=True) 256 | self.loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5 257 | 258 | # Calculate gradients 259 | # Note that gradients are not propagating back through generator 260 | # noinspection PyUnresolvedReferences 261 | self.loss_D.backward() 262 | 263 | # Update weights 264 | # Note that only discriminator weights are updated (by definition of the D optimizer) 265 | self.optimizer_D.step() 266 | 267 | # Update learning rate scheduler 268 | self.lr_scheduler_D.step() 269 | 270 | def train_one_iter(self, cur_iter, input_tensors): 271 | # Set inputs as random crops 272 | input_crops = [] 273 | mask_crops = [] 274 | real_example_crops = [] 275 | mask_flag = False 276 | for input_tensor in input_tensors: 277 | real_example_crops += self.RandCrop.forward([input_tensor]) 278 | 279 | if np.random.rand() < self.conf.crop_swap_probability: 280 | swapped_input_tensor, loss_mask = self.SwapCrops.forward(input_tensor) 281 | [input_crop, mask_crop] = self.RandCrop.forward([swapped_input_tensor, loss_mask]) 282 | input_crops.append(input_crop) 283 | mask_crops.append(mask_crop) 284 | mask_flag = True 285 | else: 286 | input_crops.append(real_example_crops[-1]) 287 | 288 | self.input_tensor = torch.cat(input_crops) 289 | self.real_example = torch.cat(real_example_crops) 290 | self.loss_mask = torch.cat(mask_crops) if mask_flag else None 291 | 292 | # Update current iteration 293 | self.cur_iter = cur_iter 294 | 295 | # Run a single forward-backward pass on the model and update weights 296 | # One global iteration includes several iterations of generator and several of discriminator 297 | # (not necessarily equal) 298 | # noinspection PyRedeclaration 299 | for _ in range(self.conf.G_iters): 300 | self.train_g() 301 | 302 | # noinspection PyRedeclaration 303 | for _ in range(self.conf.D_iters): 304 | self.train_d() 305 | 306 | # Accumulate stats 307 | # Accumulating as cuda tensors is much more efficient than passing info from GPU to CPU at every iteration 308 | self.losses_G_gan[cur_iter % self.conf.print_freq] = self.loss_G_GAN.item() 309 | self.losses_D_fake[cur_iter % self.conf.print_freq] = self.loss_D_fake.item() 310 | self.losses_D_real[cur_iter % self.conf.print_freq] = self.loss_D_real.item() 311 | if self.conf.reconstruct_loss_stop_iter > self.cur_iter: 312 | self.losses_G_reconstruct[cur_iter % self.conf.print_freq] = self.loss_G_reconstruct.item() 313 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Weizmann Institute of Science 2 | Academic Non Commercial Software Code License 3 | [InGAN- "InGAN: Capturing and Retargeting the DNA of a Natural Image"] (the "Work") 4 | © 2018 The Weizmann Institute of Science ("WIS") and Yeda Research and Development Company Ltd. ("Yeda") All Rights Reserved 5 | 6 | 1. YEDA, the commercial arm of WIS, hereby grants you, an individual or a legal entity exercising rights under, and complying with all of the provisions, of this License (“You”) a royalty-free, non-exclusive, sublicensable, worldwide license to: use, copy, modify, create derivative works (including without limiting to: adapt, alter, transform), integrate with other works, distribute, enable access (including without limiting to: communicate copies), publicly display and perform the Work in binary form or in source code, for academic and noncommercial use only and subject to all provisions of this License: 7 | 2. YEDA hereby grants You a royalty-free, non-exclusive, sublicensable, worldwide license under patents claimed or owned by YEDA that are embodied in the Work, to make, have made and use the Work under the License, for avoidance of doubt for academic and noncommercial use only. 8 | 3. Distribution or provision of access to the Work and to derivative works of the Work ("Derivative Works") may be made only under this License, accompanied with a copy of the source code or a reference to an online repository where such source code can be accessed. 9 | 4. Neither the names of WIS or Yeda, nor any of their trademarks or service marks, may be used to endorse or promote Derivative Works or for any other purpose except as expressly permitted hereunder. 10 | 5. Except as expressly stated in this License, nothing in this License grants any license to trademarks, copyrights, patents, trade secrets or any other intellectual property of WIS or Yeda. No license is granted to the trademarks of WIS or Yeda's even if such marks are included in the Work. 11 | 6. Nothing in this License shall be interpreted to prohibit WIS or Yeda from licensing the Work under terms different from this License. For commercial use please e-mail Yeda at: info.yeda@weizmann.ac.il 12 | 7. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Work, as well as a notice to inform recipients that You have modified the Work with a description of such modifications. 13 | 8. THE WORK IS PROVIDED "AS IS" AND WITHOUT ANY WARRANTIES WHATSOEVER, EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION ANY WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. 14 | 9. IN NO EVENT WILL WIS, YEDA OR ANY OF THEIR RELATED ENTITES, SCIENTISTS, EMPLOYEES, MANAGERS OR ANY OTHE PERSON ACTING ON THEIR BEHALF, BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY OR CAUSE OF ACTION, WHETHER IN CONTRACT, TORT, STRICT LIABILITY, UNJUST ENRICHMENT OR ANY OTHER, ARISING IN ANY WAY OUT OF THE USE OF THE WORK OR THIS LICENSE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | 10. This License will terminate automatically if any of its conditions is not met, or in case You commence an action, including a cross-claim or counterclaim, against WIS or YEDA or any licensee alleging that the Work (except due to combination with other software or hardware) infringes a patent. 16 | 11. This License shall be exclusively governed by the laws of the State of Israel, without giving effect to conflict of laws principles, and the competent courts in Tel Aviv will have exclusive jurisdiction and venue over any matter between You and WIS or YEDA or any of their related entities relating to this License or the Work. 17 | 12. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InGAN 2 | ### Official code for the paper "InGAN: Capturing and Retargeting the DNA of a Natural Image" 3 | 4 | Project page: http://www.wisdom.weizmann.ac.il/~vision/ingan/ (See our results and visual comparison to other methods) 5 | 6 | **Accepted ICCV'19 (Oral)** 7 | ---------- 8 | ![](/figs/fruits.gif) 9 | ---------- 10 | If you find our work useful in your research or publication, please cite our work: 11 | 12 | ``` 13 | @InProceedings{InGAN, 14 | author = {Assaf Shocher and Shai Bagon and Phillip Isola and Michal Irani}, 15 | title = {InGAN: Capturing and Retargeting the "DNA" of a Natural Image}, 16 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 17 | year = {2019} 18 | } 19 | ``` 20 | ---------- 21 | 22 | # Usage: 23 | ## Test 24 | ### Quick example 25 | First you have to [download the example checkpoint file](http://www.wisdom.weizmann.ac.il/~vision/ingan/resources/checkpoint_0075000.pth.tar), and put it in ``` InGAN/examples/fruit/ ```. 26 | Will defaulty run on the fruits image, using an existing checkpoint. 27 | ``` 28 | python test.py 29 | ``` 30 | 31 | ### General testing 32 | By default, when testing you get a collage of various sizes and a smooth video of the transforms. You can also choose to test specific sizes, non-rectangular transforms and more. 33 | 34 | See configs.py, for all the options. You can either edit this file or modify configuration from command-line. 35 | Examples: 36 | ``` 37 | python test.py --input_image_path /path/to/some/image.png # choose input image 38 | python test.py --test_non_rect # also output non rectangular transformation results 39 | python test.py --test_vid_scale 2.0, 0.5, 2.5, 0.2 # boundary scales for output video: [max_v, min_v, max_h, min_h] 40 | ``` 41 | Please see configs.py for many more options 42 | 43 | 44 | ## Train 45 | ### Quick example 46 | Will defaulty run on the fruits image. 47 | ``` 48 | python train.py 49 | ``` 50 | ### General training 51 | See configs.py for all the options. You can either edit this file or modify configuration from command-line. 52 | Examples: 53 | ``` 54 | python train.py --input_image_path /path/to/some/image.png # choose input image 55 | python train.py --G_num_resblocks 3 # change number of residual block in the generator 56 | ``` 57 | Please see configs.py for many more options 58 | ### monitoring 59 | In you results folder, monitor files will be periodically created, example: 60 | ![](/figs/monitor_60000.png) 61 | 62 | ## Produce complex animations by scripts: 63 | Please see the file supp_video.py 64 | 65 | ## Parallel training for many images 66 | Please see the file train_supp_mat.py 67 | -------------------------------------------------------------------------------- /SceneScripts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def make_scene_script(script_name, min_v, max_v, min_h, max_h, max_t, repeat, show_input=True, frames_per_resize=10): 5 | l = np.linspace 6 | 7 | if script_name == 'vertical_grow_shrink': 8 | size_v = np.concatenate([ 9 | l(1, max_v, frames_per_resize), 10 | l(max_v, min_v, 2 * frames_per_resize), 11 | l(min_v, 1, frames_per_resize)]) 12 | size_h = np.concatenate([ 13 | l(1, 1, frames_per_resize), 14 | l(1, 1, 2 * frames_per_resize), 15 | l(1, 1, frames_per_resize)]) 16 | shift_l = [0 for _ in size_v] 17 | shift_r = [0 for _ in size_v] 18 | 19 | elif script_name == 'horizontal_grow_shrink': 20 | size_v = np.concatenate([ 21 | l(1, 1, frames_per_resize), 22 | l(1, 1, 2 * frames_per_resize), 23 | l(1, 1, frames_per_resize)]) 24 | size_h = np.concatenate([ 25 | l(1, max_h, frames_per_resize), 26 | l(max_h, min_h, 2 * frames_per_resize), 27 | l(min_h, 1, frames_per_resize)]) 28 | shift_l = [0 for _ in size_v] 29 | shift_r = [0 for _ in size_v] 30 | 31 | elif script_name == 'horizontal_grow_shrink_slow': 32 | size_v = np.concatenate([ 33 | l(1, 1, 2 *frames_per_resize), 34 | l(1, 1, 2 * frames_per_resize), 35 | l(1, 1, frames_per_resize)]) 36 | size_h = np.concatenate([ 37 | l(1, max_h, 2 * frames_per_resize), 38 | l(max_h, min_h, 2 * frames_per_resize), 39 | l(min_h, 1, frames_per_resize)]) 40 | shift_l = [0 for _ in size_v] 41 | shift_r = [0 for _ in size_v] 42 | 43 | elif script_name == '2d_grow_shrink': 44 | size_v = np.concatenate([ 45 | l(1, max_v, frames_per_resize), 46 | l(max_v, min_v, 2 * frames_per_resize), 47 | l(min_v, 1, frames_per_resize)]) 48 | size_h = np.concatenate([ 49 | l(1, max_h, frames_per_resize), 50 | l(max_h, min_h, 2 * frames_per_resize), 51 | l(min_h, 1, frames_per_resize)]) 52 | shift_l = [0 for _ in size_v] 53 | shift_r = [0 for _ in size_v] 54 | 55 | elif script_name == 'resize_round': 56 | size_v = np.concatenate([ 57 | l(1, 1, frames_per_resize), 58 | l(1, max_v, frames_per_resize), 59 | l(max_v, max_v, 2 * frames_per_resize), 60 | l(max_v, min_v, 2 * frames_per_resize), 61 | l(min_v, 1, frames_per_resize)]) 62 | size_h = np.concatenate([ 63 | l(1, max_h, frames_per_resize), 64 | l(max_h, max_h, frames_per_resize), 65 | l(max_h, min_h, 2 * frames_per_resize), 66 | l(min_h, min_h, 2 * frames_per_resize), 67 | l(min_h, 1, frames_per_resize)]) 68 | shift_l = [0 for _ in size_v] 69 | shift_r = [0 for _ in size_v] 70 | 71 | elif script_name == 'special_resize_round': 72 | size_v = np.concatenate([ 73 | l(1, 1, frames_per_resize/2), 74 | l(1, max_v, frames_per_resize), 75 | l(max_v, max_v, frames_per_resize), 76 | l(max_v, max_v, 2 * frames_per_resize), 77 | l(max_v, min_v, 2 * frames_per_resize), 78 | l(min_v, 1, frames_per_resize)]) 79 | 80 | size_h = np.concatenate([ 81 | l(1, max_h/2, frames_per_resize/2), 82 | l(max_h/2, max_h/2, frames_per_resize), 83 | l(max_h/2, max_h, frames_per_resize), 84 | l(max_h, min_h, 2 * frames_per_resize), 85 | l(min_h, min_h, 2 * frames_per_resize), 86 | l(min_h, 1, frames_per_resize)]) 87 | shift_l = [0 for _ in size_v] 88 | shift_r = [0 for _ in size_v] 89 | 90 | elif script_name == 'special_zoom': 91 | size_v = np.concatenate([ 92 | l(1, max_v, frames_per_resize), 93 | l(max_v, min_v, frames_per_resize), 94 | l(min_v, 1, frames_per_resize)]) 95 | size_h = np.concatenate([ 96 | l(1, max_v, frames_per_resize), 97 | l(max_v, min_v, frames_per_resize), 98 | l(min_v, 1, frames_per_resize)]) 99 | shift_l = [0 for _ in size_v] 100 | shift_r = [0 for _ in size_v] 101 | 102 | elif script_name == 'affine_dance': 103 | shift_l = np.concatenate([ 104 | l(0, max_t, frames_per_resize), 105 | l(max_t, - max_t, 2 * frames_per_resize), 106 | l(- max_t, 0, frames_per_resize)]) 107 | shift_r = np.concatenate([ 108 | l(0, - max_t, frames_per_resize), 109 | l(- max_t, max_t, 2 * frames_per_resize), 110 | l(max_t, 0, frames_per_resize)]) 111 | size_v = [1for _ in shift_l] 112 | size_h = [1 for _ in shift_l] 113 | 114 | elif script_name == 'trapezoids': 115 | shift_l = np.concatenate([ 116 | l(0, max_t, frames_per_resize), 117 | l(max_t, - max_t, 2 * frames_per_resize), 118 | l(- max_t, max_t, 2 * frames_per_resize), 119 | l(max_t, 0, frames_per_resize)]) 120 | shift_r = np.concatenate([ 121 | l(0, max_t, frames_per_resize), 122 | l(max_t, - max_t, 2 * frames_per_resize), 123 | l(- max_t, max_t, 2 * frames_per_resize), 124 | l(max_t, 0, frames_per_resize)]) 125 | size_v = [1for _ in shift_l] 126 | size_h = [1 for _ in shift_l] 127 | 128 | elif script_name == 'trapezoids_vresize': 129 | shift_l = np.concatenate([ 130 | l(0, max_t, frames_per_resize), 131 | l(max_t, - max_t, 2 * frames_per_resize), 132 | l(- max_t, max_t, 2 * frames_per_resize), 133 | l(max_t, 0, frames_per_resize)]) 134 | shift_r = np.concatenate([ 135 | l(0, max_t, frames_per_resize), 136 | l(max_t, - max_t, 2 * frames_per_resize), 137 | l(- max_t, max_t, 2 * frames_per_resize), 138 | l(max_t, 0, frames_per_resize)]) 139 | size_v = np.concatenate([ 140 | l(1, max_v, frames_per_resize), 141 | l(max_v, 1, frames_per_resize), 142 | l(1, max_v, frames_per_resize), 143 | l(max_v, 1, frames_per_resize), 144 | l(1, max_v, frames_per_resize), 145 | l(max_v, 1, frames_per_resize), 146 | ]) 147 | size_h = np.concatenate([ 148 | l(1, 1, 6*frames_per_resize)]) 149 | 150 | elif script_name == 'flicker': 151 | size_h = np.concatenate([ 152 | l(1, 1, 6 * frames_per_resize)]) 153 | size_v = size_h 154 | shift_l = np.concatenate([ 155 | l(max_t, max_t, frames_per_resize), 156 | l(-max_t, -max_t, frames_per_resize), 157 | l(max_t, max_t, frames_per_resize), 158 | l(-max_t, -max_t, frames_per_resize), 159 | l(max_t, max_t, frames_per_resize), 160 | l(-max_t, -max_t, frames_per_resize),]) 161 | shift_r = np.concatenate([ 162 | l(-max_t, -max_t, frames_per_resize), 163 | l(max_t, max_t, frames_per_resize), 164 | l(-max_t, -max_t, frames_per_resize), 165 | l(max_t, max_t, frames_per_resize), 166 | l(-max_t, -max_t, frames_per_resize), 167 | l(max_t, max_t, frames_per_resize)]) 168 | 169 | elif script_name == 'homography': 170 | size_h = np.concatenate([ 171 | l(1, 1, 6 * frames_per_resize)]) 172 | size_v = size_h 173 | shift_l = np.concatenate([ 174 | l(0, max_t, frames_per_resize), 175 | l(max_t, max_t, frames_per_resize), 176 | l(max_t, - max_t, 2 * frames_per_resize), 177 | l(- max_t, - max_t, 2 * frames_per_resize), 178 | l(- max_t, 0, frames_per_resize)]) 179 | shift_r = np.concatenate([ 180 | l(0, 0, frames_per_resize), 181 | l(0, max_t, frames_per_resize), 182 | l(max_t, max_t, 2 * frames_per_resize), 183 | l(max_t, - max_t, 2 * frames_per_resize), 184 | l(- max_t, 0, frames_per_resize)]) 185 | 186 | 187 | 188 | elif script_name == 'random': 189 | stops = np.random.rand(10, 4) * np.array([max_v-min_v, max_h-min_h, 2*max_t, 2*max_t])[None, :] + np.array([min_v, min_h, -max_t, -max_t])[None, :] 190 | stops = np.vstack([stops, [1, 1, 0, 0]]) 191 | print stops 192 | 193 | size_v = np.concatenate([l(stop_0[0], stop_1[0], frames_per_resize) 194 | for stop_0, stop_1 in zip(np.vstack(([1, 1, 0, 0], stops)), stops)]) 195 | 196 | size_h = np.concatenate([l(stop_0[1], stop_1[1], frames_per_resize) 197 | for stop_0, stop_1 in zip(np.vstack(([1, 1, 0, 0], stops)), stops)]) 198 | 199 | shift_l = np.concatenate([l(stop_0[2], stop_1[2], frames_per_resize) 200 | for stop_0, stop_1 in zip(np.vstack(([1, 1, 0, 0], stops)), stops)]) 201 | 202 | shift_r = np.concatenate([l(stop_0[3], stop_1[3], frames_per_resize) 203 | for stop_0, stop_1 in zip(np.vstack(([1, 1, 0, 0], stops)), stops)]) 204 | 205 | elif script_name == 'random_trapezoids': 206 | stops_l = np.random.rand(11) * 2 * max_t - max_t 207 | stops_l[-1] = 0 208 | stops_r = np.random.rand(11) * max_t * (stops_l / np.abs(stops_l)) 209 | stops = zip(stops_l, stops_r) 210 | print stops 211 | 212 | size_h = np.concatenate([ 213 | l(1, 1, 20 * frames_per_resize)]) 214 | size_v = size_h 215 | 216 | shift_l = np.concatenate([l(stop_0[0], stop_1[0], frames_per_resize) 217 | for stop_0, stop_1 in zip(np.vstack(([0, 0], stops)), stops)]) 218 | 219 | shift_r = np.concatenate([l(stop_0[1], stop_1[1], frames_per_resize) 220 | for stop_0, stop_1 in zip(np.vstack(([0, 0], stops)), stops)]) 221 | 222 | 223 | return [[-1, -1, -1, -1]] * 20 + zip(size_v, size_h, shift_l, shift_r) * repeat if show_input else zip(size_v, size_h, shift_l, shift_r) * repeat 224 | 225 | 226 | INPUT_DICT = { 227 | 'fruits': ['fruits_ss.png', '/experiment_old_code_with_homo_2/results/fruits_ss_geo_new_pad_Mar_16_18_00_17/checkpoint_0075000.pth.tar'], 228 | 'farm_house': ['farm_house_s.png', '/results/farm_house_s_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_Nov_03_16_07_59/checkpoint_0050000.pth.tar'], 229 | 'cab_building': ['cab_building_s.png', '/results/cab_building_s_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_Nov_03_18_10_25/checkpoint_0065000.pth.tar'], 230 | 'capitol': ['capitol.png', '/results/capitol_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_Nov_03_18_13_22/checkpoint_0055000.pth.tar'], 231 | 'rome': ['rome_s.png', '/results/rome_s_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_Nov_03_18_09_19/checkpoint_0045000.pth.tar'], 232 | 'soldiers': ['china_soldiers.png', '/results/china_soldiers_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_NOISE2G_Nov_05_09_46_09/checkpoint_0075000.pth.tar'], 233 | 'corn': ['corn.png', '/results/corn_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_NOISE2G_Nov_05_10_29_00/checkpoint_0075000.pth.tar'], 234 | 'sushi': ['sushi.png', '/results/sushi_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_NOISE2G_Nov_05_07_47_39/checkpoint_0075000.pth.tar'], 235 | 'penguins': ['penguins.png', '/results/penguins_Nov_13_16_26_14/checkpoint_0075000.pth.tar'], 236 | 'emojis': ['emojis3.png', '/results/emojis3_Nov_23_09_59_59/checkpoint_0075000.pth.tar'], 237 | 'fish': ['input/fish.png', '/results/fish_plethora_75_Mar_18_03_36_25/checkpoint_0075000.pth.tar'], 238 | 'ny': ['textures/ny.png', '/results/ny_texture_synth_Mar_19_04_51_14/checkpoint_0075000.pth.tar'], 239 | 'metal_circles': ['metal_circles.jpg', '/results/metal_circles_Mar_26_20_04_11/checkpoint_0075000.pth.tar'], 240 | 'quilt': ['quilt.png', '/results/quilt/checkpoint_0075000.pth.tar'], 241 | 'sapa': ['sapa.png', '/results/sapa_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_NOISE2G_Nov_05_09_44_59/checkpoint_0075000.pth.tar'], 242 | 'nkorea': ['nkorea.png', '/results/nkorea_L1_Dfactor_14_WeightsEqualeThenFine_25_LRdecay_20_curric_NOISE2G_Nov_05_07_48_00/checkpoint_0075000.pth.tar'], 243 | 'wood': ['wood.png', '/results/wood/checkpoint_0075000.pth.tar'], 244 | 'starry': ['starry.png', '/results/starry/checkpoint_0075000.pth.tar'], 245 | 'umbrella': ['umbrella.png', '/results/umbrella/checkpoint_0075000.pth.tar'], 246 | 'fruits_old': ['fruits_ss.png', '/results/fruits_ss_256_COARSE2FINE_extraInv_2_30_until60_killReconstruct_20_Oct_24_12_35_33/checkpoint_0040000.pth.tar'], 247 | 'peacock': ['scaled_nird/ours_1_scaled.jpg', '/results/ours_1/checkpoint_0050000.pth.tar'], 248 | 'windows': ['scaled_nird/ours_2_scaled.jpg', '/results/ours_2/checkpoint_0050000.pth.tar'], 249 | 'light_house': ['scaled_nird/ours_23_scaled.jpg', '/results/ours_23/checkpoint_0050000.pth.tar'], 250 | 'hats': ['scaled_nird/ours_26_scaled.jpg', '/results/ours_26/checkpoint_0050000.pth.tar'], 251 | 'nature': ['scaled_nird/ours_32_scaled.jpg', '/results/ours_32/checkpoint_0050000.pth.tar'], 252 | 253 | } 254 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | from util import prepare_result_dir 5 | 6 | 7 | # noinspection PyPep8 8 | class Config: 9 | def __init__(self): 10 | self.parser = argparse.ArgumentParser() 11 | self.conf = None 12 | 13 | # Paths 14 | self.parser.add_argument('--input_image_path', default=[os.path.dirname(os.path.abspath(__file__)) + '/examples/fruit/fruit.png'], nargs='+', help='path to one specific image file') 15 | self.parser.add_argument('--output_dir_path', default=os.path.dirname(os.path.abspath(__file__)) + '/results', help='path to a directory to save results to') 16 | self.parser.add_argument('--name', default='fruit', help='name of current experiment, to be used for saving the results') 17 | self.parser.add_argument('--resume', type=str, default=None, help='checkpoint to resume from') 18 | self.parser.add_argument('--test_params_path', type=str, default=os.path.dirname(os.path.abspath(__file__)) + '/examples/fruit/checkpoint_0075000.pth.tar', help='checkpoint for testing') 19 | 20 | # Test 21 | self.parser.add_argument('--test_collage', default=True, action='store_true', help='Create collage in test?') 22 | self.parser.add_argument('--test_video', default=True, action='store_true', help='Create retarget-video in test?') 23 | self.parser.add_argument('--test_non_rect', default=False, action='store_true', help='Produce non-rectangular transformations in test?') 24 | self.parser.add_argument('--test_vid_scales', type=float, default=[2.2, 0.1, 2.2, 0.1], nargs='+', help='boundary scales for output video: [max_v, min_v, max_h, min_h]') 25 | self.parser.add_argument('--collage_scales', type=float, default=[2.0, 1.25, 1.0, 0.66, 0.33], nargs='+', help='scales for collage (h=w, only one number)') 26 | self.parser.add_argument('--collage_input_spot', type=float, default=[2, 2], nargs='+', help='replaces one spot in the collage with original input. must match a spot with scale 1.0') 27 | self.parser.add_argument('--non_rect_shift_range', type=float, default=[-0.8, 1.0, 0.2], nargs='+', help='range for homography shifts for non rect transforms [min, max, step]') 28 | self.parser.add_argument('--non_rect_scales', type=float, default=[0.7, 1.0], nargs='+', help='list of scales for non_rect outputs') 29 | 30 | # Architecture (Generator) 31 | self.parser.add_argument('--G_base_channels', type=int, default=64, help='# of base channels in G') 32 | self.parser.add_argument('--G_num_resblocks', type=int, default=6, help='# of resblocks in G\'s bottleneck') 33 | self.parser.add_argument('--G_num_downscales', type=int, default=3, help='# of downscaling layers in G') 34 | self.parser.add_argument('--G_use_bias', type=bool, default=True, help='Determinhes whether bias is used in G\'s conv layers') 35 | self.parser.add_argument('--G_skip', type=bool, default=True, help='Determines wether G uses skip connections (U-net)') 36 | 37 | # Architecture (Discriminator) 38 | self.parser.add_argument('--D_base_channels', type=int, default=64, help='# of base channels in D') 39 | self.parser.add_argument('--D_max_num_scales', type=int, default=99, help='Limits the # of scales for the multiscale D') 40 | self.parser.add_argument('--D_scale_factor', type=float, default=1.4, help='Determines the downscaling factor for multiscale D') 41 | self.parser.add_argument('--D_scale_weights_sigma', type=float, default=1.4, help='Determines the downscaling factor for multiscale D') 42 | self.parser.add_argument('--D_min_input_size', type=int, default=13, help='Determines the downscaling factor for multiscale D') 43 | self.parser.add_argument('--D_scale_weights_iter_for_even_scales', type=int, default=25000, help='Determines the downscaling factor for multiscale D') 44 | 45 | # Optimization hyper-parameters 46 | self.parser.add_argument('--g_lr', type=float, default=0.00005, help='initial learning rate for generator') 47 | self.parser.add_argument('--d_lr', type=float, default=0.00005, help='initial learning rate for discriminator') 48 | self.parser.add_argument('--lr_start_decay_iter', type=float, default=20000, help='iteration from which linear decay of lr starts until max_iter') 49 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 50 | self.parser.add_argument('--curriculum', type=bool, default=True, help='Enable curriculum learning') 51 | self.parser.add_argument('--iter_for_max_range', type=int, default=10000, help='In curriculum learning, when getting to this iteration all range is covered') 52 | 53 | # Sizes 54 | self.parser.add_argument('--input_crop_size', type=int, default=256, help='input is cropped to this size') 55 | self.parser.add_argument('--output_crop_size', type=int, default=256, help='output is cropped to this size') 56 | self.parser.add_argument('--max_scale', type=float, default=2.25, help='max retargeting scale') 57 | self.parser.add_argument('--min_scale', type=float, default=0.15, help='min retargeting scale') 58 | self.parser.add_argument('--must_divide', type=int, default=8, help='In curriculum learning, when getting to this iteration all range is covered') 59 | self.parser.add_argument('--max_transform_magnitude', type=float, default=0.0, help='max manitude of geometric transformation') 60 | 61 | # Crop Swap 62 | self.parser.add_argument('--crop_swap_min_size', type=int, default=32, help='swapping crops augmnetation') 63 | self.parser.add_argument('--crop_swap_max_size', type=int, default=256, help='swapping crops augmnetation') 64 | self.parser.add_argument('--crop_swap_probability', type=float, default=0.0, help='probability for crop swapping to occur') 65 | 66 | # GPU 67 | self.parser.add_argument('--gpu_id', type=int, default=0, help='gpu id number') 68 | 69 | # Monitoring display frequencies 70 | self.parser.add_argument('--display_freq', type=int, default=200, help='frequency of showing training results on screen') 71 | self.parser.add_argument('--print_freq', type=int, default=20, help='frequency of showing training results on console') 72 | self.parser.add_argument('--save_snapshot_freq', type=int, default=5000, help='frequency of saving the latest results') 73 | 74 | # Iterations 75 | self.parser.add_argument('--max_iters', type=int, default=75000, help='max # of iters') 76 | self.parser.add_argument('--G_iters', type=int, default=1, help='# of sub-iters for the generator per each global iteration') 77 | self.parser.add_argument('--D_iters', type=int, default=1, help='# of sub-iters for the discriminator per each global iteration') 78 | 79 | # Losses 80 | self.parser.add_argument('--reconstruct_loss_proportion', type=float, default=0.1, help='relative part of reconstruct-loss (out of 1)') 81 | self.parser.add_argument('--reconstruct_loss_stop_iter', type=int, default=200000, help='from this iter and on, reconstruct loss is deactivated') 82 | self.parser.add_argument('--G_extra_inverse_train', type=int, default=1, help='number of extra training iters for G on inverse direction') 83 | self.parser.add_argument('--G_extra_inverse_train_start_iter', type=int, default=10000, help='number of extra training iters for G on inverse direction') 84 | self.parser.add_argument('--G_extra_inverse_train_ratio', type=int, default=1.0, help='number of extra training iters for G on inverse direction') 85 | self.parser.add_argument('--use_L1', type=bool, default=True, help='Determine whether to use L1 or L2 for reconstruction') 86 | 87 | # Misc 88 | self.parser.add_argument('--create_code_copy', type=bool, default=True, help='when set to true, all .py files are saved to results directory to keep track') 89 | 90 | def parse(self, create_dir_flag=True): 91 | # Parse arguments 92 | self.conf = self.parser.parse_args() 93 | 94 | # set gpu ids 95 | torch.cuda.set_device(self.conf.gpu_id) 96 | 97 | # Create results dir if does not exist 98 | if create_dir_flag: 99 | self.conf.output_dir_path = prepare_result_dir(self.conf) 100 | 101 | return self.conf 102 | -------------------------------------------------------------------------------- /examples/fruit/fruit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/InGAN/86b5b682b02aac24b49344a8ae69b1262596cccf/examples/fruit/fruit.png -------------------------------------------------------------------------------- /figs/fruits.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/InGAN/86b5b682b02aac24b49344a8ae69b1262596cccf/figs/fruits.gif -------------------------------------------------------------------------------- /figs/monitor_60000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/InGAN/86b5b682b02aac24b49344a8ae69b1262596cccf/figs/monitor_60000.png -------------------------------------------------------------------------------- /fruits.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/InGAN/86b5b682b02aac24b49344a8ae69b1262596cccf/fruits.gif -------------------------------------------------------------------------------- /monitor_60000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/InGAN/86b5b682b02aac24b49344a8ae69b1262596cccf/monitor_60000.png -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import non_rect 7 | from util import homography_based_on_top_corners_x_shift, homography_grid 8 | 9 | 10 | def weights_init(m): 11 | """ This is used to initialize weights of any network """ 12 | class_name = m.__class__.__name__ 13 | if class_name.find('Conv') != -1: 14 | nn.init.xavier_normal_(m.weight, 0.01) 15 | if hasattr(m.bias, 'data'): 16 | m.bias.data.fill_(0) 17 | elif class_name.find('nn.BatchNorm2d') != -1: 18 | m.weight.data.normal_(1.0, 0.02) 19 | m.bias.data.fill_(0) 20 | 21 | elif class_name.find('LocalNorm') != -1: 22 | m.weight.data.normal_(1.0, 0.02) 23 | m.bias.data.fill_(0) 24 | 25 | 26 | class LocalNorm(nn.Module): 27 | def __init__(self, num_features): 28 | super(LocalNorm, self).__init__() 29 | self.weight = nn.Parameter(torch.Tensor(num_features)) 30 | self.bias = nn.Parameter(torch.Tensor(num_features)) 31 | self.get_local_mean = nn.AvgPool2d(33, 1, 16, count_include_pad=False) 32 | 33 | self.get_var = nn.AvgPool2d(33, 1, 16, count_include_pad=False) 34 | 35 | def forward(self, input_tensor): 36 | local_mean = self.get_local_mean(input_tensor) 37 | print local_mean 38 | centered_input_tensor = input_tensor - local_mean 39 | print centered_input_tensor 40 | squared_diff = centered_input_tensor ** 2 41 | print squared_diff 42 | local_std = self.get_var(squared_diff) ** 0.5 43 | print local_std 44 | normalized_tensor = centered_input_tensor / (local_std + 1e-8) 45 | 46 | return normalized_tensor # * self.weight[None, :, None, None] + self.bias[None, :, None, None] 47 | 48 | 49 | normalization_layer = nn.BatchNorm2d # BatchReNorm2d # LocalNorm 50 | 51 | 52 | class GANLoss(nn.Module): 53 | """ Receiving the final layer form the discriminator and a boolean indicating whether the input to the 54 | discriminator is real or fake (generated by generator), this returns a patch""" 55 | 56 | def __init__(self): 57 | super(GANLoss, self).__init__() 58 | 59 | # Initialize label tensor 60 | self.label_tensor = None 61 | 62 | # Loss tensor is prepared in network initialization. 63 | # Note: When activated as a loss between to feature-maps, then a loss-map is created. However, using defaults 64 | # for BCEloss, this map is averaged and reduced to a single scalar 65 | self.loss = nn.MSELoss() 66 | 67 | def forward(self, d_last_layer, is_d_input_real): 68 | # Determine label map according to whether current input to discriminator is real or fake 69 | self.label_tensor = Variable(torch.ones_like(d_last_layer).cuda(), requires_grad=False) * is_d_input_real 70 | 71 | # Finally return the loss 72 | return self.loss(d_last_layer, self.label_tensor) 73 | 74 | 75 | class WeightedMSELoss(nn.Module): 76 | def __init__(self, use_L1=False): 77 | super(WeightedMSELoss, self).__init__() 78 | 79 | self.unweighted_loss = nn.L1Loss() if use_L1 else nn.MSELoss() 80 | 81 | def forward(self, input_tensor, target_tensor, loss_mask): 82 | if loss_mask is not None: 83 | e = (target_tensor.detach() - input_tensor) ** 2 84 | e *= loss_mask 85 | return torch.sum(e) / torch.sum(loss_mask) 86 | else: 87 | return self.unweighted_loss(input_tensor, target_tensor) 88 | 89 | 90 | class MultiScaleLoss(nn.Module): 91 | def __init__(self): 92 | super(MultiScaleLoss, self).__init__() 93 | 94 | self.mse = nn.MSELoss() 95 | 96 | def forward(self, input_tensor, target_tensor, scale_weights): 97 | 98 | # Run all nets over all scales and aggregate the interpolated results 99 | loss = 0 100 | for i, scale_weight in enumerate(scale_weights): 101 | input_tensor = f.interpolate(input_tensor, scale_factor=self.scale_factor**(-i), mode='bilinear') 102 | loss += scale_weight * self.mse(input_tensor, target_tensor) 103 | return loss 104 | 105 | 106 | class Generator(nn.Module): 107 | """ Architecture of the Generator, uses res-blocks """ 108 | 109 | def __init__(self, base_channels=64, n_blocks=6, n_downsampling=3, use_bias=True, skip_flag=True): 110 | super(Generator, self).__init__() 111 | 112 | # Determine whether to use skip connections 113 | self.skip = skip_flag 114 | 115 | # Entry block 116 | # First conv-block, no stride so image dims are kept and channels dim is expanded (pad-conv-norm-relu) 117 | self.entry_block = nn.Sequential(nn.ReflectionPad2d(3), 118 | nn.utils.spectral_norm(nn.Conv2d(3, base_channels, kernel_size=7, bias=use_bias)), 119 | normalization_layer(base_channels), 120 | nn.LeakyReLU(0.2, True)) 121 | 122 | # Geometric transformation 123 | self.geo_transform = GeoTransform() 124 | 125 | # Downscaling 126 | # A sequence of strided conv-blocks. Image dims shrink by 2, channels dim expands by 2 at each block 127 | self.downscale_block = RescaleBlock(n_downsampling, 0.5, base_channels, True) 128 | 129 | # Bottleneck 130 | # A sequence of res-blocks 131 | bottleneck_block = [] 132 | for _ in range(n_blocks): 133 | # noinspection PyUnboundLocalVariable 134 | bottleneck_block += [ResnetBlock(base_channels * 2 ** n_downsampling, use_bias=use_bias)] 135 | self.bottleneck_block = nn.Sequential(*bottleneck_block) 136 | 137 | # Upscaling 138 | # A sequence of transposed-conv-blocks, Image dims expand by 2, channels dim shrinks by 2 at each block\ 139 | self.upscale_block = RescaleBlock(n_downsampling, 2.0, base_channels, True) 140 | 141 | # Final block 142 | # No stride so image dims are kept and channels dim shrinks to 3 (output image channels) 143 | self.final_block = nn.Sequential(nn.ReflectionPad2d(3), 144 | nn.Conv2d(base_channels, 3, kernel_size=7), 145 | nn.Tanh()) 146 | 147 | def forward(self, input_tensor, output_size, random_affine): 148 | # A condition for having the output at same size as the scaled input is having even output_size 149 | 150 | # Entry block 151 | feature_map = self.entry_block(input_tensor) 152 | 153 | # Change scale to output scale by interpolation 154 | if random_affine is None: 155 | feature_map = f.interpolate(feature_map, size=output_size, mode='bilinear') 156 | else: 157 | feature_map = self.geo_transform.forward(feature_map, output_size, random_affine) 158 | 159 | # Downscale block 160 | feature_map, downscales = self.downscale_block.forward(feature_map, return_all_scales=self.skip) 161 | 162 | # Bottleneck (res-blocks) 163 | feature_map = self.bottleneck_block(feature_map) 164 | 165 | # Upscale block 166 | feature_map, _ = self.upscale_block.forward(feature_map, pyramid=downscales, skip=self.skip) 167 | 168 | # Final block 169 | output_tensor = self.final_block(feature_map) 170 | 171 | return output_tensor 172 | 173 | 174 | class ResnetBlock(nn.Module): 175 | """ A single Res-Block module """ 176 | 177 | def __init__(self, dim, use_bias): 178 | super(ResnetBlock, self).__init__() 179 | 180 | # A res-block without the skip-connection, pad-conv-norm-relu-pad-conv-norm 181 | self.conv_block = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(dim, dim // 4, kernel_size=1, bias=use_bias)), 182 | normalization_layer(dim // 4), 183 | nn.LeakyReLU(0.2, True), 184 | nn.ReflectionPad2d(1), 185 | nn.utils.spectral_norm(nn.Conv2d(dim // 4, dim // 4, kernel_size=3, bias=use_bias)), 186 | normalization_layer(dim // 4), 187 | nn.LeakyReLU(0.2, True), 188 | nn.utils.spectral_norm(nn.Conv2d(dim // 4, dim, kernel_size=1, bias=use_bias)), 189 | normalization_layer(dim)) 190 | 191 | def forward(self, input_tensor): 192 | # The skip connection is applied here 193 | return input_tensor + self.conv_block(input_tensor) 194 | 195 | 196 | class MultiScaleDiscriminator(nn.Module): 197 | def __init__(self, real_crop_size, max_n_scales=9, scale_factor=2, base_channels=128, extra_conv_layers=0): 198 | super(MultiScaleDiscriminator, self).__init__() 199 | self.base_channels = base_channels 200 | self.scale_factor = scale_factor 201 | self.min_size = 16 202 | self.extra_conv_layers = extra_conv_layers 203 | 204 | # We want the max num of scales to fit the size of the real examples. further scaling would create networks that 205 | # only train on fake examples 206 | self.max_n_scales = np.min([np.int(np.ceil(np.log(np.min(real_crop_size) * 1.0 / self.min_size) 207 | / np.log(self.scale_factor))), max_n_scales]) 208 | 209 | # Prepare a list of all the networks for all the wanted scales 210 | self.nets = nn.ModuleList() 211 | 212 | # Create a network for each scale 213 | for _ in range(self.max_n_scales): 214 | self.nets.append(self.make_net()) 215 | 216 | def make_net(self): 217 | base_channels = self.base_channels 218 | net = [] 219 | 220 | # Entry block 221 | net += [nn.utils.spectral_norm(nn.Conv2d(3, base_channels, kernel_size=3, stride=1)), 222 | nn.BatchNorm2d(base_channels), 223 | nn.LeakyReLU(0.2, True)] 224 | 225 | # Downscaling blocks 226 | # A sequence of strided conv-blocks. Image dims shrink by 2, channels dim expands by 2 at each block 227 | net += [nn.utils.spectral_norm(nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, stride=2)), 228 | nn.BatchNorm2d(base_channels * 2), 229 | nn.LeakyReLU(0.2, True)] 230 | 231 | # Regular conv-block 232 | net += [nn.utils.spectral_norm(nn.Conv2d(in_channels=base_channels * 2, 233 | out_channels=base_channels * 2, 234 | kernel_size=3, 235 | bias=True)), 236 | nn.BatchNorm2d(base_channels * 2), 237 | nn.LeakyReLU(0.2, True)] 238 | 239 | # Additional 1x1 conv-blocks 240 | for _ in range(self.extra_conv_layers): 241 | net += [nn.utils.spectral_norm(nn.Conv2d(in_channels=base_channels * 2, 242 | out_channels=base_channels * 2, 243 | kernel_size=3, 244 | bias=True)), 245 | nn.BatchNorm2d(base_channels * 2), 246 | nn.LeakyReLU(0.2, True)] 247 | 248 | # Final conv-block 249 | # Ends with a Sigmoid to get a range of 0-1 250 | net += nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(base_channels * 2, 1, kernel_size=1)), 251 | nn.Sigmoid()) 252 | 253 | # Make it a valid layers sequence and return 254 | return nn.Sequential(*net) 255 | 256 | def forward(self, input_tensor, scale_weights): 257 | aggregated_result_maps_from_all_scales = self.nets[0](input_tensor) * scale_weights[0] 258 | map_size = aggregated_result_maps_from_all_scales.shape[2:] 259 | 260 | # Run all nets over all scales and aggregate the interpolated results 261 | for net, scale_weight, i in zip(self.nets[1:], scale_weights[1:], range(1, len(scale_weights))): 262 | downscaled_image = f.interpolate(input_tensor, scale_factor=self.scale_factor**(-i), mode='bilinear') 263 | result_map_for_current_scale = net(downscaled_image) 264 | upscaled_result_map_for_current_scale = f.interpolate(result_map_for_current_scale, 265 | size=map_size, 266 | mode='bilinear') 267 | aggregated_result_maps_from_all_scales += upscaled_result_map_for_current_scale * scale_weight 268 | 269 | return aggregated_result_maps_from_all_scales 270 | 271 | 272 | class RescaleBlock(nn.Module): 273 | def __init__(self, n_layers, scale=0.5, base_channels=64, use_bias=True): 274 | super(RescaleBlock, self).__init__() 275 | 276 | self.scale = scale 277 | 278 | self.conv_layers = [None] * n_layers 279 | 280 | in_channel_power = scale > 1 281 | out_channel_power = scale < 1 282 | i_range = range(n_layers) if scale < 1 else range(n_layers-1, -1, -1) 283 | 284 | for i in i_range: 285 | self.conv_layers[i] = nn.Sequential(nn.ReflectionPad2d(1), 286 | nn.utils.spectral_norm(nn.Conv2d( 287 | in_channels=base_channels * 2 ** (i + in_channel_power), 288 | out_channels=base_channels * 2 ** (i + out_channel_power), 289 | kernel_size=3, 290 | stride=1, 291 | bias=use_bias)), 292 | normalization_layer(base_channels * 2 ** (i + out_channel_power)), 293 | nn.LeakyReLU(0.2, True)) 294 | self.add_module("conv_%d" % i, self.conv_layers[i]) 295 | 296 | if scale > 1: 297 | self.conv_layers = self.conv_layers[::-1] 298 | 299 | self.max_pool = nn.MaxPool2d(2, 2) 300 | 301 | def forward(self, input_tensor, pyramid=None, return_all_scales=False, skip=False): 302 | 303 | feature_map = input_tensor 304 | all_scales = [] 305 | if return_all_scales: 306 | all_scales.append(feature_map) 307 | 308 | for i, conv_layer in enumerate(self.conv_layers): 309 | 310 | if self.scale > 1.0: 311 | feature_map = f.interpolate(feature_map, scale_factor=self.scale, mode='nearest') 312 | 313 | feature_map = conv_layer(feature_map) 314 | 315 | if skip: 316 | feature_map = feature_map + pyramid[-i-2] 317 | 318 | if self.scale < 1.0: 319 | feature_map = self.max_pool(feature_map) 320 | 321 | if return_all_scales: 322 | all_scales.append(feature_map) 323 | 324 | return (feature_map, all_scales) if return_all_scales else (feature_map, None) 325 | 326 | 327 | class RandomCrop(nn.Module): 328 | def __init__(self, crop_size, return_pos=False, must_divide=4.0): 329 | super(RandomCrop, self).__init__() 330 | 331 | # Determine crop size 332 | self.crop_size = crop_size 333 | self.must_divide = must_divide 334 | self.return_pos = return_pos 335 | 336 | def forward(self, input_tensors, crop_size=None): 337 | im_v_sz, im_h_sz = input_tensors[0].shape[2:] 338 | if crop_size is None: 339 | cr_v_sz, cr_h_sz = np.clip(self.crop_size, [0, 0], [im_v_sz-1, im_h_sz-1]) 340 | cr_v_sz, cr_h_sz = np.uint32(np.floor(np.array([cr_v_sz, cr_h_sz]) 341 | * 1.0 / self.must_divide) * self.must_divide) 342 | else: 343 | cr_v_sz, cr_h_sz = crop_size 344 | 345 | top_left_v, top_left_h = [np.random.randint(0, im_v_sz - cr_v_sz), np.random.randint(0, im_h_sz - cr_h_sz)] 346 | 347 | out_tensors = [input_tensor[:, :, top_left_v:top_left_v + cr_v_sz, top_left_h:top_left_h + cr_h_sz] 348 | if input_tensor is not None else None for input_tensor in input_tensors] 349 | 350 | return (out_tensors, (top_left_v, top_left_h)) if self.return_pos else out_tensors 351 | 352 | 353 | class SwapCrops(nn.Module): 354 | def __init__(self, min_crop_size, max_crop_size, mask_width=5): 355 | super(SwapCrops, self).__init__() 356 | 357 | self.rand_crop_1 = RandomCrop(None, return_pos=True) 358 | self.rand_crop_2 = RandomCrop(None, return_pos=True) 359 | 360 | self.min_crop_size = min_crop_size 361 | self.max_crop_size = max_crop_size 362 | 363 | self.mask_width = mask_width 364 | 365 | def forward(self, input_tensor): 366 | cr_v_sz, cr_h_sz = np.uint32(np.random.rand(2) * (self.max_crop_size - self.min_crop_size) + self.min_crop_size) 367 | 368 | [crop_1], (top_left_v_1, top_left_h_1) = self.rand_crop_1.forward([input_tensor], (cr_v_sz, cr_h_sz)) 369 | [crop_2], (top_left_v_2, top_left_h_2) = self.rand_crop_1.forward([input_tensor], (cr_v_sz, cr_h_sz)) 370 | 371 | output_tensor = torch.zeros_like(input_tensor) 372 | output_tensor[:, :, :, :] = input_tensor 373 | 374 | output_tensor[:, :, top_left_v_1:top_left_v_1 + cr_v_sz, top_left_h_1:top_left_h_1 + cr_h_sz] = crop_2 375 | output_tensor[:, :, top_left_v_2:top_left_v_2 + cr_v_sz, top_left_h_2:top_left_h_2 + cr_h_sz] = crop_1 376 | 377 | # Creating a mask. this is drawing a line in width 2*mask_width over the boundaries of the cropped image 378 | loss_mask = torch.ones_like(input_tensor) 379 | mw = self.mask_width 380 | loss_mask[:, :, top_left_v_1:top_left_v_1+cr_v_sz, top_left_h_1-mw:top_left_h_1+mw] = 0 381 | loss_mask[:, :, top_left_v_1-mw:top_left_v_1+mw, top_left_h_1:top_left_h_1+cr_h_sz] = 0 382 | loss_mask[:, :, top_left_v_1:top_left_v_1+cr_v_sz, top_left_h_1+cr_h_sz-mw:top_left_h_1+cr_h_sz+mw] = 0 383 | loss_mask[:, :, top_left_v_1+cr_v_sz-mw:top_left_v_1+cr_v_sz+mw, top_left_h_1:top_left_h_1+cr_h_sz] = 0 384 | loss_mask[:, :, top_left_v_2:top_left_v_2+cr_v_sz, top_left_h_2-mw:top_left_h_2+mw] = 0 385 | loss_mask[:, :, top_left_v_2-mw:top_left_v_2+mw, top_left_h_2:top_left_h_2+cr_h_sz] = 0 386 | loss_mask[:, :, top_left_v_2:top_left_v_2+cr_v_sz, top_left_h_2+cr_h_sz-mw:top_left_h_2+cr_h_sz+mw] = 0 387 | loss_mask[:, :, top_left_v_2+cr_v_sz-mw:top_left_v_2+cr_v_sz+mw, top_left_h_2:top_left_h_2+cr_h_sz] = 0 388 | 389 | return output_tensor, loss_mask 390 | 391 | 392 | class GeoTransform(nn.Module): 393 | def __init__(self): 394 | super(GeoTransform, self).__init__() 395 | 396 | def forward(self, input_tensor, target_size, shifts): 397 | sz = input_tensor.shape 398 | theta = homography_based_on_top_corners_x_shift(shifts) 399 | 400 | pad = f.pad(input_tensor, (np.abs(np.int(np.ceil(sz[3] * shifts[0]))), np.abs(np.int(np.ceil(-sz[3] * shifts[1]))), 0, 0), 'reflect') 401 | target_size4d = torch.Size([pad.shape[0], pad.shape[1], target_size[0], target_size[1]]) 402 | 403 | grid = homography_grid(theta.expand(pad.shape[0], -1, -1), target_size4d) 404 | 405 | return f.grid_sample(pad, grid, mode='bilinear', padding_mode='border') 406 | -------------------------------------------------------------------------------- /non_rect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as f 3 | import numpy as np 4 | 5 | 6 | def affine_based_on_top_left_corner_x_shift(rand_affine): 7 | """ 8 | random affine transformation that only shifts the top-left corner at random along the x direction 9 | :param sig: amount of random x perturbation 10 | :return: forward and backward affine transforms 11 | """ 12 | aff = np.array([[1., -0.5 * rand_affine, 0.5 * rand_affine], [0, 1., 0]], dtype=np.float32) 13 | 14 | return torch.from_numpy(aff).clone().cuda() 15 | 16 | 17 | def apply_resize_and_affine(x, target_size, rand_affine): 18 | aff = affine_based_on_top_left_corner_x_shift(rand_affine) 19 | target_size4d = torch.Size([x.shape[0], x.shape[1], target_size[0], target_size[1]]) 20 | grid = f.affine_grid(aff.expand(x.shape[0], -1, -1), target_size4d) 21 | out = f.grid_sample(x, grid, mode='bilinear', padding_mode='border') 22 | return out 23 | 24 | 25 | def homography_grid(theta, size): 26 | r"""Generates a 2d flow field, given a batch of homography matrices :attr:`theta` 27 | Generally used in conjunction with :func:`grid_sample` to 28 | implement Spatial Transformer Networks. 29 | 30 | Args: 31 | theta (Tensor): input batch of homography matrices (:math:`N \times 3 \times 3`) 32 | size (torch.Size): the target output image size (:math:`N \times C \times H \times W`) 33 | Example: torch.Size((32, 3, 24, 24)) 34 | 35 | Returns: 36 | output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) 37 | """ 38 | y, x = torch.meshgrid((torch.linspace(-1., 1., size[-2]), torch.linspace(-1., 1., size[-1]))) 39 | n = size[-2] * size[-1] 40 | hxy = torch.ones(n, 3, dtype=torch.float) 41 | hxy[:, 0] = x.contiguous().view(-1) 42 | hxy[:, 1] = y.contiguous().view(-1) 43 | out = hxy[None, ...].cuda().matmul(theta.transpose(1, 2)) 44 | # normalize 45 | out = out[:, :, :2] / out[:, :, 2:] 46 | return out.view(theta.shape[0], size[-2], size[-1], 2) 47 | 48 | 49 | def apply_resize_and_homograhpy(x, target_size, rand_h): 50 | theta = homography_based_on_top_corners_x_shift(rand_h) 51 | target_size4d = torch.Size([x.shape[0], x.shape[1], target_size[0], target_size[1]]) 52 | grid = homography_grid(theta.expand(x.shape[0], -1, -1), target_size4d) 53 | out = f.grid_sample(x, grid, mode='bilinear', padding_mode='border') 54 | return out 55 | 56 | 57 | def homography_based_on_top_corners_x_shift(rand_h): 58 | # play with both top corners 59 | # p = np.array([[1., 1., -1, 0, 0, 0, -(-1. + rand_h[0]), -(-1. + rand_h[0]), -1. + rand_h[0]], 60 | # [0, 0, 0, 1., 1., -1., 1., 1., -1.], 61 | # [-1., 1., -1, 0, 0, 0, 1 + rand_h[1], -(1 + rand_h[1]), 1 + rand_h[1]], 62 | # [0, 0, 0, -1, 1, -1, -1, 1, -1], 63 | # [1, 0, -1, 0, 0, 0, 1, 0, -1], 64 | # [0, 0, 0, 1, 0, -1, 0, 0, 0], 65 | # [-1, 0, -1, 0, 0, 0, 1, 0, 1], 66 | # [0, 0, 0, -1, 0, -1, 0, 0, 0], 67 | # [0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.float32) 68 | # play with top left and bottom right 69 | p = np.array([[1., 1., -1, 0, 0, 0, -(-1. + rand_h[0]), -(-1. + rand_h[0]), -1. + rand_h[0]], 70 | [0, 0, 0, 1., 1., -1., 1., 1., -1.], 71 | [-1., -1., -1, 0, 0, 0, 1 + rand_h[1], 1 + rand_h[1], 1 + rand_h[1]], 72 | [0, 0, 0, -1, -1, -1, 1, 1, 1], 73 | [1, 0, -1, 0, 0, 0, 1, 0, -1], 74 | [0, 0, 0, 1, 0, -1, 0, 0, 0], 75 | [-1, 0, -1, 0, 0, 0, 1, 0, 1], 76 | [0, 0, 0, -1, 0, -1, 0, 0, 0], 77 | [0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.float32) 78 | b = np.zeros((9, 1), dtype=np.float32) 79 | b[8, 0] = 1. 80 | h = np.dot(np.linalg.inv(p), b) 81 | return torch.from_numpy(h).view(3, 3).clone().cuda() 82 | 83 | 84 | def apply_resize_and_radial(x, target_size, rand_r): 85 | target_size4d = torch.Size([x.shape[0], x.shape[1], target_size[0], target_size[1]]) 86 | grid = make_radial_scale_grid(rand_r, target_size4d) 87 | out = f.grid_sample(x, grid, mode='bilinear', padding_mode='border') 88 | return out 89 | 90 | def make_radial_scale_grid(rand_r, size4d): 91 | y, x = torch.meshgrid((torch.linspace(-1., 1., size4d[-2]), torch.linspace(-1., 1., size4d[-1]))) 92 | theta = torch.atan2(x, y) 93 | r = torch.sqrt() 94 | 95 | ''' 96 | def test_time(): 97 | def _make_pink_noise(sz_): 98 | with torch.no_grad(): 99 | n = 4 # number of scales 100 | pn_ = 0. 101 | sf = 0.375 102 | nsf = 0.5 103 | for sc in range(n): 104 | csz = [int(s_ * sf ** sc) for s_ in sz_[2:]] 105 | cn = torch.randn(sz_[0], sz_[1], csz[0], csz[1]).cuda() * nsf ** (n - sc - 1) 106 | pn_ += f.interpolate(cn, sz_[2:], mode='bilinear', align_corners=False) 107 | return torch.clamp(pn_, -1., 1.) 108 | 109 | import torch 110 | from torch.nn import functional as f 111 | from PIL import Image 112 | import util 113 | from InGAN import InGAN 114 | from configs import Config 115 | from skvideo.io import FFmpegWriter 116 | from non_rect import affine_based_on_top_left_corner_x_shift 117 | import numpy as np 118 | from non_rect import * 119 | 120 | conf = Config().parse() 121 | gan = InGAN(conf) 122 | sd = torch.load('results/rome_s-aff_Mar_03_16_23_22/checkpoint_0080000.pth.tar') 123 | gan.G.load_state_dict(sd['G']) 124 | 125 | def _make_affine_mask(in_mask, target_size, rand_affine): 126 | aff = affine_based_on_top_left_corner_x_shift(rand_affine) 127 | target_size4d = torch.Size([in_mask.shape[0], in_mask.shape[1], target_size[0], target_size[1]]) 128 | grid = f.affine_grid(aff.expand(in_mask.shape[0], -1, -1), target_size4d) 129 | out_mask = f.grid_sample(in_mask, grid, mode='bilinear', padding_mode='zeros') 130 | return out_mask 131 | 132 | def _make_homography_mask(in_mask, target_size, rand_h): 133 | theta = homography_based_on_top_corners_x_shift(rand_h) 134 | target_size4d = torch.Size([in_mask.shape[0], in_mask.shape[1], target_size[0], target_size[1]]) 135 | grid = homography_grid(theta.expand(in_mask.shape[0], -1, -1), target_size4d) 136 | out = f.grid_sample(in_mask, grid, mode='bilinear', padding_mode='zeros') 137 | return out 138 | 139 | orig = util.read_shave_tensorize('/home/bagon/develop/waic/InGAN/rome_s.png', 8) 140 | pad = torch.zeros(1, 3, orig.shape[2], orig.shape[3] * 2, dtype=torch.float).cuda() 141 | hp = orig.shape[3] // 2 142 | pad[..., hp:-hp] = orig 143 | in_mask = torch.zeros_like(pad[:, :1, ...]) 144 | in_mask[..., hp:-hp] = 1. 145 | 146 | pinkn = _make_pink_noise(pad.shape) 147 | 148 | writer = FFmpegWriter('vid-h-fruits_ss.mp4', verbosity=1, outputdict={'-b': '30000000', '-r': '10.0'}) 149 | n = 400 150 | for i in range(n): 151 | rand_h = (.25 * np.sin(2*np.pi*float(i)/float(0.5*n)), .25 * np.sin(2*np.pi*float(i)/float(0.25*n))) 152 | # a = float(.3 * np.sin(2*np.pi*float(i)/float(0.5*n))) 153 | out = gan.G(pad + 0. * pinkn, pad.shape[2:], rand_h) 154 | # out_mask = _make_affine_mask(in_mask, pad.shape[2:], a) 155 | out_mask = _make_homography_mask(in_mask, pad.shape[2:], rand_h) 156 | frame = util.tensor2im(out*out_mask - 1 + out_mask) 157 | writer.writeFrame(frame) 158 | writer.close() 159 | ''' -------------------------------------------------------------------------------- /supp_video.py: -------------------------------------------------------------------------------- 1 | import util 2 | from InGAN import InGAN 3 | from configs import Config 4 | from skvideo.io import FFmpegWriter 5 | import os 6 | from non_rect import * 7 | from SceneScripts import * 8 | 9 | 10 | FRAME_SHAPE = [500, 1000] 11 | MUST_DIVIDE = 8 12 | VIDEO_SCRIPT = [ # [nameses, script_name, script_params=(min_v, max_v, min_h, max_h, max_t, repeat)] 13 | [[['fruits'], ['fruits_old'], ['fruits_old'], ['fruits'], ['fruits']], ['horizontal_grow_shrink_slow', 'vertical_grow_shrink', 'resize_round', 'affine_dance', 'random'], [[0.55, None, 0.55, None, None, 1], [0.3, 1.8, 0.3, 2.0, None, 1, False], [0.3, 1.8, 0.3, 2.0, None, 1, False], [None, None, None, None, 0.45, 1, False], [0.3, 1.3, 0.3, 1.6, 0.45, 1, False]]], 14 | ['farm_house', 'special_resize_round', [0.45, None, 0.45, None, None, 2]], 15 | ['cab_building', 'resize_round', [0.5, None, 0.3, 2.5, None, 2]], 16 | ['rome', 'horizontal_grow_shrink', [0.3, None, 0.3, None, None, 3]], 17 | [[['peacock', 'windows']], 'resize_round', [0.5, 2, 0.5, 1.75, None, 3]], 18 | [[['soldiers', 'penguins']], 'horizontal_grow_shrink', [0.3, None, 0.3, None, None, 3]], 19 | [[['nkorea', 'sapa']], 'horizontal_grow_shrink', [0.15, None, 0.15, None, None, 3]], 20 | [[['quilt']] * 5, ['horizontal_grow_shrink', 'vertical_grow_shrink', 'resize_round', 'affine_dance', 'random'], [[0.55, None, 0.55, None, None, 1], [0.3, None, 0.3, None, None, 1, False], [0.3, None, 0.3, None, None, 1, False], [None, None, None, None, 0.45, 1, False], [0.6, 1.6, 0.6, 1.75, 0.55, 1, False]]], 21 | [[['umbrella'], ['umbrella'], ['umbrella']], ['horizontal_grow_shrink', 'resize_round', 'trapezoids'], [[0.55, None, 0.55, None, None, 1], [0.55, None, 0.55, None, None, 1, False], [1, 1, 0.8, 1.2, 0.3, 1, False]]], 22 | [[['metal_circles']] * 5, ['vertical_grow_shrink', 'random'], [[0.15, None, 0.55, None, None, 2], [0.15, 1.8, 0.15, 1.45, 0.55, 1, False]]], 23 | [[['fish'], ['fish']], ['affine_dance', 'random'], [[1, 1, 1, 1, 0.4, 1], [1, 1, 1, 1, 0.5, 1, False]]], 24 | ['wood', 'special_zoom', [0.3, None, 0.3, None, None, 2]], 25 | ['ny', 'affine_dance', [None, None, None, None, 0.3, 2]], 26 | ['sushi', 'resize_round', [0.5, None, 0.3, None, None, 1]], 27 | ] 28 | 29 | 30 | def generate_one_frame(gan, input_tensor, frame_shape, scale, geo_shifts, center): 31 | with torch.no_grad(): 32 | base_sz = input_tensor.shape 33 | in_size = base_sz[2:] 34 | out_pad = np.uint8(np.zeros([frame_shape[0], frame_shape[1], 3])) 35 | 36 | if scale[0] == -1: 37 | output_tensor = [None, input_tensor] 38 | out_mask = torch.ones_like(output_tensor[1]) 39 | out_size = in_size 40 | 41 | else: 42 | out_mask, out_size = prepare_geometric(base_sz, scale, geo_shifts) 43 | 44 | output_tensor, _, _ = gan.test(input_tensor=input_tensor, 45 | input_size=in_size, 46 | output_size=out_size, 47 | rand_affine=geo_shifts, 48 | run_d_pred=False, 49 | run_reconstruct=False) 50 | 51 | out = out_mask * output_tensor[1] - 1 + out_mask 52 | margin = np.uint16((frame_shape - np.array(out_size)) / 2) if center else [0, 0] 53 | out_pad[margin[0]:margin[0] + out_size[0], margin[1]:margin[1] + out_size[1], :] = util.hist_match(util.tensor2im(out), util.tensor2im(input_tensor), util.tensor2im(out_mask)) 54 | return out_pad 55 | 56 | 57 | def generate_one_scene(gan, input_tensor, scene_script, frame_shape, center): 58 | frames = [] 59 | for i, (scale_v, scale_h, shift_l, shift_r) in enumerate(scene_script): 60 | output_image = generate_one_frame(gan, input_tensor, frame_shape, [scale_v, scale_h], [shift_l, shift_r], center) 61 | frames.append(output_image) 62 | return np.stack(frames, axis=0) 63 | 64 | 65 | def generate_full_video(video_script, frame_shape): 66 | conf = Config().parse(create_dir_flag=False) 67 | conf.name = 'supp_vid' 68 | conf.output_dir_path = util.prepare_result_dir(conf) 69 | n_scenes = len(video_script) 70 | 71 | for i, (nameses, scene_script_names, scene_script_params) in enumerate(video_script): 72 | if not isinstance(nameses, list): 73 | nameses = [[nameses]] 74 | if not isinstance(scene_script_names, list): 75 | scene_script_names = [scene_script_names] 76 | scene_script_params = [scene_script_params] 77 | scenes = [] 78 | for names, scene_script_name, scene_script_param in zip(nameses, scene_script_names, scene_script_params): 79 | partial_screen_scenes = [] 80 | 81 | for name in names: 82 | conf.input_image_path = [os.path.dirname(os.path.abspath(__file__)) + '/' + INPUT_DICT[name][0]] 83 | conf.test_params_path = os.path.dirname(os.path.abspath(__file__)) + INPUT_DICT[name][1] 84 | gan = InGAN(conf) 85 | gan.G.load_state_dict(torch.load(conf.test_params_path, map_location='cuda:0')['G']) 86 | [input_tensor] = util.read_data(conf) 87 | 88 | cur_frame_shape = frame_shape[:] 89 | concat_axis = 2 if scene_script_name == 'resize_round' else 1 90 | if len(names) > 1: 91 | cur_frame_shape[concat_axis - 1] /= 2 92 | 93 | cur_scene_script_param = scene_script_param[:] 94 | if scene_script_param[1] is None: 95 | cur_scene_script_param[1] = cur_frame_shape[0] * 1.0 / input_tensor.shape[2] 96 | print 'max scale vertical:', cur_scene_script_param[1] 97 | if cur_scene_script_param[3] is None: 98 | cur_scene_script_param[3] = cur_frame_shape[1] * 1.0 / input_tensor.shape[3] 99 | print 'max scale horizontal:', cur_scene_script_param[3] 100 | 101 | scene_script = make_scene_script(scene_script_name, *cur_scene_script_param) 102 | 103 | center = (cur_scene_script_param[4] is not None) 104 | 105 | 106 | scene = generate_one_scene(gan, input_tensor, scene_script, np.array([cur_frame_shape[0], cur_frame_shape[1]]), center) 107 | partial_screen_scenes.append(scene) 108 | 109 | print 'Done with %s, (scene %d/%d)' % (name, i + 1, n_scenes) 110 | 111 | 112 | scene = np.concatenate(partial_screen_scenes, axis=concat_axis) if len(partial_screen_scenes) > 1 else partial_screen_scenes[0] 113 | scenes.append(scene) 114 | 115 | scene = np.concatenate(scenes, axis=0) 116 | 117 | outputdict = {'-b:v': '30000000', '-r': '100.0', 118 | '-vf': 'drawtext="text=\'Input image\':fontcolor=red:fontsize=48:x=(w-text_w)/2:y=(h-text_h)*7/8:enable=\'between(t,0,2)\'"', 119 | '-preset': 'slow', '-profile:v': 'high444', '-level:v': '4.0', '-crf': '22'} 120 | if len(names) > 1: 121 | outputdict['-vf'] = 'drawtext="text=\'Input images\':fontcolor=red:fontsize=48:x=(w-text_w)/2:y=(h-text_h)/2.5:enable=\'between(t,0,2)\'"' 122 | 123 | if not scene_script_params[-1]: 124 | outputdict['-vf'] = 'drawtext="text=\'Input images\':fontcolor=red:fontsize=48:x=(w-text_w)/2:y=(h-text_h)/2.5:enable=\'between(t,0,0)\'"' 125 | 126 | writer = FFmpegWriter(conf.output_dir_path + '/vid%d_%s.mp4' % (i, '_'.join(names)), verbosity=1, 127 | outputdict=outputdict) 128 | for frame in scene: 129 | for j in range(3): 130 | writer.writeFrame(frame) 131 | writer.close() 132 | 133 | 134 | def prepare_geometric(base_sz, scale, geo_shifts): 135 | pad_l = np.abs(np.int(np.ceil(base_sz[3] * geo_shifts[0]))) 136 | pad_r = np.abs(np.int(np.ceil(base_sz[3] * geo_shifts[1]))) 137 | in_mask = torch.zeros(base_sz[0], base_sz[1], base_sz[2], pad_l + base_sz[3] + pad_r).cuda() 138 | in_size = in_mask.shape[2:] 139 | out_size = (np.uint32(np.floor(scale[0] * in_size[0] * 1.0 / MUST_DIVIDE) * MUST_DIVIDE), 140 | np.uint32(np.floor(scale[1] * in_size[1] * 1.0 / MUST_DIVIDE) * MUST_DIVIDE)) 141 | if pad_r > 0: 142 | in_mask[:, :, :, pad_l:-pad_r] = torch.ones(base_sz) 143 | else: 144 | in_mask[:, :, :, pad_l:] = torch.ones(base_sz) 145 | 146 | theta = homography_based_on_top_corners_x_shift(geo_shifts) 147 | target_size4d = torch.Size([in_mask.shape[0], in_mask.shape[1], out_size[0], out_size[1]]) 148 | grid = homography_grid(theta.expand(in_mask.shape[0], -1, -1), target_size4d) 149 | out_mask = f.grid_sample(in_mask, grid, mode='bilinear', padding_mode='zeros') 150 | return out_mask, out_size 151 | 152 | 153 | def main(): 154 | generate_full_video(VIDEO_SCRIPT, FRAME_SHAPE) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from networks import GeoTransform 2 | from PIL import Image 3 | import util 4 | from InGAN import InGAN 5 | from configs import Config 6 | from traceback import print_exc 7 | from skvideo.io import FFmpegWriter 8 | import os 9 | from non_rect import * 10 | 11 | 12 | def test_one_scale(gan, input_tensor, scale, must_divide, affine=None, return_tensor=False, size_instead_scale=False): 13 | with torch.no_grad(): 14 | in_size = input_tensor.shape[2:] 15 | if size_instead_scale: 16 | out_size = scale 17 | else: 18 | out_size = (np.uint32(np.floor(scale[0] * in_size[0] * 1.0 / must_divide) * must_divide), 19 | np.uint32(np.floor(scale[1] * in_size[1] * 1.0 / must_divide) * must_divide)) 20 | 21 | output_tensor, _, _ = gan.test(input_tensor=input_tensor, 22 | input_size=in_size, 23 | output_size=out_size, 24 | rand_affine=affine, 25 | run_d_pred=False, 26 | run_reconstruct=False) 27 | if return_tensor: 28 | return output_tensor[1] 29 | else: 30 | return util.tensor2im(output_tensor[1]) 31 | 32 | 33 | def concat_images(images, margin, input_spot): 34 | h_sizes = [im.shape[0] for im in zip(*images)[0]] 35 | w_sizes = [im.shape[1] for im in images[0]] 36 | h_total_size = np.sum(h_sizes) + margin * (len(images) - 1) 37 | w_total_size = np.sum(w_sizes) + margin * (len(images) - 1) 38 | 39 | collage = np.ones([h_total_size, w_total_size, 3]) * 255 40 | for i in range(len(images)): 41 | for j in range(len(images)): 42 | top_left_corner_h = int(np.sum(h_sizes[:j]) + j * margin) 43 | top_left_corner_w = int(np.sum(w_sizes[:i]) + i * margin) 44 | bottom_right_corner_h = int(top_left_corner_h + h_sizes[j]) 45 | bottom_right_corner_w = int(top_left_corner_w + w_sizes[i]) 46 | 47 | if [i, j] == input_spot: 48 | collage[top_left_corner_h - margin/2: bottom_right_corner_h + margin/2, 49 | top_left_corner_w - margin/2: bottom_right_corner_w + margin/2, 50 | :] = [255, 0, 0] 51 | collage[top_left_corner_h:bottom_right_corner_h, top_left_corner_w:bottom_right_corner_w] = images[j][i] 52 | 53 | return collage 54 | 55 | 56 | def generate_images_for_collage(gan, input_tensor, scales, must_divide): 57 | # NOTE: scales here is different from in the other funcs: here we only need 1d scales. 58 | # Prepare output images list 59 | output_images = [[[None] for _ in range(len(scales))] for _ in range(len(scales))] 60 | 61 | # Run over all scales and test the network for each one 62 | for i, scale_h in enumerate(scales): 63 | for j, scale_w in enumerate(scales): 64 | output_images[i][j] = test_one_scale(gan, input_tensor, [scale_h, scale_w], must_divide) 65 | return output_images 66 | 67 | 68 | def retarget_video(gan, input_tensor, scales, must_divide, output_dir_path): 69 | max_scale = np.max(np.array(scales)) 70 | frame_shape = np.uint32(np.array(input_tensor.shape[2:]) * max_scale) 71 | frame_shape[0] += (frame_shape[0] % 2) 72 | frame_shape[1] += (frame_shape[1] % 2) 73 | frames = np.zeros([len(scales), frame_shape[0], frame_shape[1], 3]) 74 | for i, (scale_h, scale_w) in enumerate(scales): 75 | output_image = test_one_scale(gan, input_tensor, [scale_h, scale_w], must_divide) 76 | frames[i, 0:output_image.shape[0], 0:output_image.shape[1], :] = output_image 77 | writer = FFmpegWriter(output_dir_path + '/vid.mp4', verbosity=1, outputdict={'-b': '30000000', '-r': '100.0'}) 78 | 79 | for i, _ in enumerate(scales): 80 | for j in range(3): 81 | writer.writeFrame(frames[i, :, :, :]) 82 | writer.close() 83 | 84 | 85 | def define_video_scales(scales): 86 | max_v, min_v, max_h, min_h = scales 87 | frames_per_resize = 10 88 | 89 | x = np.concatenate([ 90 | np.linspace(1, max_v, frames_per_resize), 91 | np.linspace(max_v, min_v, 2 * frames_per_resize), 92 | np.linspace(min_v, max_v, 2 * frames_per_resize), 93 | np.linspace(max_v, 1, frames_per_resize), 94 | np.linspace(1, 1, frames_per_resize), 95 | np.linspace(1, 1, 2 * frames_per_resize), 96 | np.linspace(1, 1, 2 * frames_per_resize), 97 | np.linspace(1, 1, frames_per_resize), 98 | np.linspace(1, max_v, frames_per_resize), 99 | np.linspace(max_v, min_v, 2 * frames_per_resize), 100 | np.linspace(min_v, max_v, 2 * frames_per_resize), 101 | np.linspace(max_v, 1, frames_per_resize), 102 | np.linspace(1, 1, frames_per_resize), 103 | np.linspace(1, max_v, frames_per_resize), 104 | np.linspace(max_v, max_v, 2 * frames_per_resize), 105 | np.linspace(max_v, min_v, 2 * frames_per_resize)]) 106 | y = np.concatenate([ 107 | np.linspace(1, 1, frames_per_resize), 108 | np.linspace(1, 1, 2 * frames_per_resize), 109 | np.linspace(1, 1, 2 * frames_per_resize), 110 | np.linspace(1, 1, frames_per_resize), 111 | np.linspace(1, max_h, frames_per_resize), 112 | np.linspace(max_h, min_h, 2 * frames_per_resize), 113 | np.linspace(min_h, max_h, 2 * frames_per_resize), 114 | np.linspace(max_h, 1, frames_per_resize), 115 | np.linspace(1, max_h, frames_per_resize), 116 | np.linspace(max_h, min_h, 2 * frames_per_resize), 117 | np.linspace(min_h, max_h, 2 * frames_per_resize), 118 | np.linspace(max_h, 1, frames_per_resize), 119 | np.linspace(1, max_h, frames_per_resize), 120 | np.linspace(max_h, max_h, frames_per_resize), 121 | np.linspace(max_h, min_h, 2 * frames_per_resize), 122 | np.linspace(min_h, min_h, 2 * frames_per_resize)]) 123 | 124 | return zip(x, y) 125 | 126 | 127 | def generate_collage_and_outputs(conf, gan, input_tensor): 128 | output_images = generate_images_for_collage(gan, input_tensor, conf.collage_scales, conf.must_divide) 129 | 130 | for i in range(len(output_images)): 131 | for j in range(len(output_images)): 132 | Image.fromarray(output_images[i][j], 'RGB').save(conf.output_dir_path + '/test_%d_%d.png' % (i, j)) 133 | 134 | input_spot = conf.collage_input_spot 135 | output_images[input_spot[0]][input_spot[1]] = util.tensor2im(input_tensor) 136 | 137 | collage = concat_images(output_images, margin=10, input_spot=input_spot) 138 | 139 | Image.fromarray(np.uint8(collage), 'RGB').save(conf.output_dir_path + '/test_collage.png') 140 | 141 | 142 | def _make_homography_mask(in_mask, target_size, rand_h): 143 | theta = homography_based_on_top_corners_x_shift(rand_h) 144 | target_size4d = torch.Size([in_mask.shape[0], in_mask.shape[1], target_size[0], target_size[1]]) 145 | grid = homography_grid(theta.expand(in_mask.shape[0], -1, -1), target_size4d) 146 | out = f.grid_sample(in_mask, grid, mode='bilinear', padding_mode='border') 147 | return out 148 | 149 | 150 | def test_homo(conf, gan, input_tensor, must_divide=8): 151 | shift_range = np.arange(conf.non_rect_shift_range[0], conf.non_rect_shift_range[1], conf.non_rect_shift_range[2]) 152 | total = (len(conf.non_rect_scales)*len(shift_range))**2 153 | ind = 0 154 | for scale1 in conf.non_rect_scales: 155 | for scale2 in conf.non_rect_scales: 156 | scale = [scale1, scale2] 157 | for shift1 in shift_range: 158 | for shift2 in shift_range: 159 | ind += 1 160 | shifts = (shift1, shift2) 161 | sz = input_tensor.shape 162 | out_pad = np.uint8(255*np.ones([np.uint32(np.floor(sz[2]*scale[0])), np.uint32(np.floor(3*sz[3]*scale[1])), 3])) 163 | 164 | pad_l = np.abs(np.int(np.ceil(sz[3] * shifts[0]))) 165 | pad_r = np.abs(np.int(np.ceil(sz[3] * shifts[1]))) 166 | 167 | in_mask = torch.zeros(sz[0], sz[1], sz[2], pad_l + sz[3] + pad_r).cuda() 168 | input_for_regular = torch.zeros(sz[0], sz[1], sz[2], pad_l + sz[3] + pad_r).cuda() 169 | 170 | in_size = in_mask.shape[2:] 171 | 172 | out_size = (np.uint32(np.floor(scale[0] * in_size[0] * 1.0 / must_divide) * must_divide), 173 | np.uint32(np.floor(scale[1] * in_size[1] * 1.0 / must_divide) * must_divide)) 174 | 175 | if pad_r > 0: 176 | in_mask[:,:, :, pad_l:-pad_r] = torch.ones_like(input_tensor) 177 | input_for_regular[:, :, :, pad_l:-pad_r] = input_tensor 178 | else: 179 | in_mask[:, :, :, pad_l:] = torch.ones_like(input_tensor) 180 | input_for_regular[:, :, :, pad_l:] = input_tensor 181 | 182 | out = test_one_scale(gan, input_tensor, out_size, conf.must_divide, affine=shifts, return_tensor=True, size_instead_scale=True) 183 | # regular = transform(input_tensor, out_size, shifts) 184 | out_mask = _make_homography_mask(in_mask, out_size, shifts) 185 | 186 | out = util.tensor2im(out_mask * out + 1 - out_mask) 187 | # regular_out = util.tensor2im(out_mask * regular + 1 - out_mask) 188 | # out_pad[:, sz[3] - pad_l: sz[3] - pad_l + out_size[1], :] = out 189 | shift_str = "{1:0{0}d}_{3:0{2}d}".format(2 if shift1>=0 else 3, int(10*shift1), 2 if shift2>=0 else 3, int(10*shift2)) 190 | 191 | # out = np.rot90(out, 3) 192 | # regular_out = np.rot90(regular_out, 3) 193 | 194 | Image.fromarray(out, 'RGB').save(conf.output_dir_path + '/scale_%02d_%02d_transform %s_ingan.png' % (int(10*scale1), int(10*scale2), shift_str)) 195 | # Image.fromarray(regular_out, 'RGB').save(conf.output_dir_path + '/scale_%02d_%02d_transform %s_ref.png' % (scale1, scale2, shift_str)) 196 | print ind, '/', total, 'scale:', scale, 'shift:', shifts 197 | 198 | 199 | def main(): 200 | conf = Config().parse(create_dir_flag=False) 201 | conf.name = 'TEST_' + conf.name 202 | conf.output_dir_path = util.prepare_result_dir(conf) 203 | gan = InGAN(conf) 204 | 205 | try: 206 | gan.resume(conf.test_params_path, test_flag=True) 207 | [input_tensor] = util.read_data(conf) 208 | 209 | if conf.test_video: 210 | retarget_video(gan, input_tensor, define_video_scales(conf.test_vid_scales), 8, conf.output_dir_path) 211 | if conf.test_collage: 212 | generate_collage_and_outputs(conf, gan, input_tensor) 213 | if conf.test_non_rect: 214 | test_homo(conf, gan, input_tensor) 215 | 216 | print 'Done with %s' % conf.input_image_path 217 | 218 | except KeyboardInterrupt: 219 | raise 220 | except Exception as e: 221 | # print 'Something went wrong with %s (%d/%d), iter %dk' % (input_image_path, i, n_files, snapshot_iter) 222 | print_exc() 223 | 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from configs import Config 2 | from InGAN import InGAN 3 | import os 4 | from util import Visualizer, read_data 5 | from traceback import print_exc 6 | 7 | 8 | # Load configuration 9 | conf = Config().parse() 10 | 11 | # Prepare data 12 | input_images = read_data(conf) 13 | 14 | # Create complete model 15 | gan = InGAN(conf) 16 | 17 | # If required, fine-tune from some checkpoint 18 | if conf.resume is not None: 19 | gan.resume(os.path.join(conf.resume)) 20 | 21 | # Define visualizer to monitor learning process 22 | visualizer = Visualizer(gan, conf, input_images) 23 | 24 | # Main training loop 25 | for i in range(conf.max_iters + 1): 26 | 27 | # Train a single iteration on the current data instance 28 | try: 29 | gan.train_one_iter(i, input_images) 30 | except KeyboardInterrupt: 31 | raise 32 | except Exception as e: 33 | print 'Something went wrong in iteration %d, While training.' % i 34 | print_exc() 35 | 36 | # Take care of all testing, saving and presenting of current results and status 37 | try: 38 | visualizer.test_and_display(i) 39 | except KeyboardInterrupt: 40 | raise 41 | except Exception as e: 42 | print 'Something went wrong in iteration %d, While testing or visualizing.' % i 43 | print_exc() 44 | 45 | # Save snapshot when needed 46 | try: 47 | if i > 0 and not i % conf.save_snapshot_freq: 48 | gan.save(os.path.join(conf.output_dir_path, 'checkpoint_%07d.pth.tar' % i)) 49 | del gan 50 | gan = InGAN(conf) 51 | gan.resume(os.path.join(conf.output_dir_path, 'checkpoint_%07d.pth.tar' % i)) 52 | visualizer.gan = gan 53 | except KeyboardInterrupt: 54 | raise 55 | except Exception as e: 56 | print 'Something went wrong in iteration %d, While saving snapshot.' % i 57 | print_exc() 58 | -------------------------------------------------------------------------------- /train_supp_mat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import Queue 4 | import subprocess 5 | 6 | base_dir = './side/' 7 | abl_args = {'geo_side': []} 8 | 9 | 10 | def experiment_was_not_already_exec(exp_name): 11 | for nm in os.listdir('results'): 12 | if nm.startswith(exp_name): 13 | return False 14 | return True 15 | 16 | 17 | class Worker(threading.Thread): 18 | def __init__(self, inQ, gpu_id): 19 | super(Worker, self).__init__() 20 | self.inQ = inQ 21 | self.daemon = True 22 | self.env = os.environ.copy() # copy of environment 23 | self.env['CUDA_VISIBLE_DEVICES'] = '{:d}'.format(gpu_id) 24 | self.start() 25 | 26 | def run(self): 27 | while True: 28 | try: 29 | exp_name, item = self.inQ.get() 30 | except Queue.Empty: 31 | break 32 | # verify that this experiment was not executed already 33 | if experiment_was_not_already_exec(exp_name): 34 | subprocess.call(item, env=self.env) 35 | self.inQ.task_done() 36 | 37 | 38 | def main(): 39 | q = Queue.Queue() 40 | workers = [Worker(q, gpu_id) for gpu_id in [0, 1]] 41 | for imgname in os.listdir(base_dir): 42 | full_img_name = os.path.join(base_dir, imgname) 43 | short_name = os.path.splitext(imgname)[0] 44 | cmd = ['python', 'train.py', '--input_image_path', full_img_name, '--gpu_id', '0'] 45 | for aname, aa in abl_args.items(): 46 | exp_name = '{}_{}'.format(short_name, aname) 47 | full_cmd = cmd + aa + ['--name', exp_name] 48 | q.put((exp_name, full_cmd)) 49 | q.join() 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from matplotlib import pyplot as plt, gridspec 4 | import os 5 | import glob 6 | from time import strftime, localtime 7 | from shutil import copy 8 | from scipy.misc import imresize 9 | import torch 10 | 11 | 12 | def read_data(conf): 13 | input_images = [read_shave_tensorize(path, conf.must_divide) for path in conf.input_image_path] 14 | return input_images 15 | 16 | 17 | def read_shave_tensorize(path, must_divide): 18 | input_np = (np.array(Image.open(path).convert('RGB')) / 255.0) 19 | 20 | input_np_shaved = input_np[:(input_np.shape[0] // must_divide) * must_divide, 21 | :(input_np.shape[1] // must_divide) * must_divide, 22 | :] 23 | 24 | input_tensor = im2tensor(input_np_shaved) 25 | 26 | return input_tensor 27 | 28 | 29 | def tensor2im(image_tensors, imtype=np.uint8): 30 | 31 | if not isinstance(image_tensors, list): 32 | image_tensors = [image_tensors] 33 | 34 | image_numpys = [] 35 | for image_tensor in image_tensors: 36 | # Note that tensors are shifted to be in [-1,1] 37 | image_numpy = image_tensor.detach().cpu().float().numpy() 38 | 39 | if np.ndim(image_numpy) == 4: 40 | image_numpy = image_numpy.transpose((0, 2, 3, 1)) 41 | 42 | image_numpy = np.round((image_numpy.squeeze(0) + 1) / 2.0 * 255.0) 43 | image_numpys.append(image_numpy.astype(imtype)) 44 | 45 | if len(image_numpys) == 1: 46 | image_numpys = image_numpys[0] 47 | 48 | return image_numpys 49 | 50 | 51 | def im2tensor(image_numpy, int_flag=False): 52 | # the int flag indicates whether the input image is integer (and [0,255]) or float ([0,1]) 53 | if int_flag: 54 | image_numpy /= 255.0 55 | # Undo the tensor shifting (see tensor2im function) 56 | transformed_image = np.transpose(image_numpy, (2, 0, 1)) * 2.0 - 1.0 57 | return torch.FloatTensor(transformed_image).unsqueeze(0).cuda() 58 | 59 | 60 | def random_size(orig_size, curriculum=True, i=None, iter_for_max_range=None, must_divide=8.0, 61 | min_scale=0.25, max_scale=2.0, max_transform_magniutude=0.3): 62 | cur_max_scale = 1.0 + (max_scale - 1.0) * np.clip(1.0 * i / iter_for_max_range, 0, 1) if curriculum else max_scale 63 | cur_min_scale = 1.0 + (min_scale - 1.0) * np.clip(1.0 * i / iter_for_max_range, 0, 1) if curriculum else min_scale 64 | cur_max_transform_magnitude = (max_transform_magniutude * np.clip(1.0 * i / iter_for_max_range, 0, 1) 65 | if curriculum else max_transform_magniutude) 66 | 67 | # set random transformation magnitude. scalar = affine, pair = homography. 68 | random_affine = -cur_max_transform_magnitude + 2 * cur_max_transform_magnitude * np.random.rand(2) 69 | 70 | # set new size for the output image 71 | new_size = np.array(orig_size) * (cur_min_scale + (cur_max_scale - cur_min_scale) * np.random.rand(2)) 72 | 73 | return tuple(np.uint32(np.ceil(new_size * 1.0 / must_divide) * must_divide)), random_affine 74 | 75 | 76 | def image_concat(g_preds, d_preds=None, size=None): 77 | hsize = g_preds[0].shape[0] + 6 if size is None else size[0] 78 | results = [] 79 | if d_preds is None: 80 | d_preds = [None] * len(g_preds) 81 | for g_pred, d_pred in zip(g_preds, d_preds): 82 | # noinspection PyUnresolvedReferences 83 | dsize = g_pred.shape[1] if size is None or size[1] is None else size[1] 84 | result = np.ones([(1 + (d_pred is not None)) * hsize, dsize, 3]) * 255 85 | if d_pred is not None: 86 | d_pred_new = imresize((np.concatenate([d_pred] * 3, 2) - 128) * 2, g_pred.shape[0:2], interp='nearest') 87 | result[hsize-g_pred.shape[0]:hsize+g_pred.shape[0], :g_pred.shape[1], :] = np.concatenate([g_pred, 88 | d_pred_new], 0) 89 | else: 90 | result[hsize - g_pred.shape[0]:, :, :] = g_pred 91 | results.append(np.uint8(np.round(result))) 92 | 93 | return np.concatenate(results, 1) 94 | 95 | 96 | def save_image(image_tensor, image_path): 97 | image_pil = Image.fromarray(tensor2im(image_tensor), 'RGB') 98 | image_pil.save(image_path) 99 | 100 | 101 | def get_scale_weights(i, max_i, start_factor, input_shape, min_size, num_scales_limit, scale_factor): 102 | num_scales = np.min([np.int(np.ceil(np.log(np.min(input_shape) * 1.0 / min_size) 103 | / np.log(scale_factor))), num_scales_limit]) 104 | 105 | # if i > max_i * 2: 106 | # i = max_i * 2 107 | 108 | factor = start_factor ** ((max_i - i) * 1.0 / max_i) 109 | 110 | un_normed_weights = factor ** np.arange(num_scales) 111 | weights = un_normed_weights / np.sum(un_normed_weights) 112 | # 113 | # np.clip(i, 0, max_i) 114 | # 115 | # un_normed_weights = np.exp(-((np.arange(num_scales) - (max_i - i) * num_scales * 1.0 / max_i) ** 2) / (2 * sigma ** 2)) 116 | # weights = un_normed_weights / np.sum(un_normed_weights) 117 | 118 | return weights 119 | 120 | 121 | class Visualizer: 122 | def __init__(self, gan, conf, test_inputs): 123 | self.gan = gan 124 | self.conf = conf 125 | self.G_loss = [None] * conf.max_iters 126 | self.D_loss_real = [None] * conf.max_iters 127 | self.D_loss_fake = [None] * conf.max_iters 128 | 129 | self.test_inputs = test_inputs 130 | self.test_input_sizes = [test_input.shape[2:] for test_input in test_inputs] 131 | 132 | if conf.reconstruct_loss_stop_iter > 0: 133 | self.Rec_loss = [None] * conf.max_iters 134 | 135 | def recreate_fig(self): 136 | self.fig = plt.figure(figsize=(18, 9)) 137 | gs = gridspec.GridSpec(8, 8) 138 | self.result = self.fig.add_subplot(gs[0:8, 0:4]) 139 | self.gan_loss = self.fig.add_subplot(gs[0:2, 5:8]) 140 | self.reconstruct_loss = self.fig.add_subplot(gs[3:5, 5:8]) 141 | self.reconstruction = self.fig.add_subplot(gs[6:8, 5:6]) 142 | self.real_example = self.fig.add_subplot(gs[7, 6]) 143 | self.d_map_real = self.fig.add_subplot(gs[7, 7]) 144 | 145 | # First plot data 146 | self.plot_gan_loss = self.gan_loss.plot([], [], 'b-', 147 | [], [], 'c--', 148 | [], [], 'r--') 149 | self.gan_loss.legend(('Generator loss', 'Discriminator loss (real image)', 'Discriminator loss (fake image)')) 150 | self.gan_loss.set_ylim(0, 1) 151 | 152 | if self.conf.reconstruct_loss_stop_iter > 0: 153 | self.plot_reconstruct_loss = self.reconstruct_loss.semilogy([], []) 154 | 155 | # Set titles 156 | self.gan_loss.set_title('Gan Losses') 157 | self.reconstruct_loss.set_title('Reconstruction Loss') 158 | self.reconstruction.set_title('Reconstruction') 159 | self.d_map_real.set_xlabel('Current Discriminator \n map for real example') 160 | self.real_example.set_xlabel('Real example') 161 | self.result.set_title('Current result') 162 | 163 | self.result.axes.get_xaxis().set_visible(False) 164 | self.result.axes.get_yaxis().set_visible(False) 165 | self.reconstruction.axes.get_xaxis().set_visible(False) 166 | self.reconstruction.axes.get_yaxis().set_visible(False) 167 | self.d_map_real.axes.get_yaxis().set_visible(False) 168 | self.real_example.axes.get_yaxis().set_visible(False) 169 | self.result.axes.get_yaxis().set_visible(False) 170 | 171 | def test_and_display(self, i): 172 | if not i % self.conf.print_freq and i > 0: 173 | self.G_loss[i-self.conf.print_freq:i] = self.gan.losses_G_gan.detach().cpu().float().numpy().tolist() 174 | self.D_loss_real[i-self.conf.print_freq:i] = self.gan.losses_D_real.detach().cpu().float().numpy().tolist() 175 | self.D_loss_fake[i-self.conf.print_freq:i] = self.gan.losses_D_fake.detach().cpu().float().numpy().tolist() 176 | if self.conf.reconstruct_loss_stop_iter > i: 177 | self.Rec_loss[i-self.conf.print_freq:i] = self.gan.losses_G_reconstruct.detach().cpu().float().numpy().tolist() 178 | 179 | if self.conf.reconstruct_loss_stop_iter < i: 180 | print('iter: %d, G_loss: %f, D_loss_real: %f, D_loss_fake: %f, LR: %f' % 181 | (i, self.G_loss[i-1], self.D_loss_real[i-1], self.D_loss_fake[i-1], 182 | self.gan.lr_scheduler_G.get_lr()[0])) 183 | else: 184 | print('iter: %d, G_loss: %f, D_loss_real: %f, D_loss_fake: %f, Rec_loss: %f, LR: %f' % 185 | (i, self.G_loss[i-1], self.D_loss_real[i-1], self.D_loss_fake[i-1], self.Rec_loss[i-1], 186 | self.gan.lr_scheduler_G.get_lr()[0])) 187 | 188 | if not i % self.conf.display_freq and i > 0: 189 | plt.gcf().clear() 190 | plt.close() 191 | self.recreate_fig() 192 | 193 | # choice = np.random.randint(0, len(self.test_inputs)) 194 | # test_input, test_input_size = self.test_inputs[choice], self.test_input_sizes[choice] 195 | # # Determine output size of G (dynamic change) 196 | # output_size, rand_h = random_size(orig_size=test_input_size, 197 | # curriculum=self.conf.curriculum, 198 | # i=i, 199 | # iter_for_max_range=self.conf.iter_for_max_range, 200 | # must_divide=self.conf.must_divide, 201 | # min_scale=self.conf.min_scale, 202 | # max_scale=self.conf.max_scale) 203 | # 204 | # g_preds, d_preds, reconstructs = self.gan.test(test_input, output_size, rand_h, test_input_size) 205 | 206 | g_preds = [self.gan.input_tensor_noised, self.gan.G_pred] 207 | d_preds = [self.gan.D.forward(self.gan.input_tensor_noised.detach(), self.gan.scale_weights), 208 | self.gan.d_pred_fake] 209 | reconstructs = self.gan.reconstruct 210 | input_size = self.gan.input_tensor_noised.shape[2:] 211 | 212 | result = image_concat(tensor2im(g_preds), tensor2im(d_preds), (input_size[0]*2, input_size[1]*2)) 213 | self.plot_gan_loss[0].set_data(range(i), self.G_loss[:i]) 214 | self.plot_gan_loss[1].set_data(range(i), self.D_loss_real[:i]) 215 | self.plot_gan_loss[2].set_data(range(i), self.D_loss_fake[:i]) 216 | self.gan_loss.set_xlim(0, i) 217 | 218 | if self.conf.reconstruct_loss_stop_iter > i: 219 | self.plot_reconstruct_loss[0].set_data(range(i), self.Rec_loss[:i]) 220 | self.reconstruct_loss.set_ylim(np.min(self.Rec_loss[:i]), np.max(self.Rec_loss[:i])) 221 | self.reconstruct_loss.set_xlim(0, i) 222 | 223 | self.result.imshow(np.clip(result, 0, 255), vmin=0, vmax=255) 224 | self.real_example.imshow(np.clip(tensor2im(self.gan.real_example[0:1, :, :, :]), 0, 255), vmin=0, vmax=255) 225 | self.d_map_real.imshow(self.gan.d_pred_real[0:1, :, :, :].detach().cpu().float().numpy().squeeze(), 226 | cmap='gray', vmin=0, vmax=1) 227 | if self.conf.reconstruct_loss_stop_iter > i: 228 | self.reconstruction.imshow(np.clip(image_concat([tensor2im(reconstructs)]), 0, 255), vmin=0, vmax=255) 229 | 230 | plt.savefig(self.conf.output_dir_path + '/monitor_%d' % i) 231 | 232 | save_image(self.gan.G_pred, self.conf.output_dir_path + '/result_iter_%d.png' % i) 233 | 234 | 235 | def prepare_result_dir(conf): 236 | # Create results directory 237 | conf.output_dir_path += '/' + conf.name + strftime('_%b_%d_%H_%M_%S', localtime()) 238 | os.makedirs(conf.output_dir_path) 239 | 240 | # Put a copy of all *.py files in results path, to be able to reproduce experimental results 241 | if conf.create_code_copy: 242 | local_dir = os.path.dirname(__file__) 243 | for py_file in glob.glob(local_dir + '/*.py'): 244 | copy(py_file, conf.output_dir_path) 245 | if conf.resume: 246 | copy(conf.resume, os.path.join(conf.output_dir_path, 'starting_checkpoint.pth.tar')) 247 | return conf.output_dir_path 248 | 249 | 250 | 251 | def homography_based_on_top_corners_x_shift(rand_h): 252 | p = np.array([[1., 1., -1, 0, 0, 0, -(-1. + rand_h[0]), -(-1. + rand_h[0]), -1. + rand_h[0]], 253 | [0, 0, 0, 1., 1., -1., 1., 1., -1.], 254 | [-1., -1., -1, 0, 0, 0, 1 + rand_h[1], 1 + rand_h[1], 1 + rand_h[1]], 255 | [0, 0, 0, -1, -1, -1, 1, 1, 1], 256 | [1, 0, -1, 0, 0, 0, 1, 0, -1], 257 | [0, 0, 0, 1, 0, -1, 0, 0, 0], 258 | [-1, 0, -1, 0, 0, 0, 1, 0, 1], 259 | [0, 0, 0, -1, 0, -1, 0, 0, 0], 260 | [0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.float32) 261 | b = np.zeros((9, 1), dtype=np.float32) 262 | b[8, 0] = 1. 263 | h = np.dot(np.linalg.inv(p), b) 264 | return torch.from_numpy(h).view(3, 3).cuda() 265 | 266 | 267 | def homography_grid(theta, size): 268 | r"""Generates a 2d flow field, given a batch of homography matrices :attr:`theta` 269 | Generally used in conjunction with :func:`grid_sample` to 270 | implement Spatial Transformer Networks. 271 | 272 | Args: 273 | theta (Tensor): input batch of homography matrices (:math:`N \times 3 \times 3`) 274 | size (torch.Size): the target output image size (:math:`N \times C \times H \times W`) 275 | Example: torch.Size((32, 3, 24, 24)) 276 | 277 | Returns: 278 | output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) 279 | """ 280 | a = 1 281 | b = 1 282 | y, x = torch.meshgrid((torch.linspace(-b, b, np.int(size[-2]*a)), torch.linspace(-b, b, np.int(size[-1]*a)))) 283 | n = np.int(size[-2] * a) * np.int(size[-1] * a) 284 | hxy = torch.ones(n, 3, dtype=torch.float) 285 | hxy[:, 0] = x.contiguous().view(-1) 286 | hxy[:, 1] = y.contiguous().view(-1) 287 | out = hxy[None, ...].cuda().matmul(theta.transpose(1, 2)) 288 | # normalize 289 | out = out[:, :, :2] / out[:, :, 2:] 290 | return out.view(theta.shape[0], np.int(size[-2]*a), np.int(size[-1]*a), 2) 291 | 292 | 293 | def hist_match(source, template, mask_3ch): 294 | """ 295 | Adjust the pixel values of a grayscale image such that its histogram 296 | matches that of a target image 297 | 298 | Arguments: 299 | ----------- 300 | source: np.ndarray 301 | Image to transform; the histogram is computed over the flattened 302 | array 303 | template: np.ndarray 304 | Template image; can have different dimensions to source 305 | Returns: 306 | ----------- 307 | matched: np.ndarray 308 | The transformed output image 309 | """ 310 | 311 | oldshape = source.shape 312 | source_masked = source.ravel()[mask_3ch.ravel() > 128] 313 | template = template.ravel() 314 | # get the set of unique pixel values and their corresponding indices and 315 | # counts 316 | s_values, bin_idx, s_counts = np.unique(source_masked, return_inverse=True, 317 | return_counts=True) 318 | t_values, t_counts = np.unique(template, return_counts=True) 319 | 320 | # take the cumsum of the counts and normalize by the number of pixels to 321 | # get the empirical cumulative distribution functions for the source and 322 | # template images (maps pixel value --> quantile) 323 | s_quantiles = np.cumsum(s_counts).astype(np.float64) 324 | s_quantiles /= s_quantiles[-1] 325 | t_quantiles = np.cumsum(t_counts).astype(np.float64) 326 | t_quantiles /= t_quantiles[-1] 327 | 328 | # interpolate linearly to find the pixel values in the template image 329 | # that correspond most closely to the quantiles in the source image 330 | interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) 331 | 332 | out = source.copy().ravel() 333 | out[mask_3ch.ravel() > 128] = interp_t_values[bin_idx] 334 | return out.reshape(oldshape) 335 | --------------------------------------------------------------------------------