├── .gitignore ├── LICENSE ├── README.md ├── architecture.png ├── data ├── __init__.py ├── aligned_dataset.py ├── aligned_dataset_resized.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py ├── image_folder.py └── single_dataset.py ├── download_models.sh ├── generate_masks.py ├── imgs ├── compare │ ├── 13_fake_B.png │ ├── 13_fake_B_flip.png │ ├── 13_real_A.png │ ├── 13_real_B.png │ ├── 18_fake_B.png │ ├── 18_fake_B_flip.png │ ├── 18_real_A.png │ └── 18_real_B.png ├── face_center │ ├── 0_fake_B.png │ ├── 0_real_A.png │ ├── 0_real_B.png │ ├── 106_fake_B.png │ ├── 106_real_A.png │ ├── 106_real_B.png │ ├── 111_fake_B.png │ ├── 111_real_A.png │ ├── 111_real_B.png │ ├── 11_fake_B.png │ ├── 11_real_A.png │ ├── 11_real_B.png │ ├── 14_fake_B.png │ ├── 14_real_A.png │ ├── 14_real_B.png │ ├── 1_fake_B.png │ ├── 1_real_A.png │ └── 1_real_B.png ├── face_random │ ├── 0_fake_B.png │ ├── 0_real_A.png │ ├── 0_real_B.png │ ├── 1_fake_B.png │ ├── 1_real_A.png │ └── 1_real_B.png ├── paris_center │ ├── 003_im_fake_B.png │ ├── 003_im_real_A.png │ ├── 003_im_real_B.png │ ├── 004_im_fake_B.png │ ├── 004_im_real_A.png │ ├── 004_im_real_B.png │ ├── 048_im_fake_B.png │ ├── 048_im_real_A.png │ └── 048_im_real_B.png └── paris_random │ ├── 006_im_fake_B.png │ ├── 006_im_real_A.png │ ├── 006_im_real_B.png │ ├── 055_im_fake_B.png │ ├── 055_im_real_A.png │ ├── 055_im_real_B.png │ ├── 073_im_fake_B.png │ ├── 073_im_real_A.png │ └── 073_im_real_B.png ├── models ├── __init__.py ├── face_shift_net │ ├── InnerFaceShiftTriple.py │ ├── InnerFaceShiftTripleFunction.py │ ├── __init__.py │ └── face_shiftnet_model.py ├── modules │ ├── __init__.py │ ├── denset_net.py │ ├── discrimators.py │ ├── losses.py │ ├── modules.py │ ├── shift_unet.py │ └── unet.py ├── networks.py ├── patch_soft_shift │ ├── __init__.py │ ├── innerPatchSoftShiftTriple.py │ ├── innerPatchSoftShiftTripleModule.py │ └── patch_soft_shiftnet_model.py ├── res_patch_soft_shift │ ├── __init__.py │ ├── innerResPatchSoftShiftTriple.py │ └── res_patch_soft_shiftnet_model.py ├── res_shift_net │ ├── __init__.py │ ├── innerResShiftTriple.py │ └── shiftnet_model.py └── shift_net │ ├── InnerCos.py │ ├── InnerCosFunction.py │ ├── InnerShiftTriple.py │ ├── InnerShiftTripleFunction.py │ ├── __init__.py │ ├── base_model.py │ └── shiftnet_model.py ├── notebooks ├── NewModule.ipynb ├── OptimizingShift.ipynb └── __init__.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── shift_layer.png ├── show_map.py ├── test.py ├── test_acc_shift.py ├── train.py └── util ├── NonparametricShift.py ├── __init__.py ├── html.py ├── png.py ├── poisson_blending.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.JPG 2 | *.jpg 3 | *.txt 4 | datasets/ 5 | checkpoints/ 6 | output/ 7 | results/ 8 | .vscode/ 9 | log/ 10 | logs/ 11 | *.swp 12 | *.pth 13 | *.pyc 14 | .idea/ 15 | *-checkpoint.py 16 | *.ipynb_checkpoints/ 17 | masks/ 18 | resized_paris/ 19 | fakeB/ 20 | shifted/ 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zhaoyi-Yan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # New training strategy 2 | I release a new training strategy that helps deal with random mask training by reducing color shifting at the cost of about extra 30% training time. It is quite useful when we perform face inpainiting. 3 | Set `--which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'` and `--batchSize=1`to carry out the strategy. 4 | 5 | See some examples below, many approaches suffer from such `color shifting` when training with random masks on face datasets. 6 | 7 | 8 |   9 | 10 | 11 | 14 | 17 | 20 | 23 | 24 | 25 | 26 | 27 | 30 | 33 | 36 | 39 | 40 | 41 | 42 |
Input Navie Shift Flip Shift Ground-truth
12 | 13 | 15 | 16 | 18 | 19 | 21 | 22 |
28 | 29 | 31 | 32 | 34 | 35 | 37 | 38 |
43 | 44 | 45 | 46 | 47 | Note: When you use `face_flip training strategy`, it suffers some minor drawbacks: 48 | 1. It is not fully-parallel compared with original shift. 49 | 2. It can only be trained on the 'cpu' or on a single gpu, the batch size must be 1, or it occurs an error. 50 | 51 | If you want to conquer these drawbacks, you can optimize it by referring to original shift. It is not difficult, however, I do not have time to do it. 52 | 53 | # Architecutre 54 | 55 | 56 | # Shift layer 57 | 58 | 59 | ## Prerequisites 60 | - Linux or Windows. 61 | - Python 2 or Python 3. 62 | - CPU or NVIDIA GPU + CUDA CuDNN. 63 | - Tested on pytorch >= **1.2** 64 | 65 | ## Getting Started 66 | ### Installation 67 | - Install PyTorch and dependencies from http://pytorch.org/ 68 | - Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate). 69 | 70 | ```bash 71 | pip install visdom 72 | pip install dominate 73 | ``` 74 | - Clone this repo: 75 | ```bash 76 | git clone https://github.com/Zhaoyi-Yan/Shift-Net_pytorch 77 | cd Shift-Net_pytorch 78 | 79 | ``` 80 | 81 | # Trained models 82 | Usually, I would like to suggest you just pull the latest code and train by following the instructions. 83 | 84 | However, for now, several models have been trained and uploaded. 85 | 86 | | Mask | Paris | CelebaHQ_256 | 87 | | ---- | ---- | ---- | 88 | | center-mask | ok | ok | 89 | | random mask(from **partial conv**)| ok | ok | 90 | 91 | For CelebaHQ_256 dataset: 92 | I select the first 2k images in CelebaHQ_256 for testing, the rest are for training. 93 | ``` 94 | python train.py --loadSize=256 --batchSize=1 --model='face_shiftnet' --name='celeb256' --which_model_netG='face_unet_shift_triple' --niter=30 --datarooot='./datasets/celeba-256/train' 95 | ``` 96 | Mention: **`loadSize` should be `256` for face datasets, meaning direct resize the input image to `256x256`.** 97 | 98 | The following some results on celebaHQ-256 and Paris. 99 | 100 | Specially, for training models of random masks, we adopt the masks of **partial conv**(only the masks of which the ratio of masked region is 20~30% are used.) 101 | 102 | 103 | 104 |   105 | 106 | 107 | 110 | 113 | 116 | 117 | 118 | 119 | 122 | 125 | 128 | 129 | 130 | 131 | 134 | 137 | 140 | 141 | 142 | 143 | 146 | 149 | 152 | 153 | 154 | 155 | 158 | 161 | 164 | 165 | 166 | 167 | 170 | 173 | 176 | 177 | 178 | 179 | 182 | 185 | 188 | 189 | 190 | 191 | 194 | 197 | 200 | 201 | 202 | 203 |
Input Results Ground-truth
108 | 109 | 111 | 112 | 114 | 115 |
120 | 121 | 123 | 124 | 126 | 127 |
132 | 133 | 135 | 136 | 138 | 139 |
144 | 145 | 147 | 148 | 150 | 151 |
156 | 157 | 159 | 160 | 162 | 163 |
168 | 169 | 171 | 172 | 174 | 175 |
180 | 181 | 183 | 184 | 186 | 187 |
192 | 193 | 195 | 196 | 198 | 199 |
204 | 205 | For testing, please read the documnent carefully. 206 | 207 | Pretrained model for face center inpainting are available: 208 | ```bash 209 | bash download_models.sh 210 | ``` 211 | Rename `face_center_mask.pth` to `30_net_G.pth`, and put it in the folder `./log/face_center_mask_20_30`(if not existed, create it) 212 | ```bash 213 | python test.py --which_model_netG='unet_shift_triple' --model='shiftnet' --name='face_center_mask_20_30' --which_epoch=30 --dataroot='./datasets/celeba-256/test' 214 | ``` 215 | 216 | For face random inpainting, it is trained with `--which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'`. Rename `face_flip_random.pth` to `30_net_G.pth` and set `which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'` when testing. 217 | 218 | Similarity, for paris random inpainting, rename `paris_random_mask_20_30.pth` to `30_net_G.pth`, and put it in the folder `./log/paris_random_mask_20_30`(if not existed, create it) 219 | Then test the model: 220 | ``` 221 | python test.py --which_epoch=30 --name='paris_random_mask_20_30' --offline_loading_mask=1 --testing_mask_folder='masks' --dataroot='./datasets/celeba-256/test' --norm='instance' 222 | ``` 223 | Mention, your own masks should be prepared in the folder `testing_mask_folder` in advance. 224 | 225 | For other models, I think you know how to evaluate them. 226 | For models trained with center mask, make sure `--mask_type='center' --offline_loading_mask=0`. 227 | 228 | 229 | ## Train models 230 | - Download your own inpainting datasets. Just put all the train/test images in some folder (eg, ./xx/train/ , ./xx/test/), change `dataroot` in `options/base_options.py` to the that path, that is all. 231 | 232 | - Train a model: 233 | Please read this paragraph carefully before running the code. 234 | 235 | Usually, we train/test `navie shift-net` with `center` mask. 236 | 237 | ```bash 238 | python train.py --batchsize=1 --use_spectral_norm_D=1 --which_model_netD='basic' --mask_type='center' --which_model_netG='unet_shift_triple' --model='shiftnet' --shift_sz=1 --mask_thred=1 239 | ``` 240 | 241 | For some datasets, such as `CelebA`, some images are smaller than `256*256`, so you need add `--loadSize=256` when training, **it is important**. 242 | 243 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. The checkpoints will be saved in `./log` by default. 244 | 245 | 246 | **DO NOT** set batchsize larger than 1 for `square` mask training, the performance degrades a lot(I don't know why...) 247 | For `random mask`(`mask_sub_type` is NOT `rect` or your own random masks), the training batchsize can be larger than 1 without hurt of performance. 248 | 249 | Random mask training(both online and offline) are also supported. 250 | 251 | Personally, I would like to suggest you to loading the masks offline(similar as **partial conv**). Please refer to section **Masks**. 252 | 253 | ## Test the model 254 | 255 | **Keep the same settings as those during training phase to avoid errors or bad performance** 256 | 257 | For example, if you train `patch soft shift-net`, then the following testing command is appropriate. 258 | ```bash 259 | python test.py --fuse=1/0 --which_model_netG='patch_soft_unet_shift_triple' --model='patch_soft_shiftnet' --shift_sz=3 --mask_thred=4 260 | ``` 261 | The test results will be saved to a html file here: `./results/`. 262 | 263 | 264 | ## Masks 265 | Usually, **Keep the same setting of masks of between training and testing.** 266 | It is because the performance is highly-related to the masks your applied in training. 267 | The consistency of training and testing masks is crucial to achieve good performance. 268 | 269 | | training | testing | 270 | | ---- | ---- | 271 | | center-mask | center-mask | 272 | | random-square| All | 273 | | random | All| 274 | | your own masks| your own masks| 275 | 276 | It means that if you train a model with `center-mask`, then you need test it using `center-mask`(even without one pixel offset). For more info, you may refer to https://github.com/Zhaoyi-Yan/Shift-Net_pytorch/issues/125 277 | ### Training by online-generating marks 278 | We offer three types of online-generating masks: `center-mask, random_square and random_mask`. 279 | If you want to train on your own masks silimar like **partial conv**, ref to **Training on your own masks**. 280 | 281 | 282 | ### Training on your own masks 283 | It now supports both online-generating and offline-loading for training and testing. 284 | We generate masks online by default, however, set `--offline_loading_mask=1` when you want to train/test with your own prepared masks. 285 | **The prepared masks should be put in the folder `--training_mask_folder` and `--testing_mask_folder`.** 286 | 287 | ### Masks when training 288 | For each batch, then: 289 | - Generating online: masks are the same for each image in a batch.(To save computation) 290 | - Loading offline: masks are loaded randomly for each image in a batch. 291 | 292 | ## Using Switchable Norm instead of Instance/Batch Norm 293 | For fixed mask training, `Switchable Norm` delivers better stableness when batchSize > 1. **Please use switchable norm when you want to training with batchsize is large, much more stable than instance norm or batchnorm!** 294 | 295 | ### Extra variants 296 | 297 | **These 3 models are just for fun** 298 | 299 | For `res patch soft shift-net`: 300 | ```bash 301 | python train.py --batchSize=1 --which_model_netG='res_patch_soft_unet_shift_triple' --model='res_patch_soft_shiftnet' --shift_sz=3 --mask_thred=4 302 | ``` 303 | 304 | For `res navie shift-net`: 305 | ```bash 306 | python train.py --which_model_netG='res_unet_shift_triple' --model='res_shiftnet' --shift_sz=1 --mask_thred=1 307 | ``` 308 | 309 | For `patch soft shift-net`: 310 | ```bash 311 | python train.py --which_model_netG='patch_soft_unet_shift_triple' --model='patch_soft_shiftnet' --shift_sz=3 --mask_thred=4 312 | ``` 313 | 314 | DO NOT change the shift_sz and mask_thred. Otherwise, it errors with a high probability. 315 | 316 | For `patch soft shift-net` or `res patch soft shift-net`. You may set `fuse=1` to see whether it delivers better results(Mention, you need keep the same setting between training and testing). 317 | 318 | 319 | ## New things that I want to add 320 | - [x] Make U-Net handle with inputs of any sizes. 321 | - [x] Add more GANs, like spectural norm and relativelistic GAN. 322 | - [x] Boost the efficiency of shift layer. 323 | - [x] Directly resize the global_mask to get the mask in feature space. 324 | - [x] Visualization of flow. It is still experimental now. 325 | - [x] Extensions of Shift-Net. Still active in absorbing new features. 326 | - [x] Fix bug in guidance loss when adopting it in multi-gpu. 327 | - [x] Add composit L1 loss between mask loss and non-mask loss. 328 | - [x] Finish optimizing soft-shift. 329 | - [x] Add mask varaint in a batch. 330 | - [x] Support Online-generating/Offline-loading prepared masks for training/testing. 331 | - [x] Add VGG loss and TV loss 332 | - [x] Fix performance degradance when batchsize is larger than 1. 333 | - [x] Make it compatible for Pytorch 1.2 334 | - [ ] Training with mixed type of masks. 335 | - [ ] Try amp training 336 | - [ ] Try self-attn discriminator(maybe it helps) 337 | 338 | ## Somethings extra I have tried 339 | **Gated Conv**: I have tried gated conv(by replacing the normal convs of UNet with gated conv, expect the innermost/outermost layer). 340 | However, I obtained no benifits. Maybe I should try replacing all layers with gated conv. I will try again when I am free. 341 | 342 | **Non local block**: I added, but seems worse. Maybe I haven't added the blocks on the proper postion. (It makes the training time increase a lot. So I am not in favor of it.) 343 | 344 | ## Citation 345 | If you find this work useful or gives you some insights, please cite: 346 | ``` 347 | @InProceedings{Yan_2018_Shift, 348 | author = {Yan, Zhaoyi and Li, Xiaoming and Li, Mu and Zuo, Wangmeng and Shan, Shiguang}, 349 | title = {Shift-Net: Image Inpainting via Deep Feature Rearrangement}, 350 | booktitle = {The European Conference on Computer Vision (ECCV)}, 351 | month = {September}, 352 | year = {2018} 353 | } 354 | ``` 355 | 356 | ## Acknowledgments 357 | We benefit a lot from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) 358 | -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/architecture.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/data/__init__.py -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import os.path 3 | import random 4 | import torchvision.transforms as transforms 5 | import torch 6 | import random 7 | from data.base_dataset import BaseDataset 8 | from data.image_folder import make_dataset 9 | from PIL import Image 10 | 11 | class AlignedDataset(BaseDataset): 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.dir_A = opt.dataroot 15 | self.A_paths = sorted(make_dataset(self.dir_A)) 16 | if self.opt.offline_loading_mask: 17 | self.mask_folder = self.opt.training_mask_folder if self.opt.isTrain else self.opt.testing_mask_folder 18 | self.mask_paths = sorted(make_dataset(self.mask_folder)) 19 | 20 | assert(opt.resize_or_crop == 'resize_and_crop') 21 | 22 | transform_list = [transforms.ToTensor(), 23 | transforms.Normalize((0.5, 0.5, 0.5), 24 | (0.5, 0.5, 0.5))] 25 | 26 | self.transform = transforms.Compose(transform_list) 27 | 28 | def __getitem__(self, index): 29 | A_path = self.A_paths[index] 30 | A = Image.open(A_path).convert('RGB') 31 | w, h = A.size 32 | 33 | if w < h: 34 | ht_1 = self.opt.loadSize * h // w 35 | wd_1 = self.opt.loadSize 36 | A = A.resize((wd_1, ht_1), Image.BICUBIC) 37 | else: 38 | wd_1 = self.opt.loadSize * w // h 39 | ht_1 = self.opt.loadSize 40 | A = A.resize((wd_1, ht_1), Image.BICUBIC) 41 | 42 | A = self.transform(A) 43 | h = A.size(1) 44 | w = A.size(2) 45 | w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) 46 | h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) 47 | 48 | A = A[:, h_offset:h_offset + self.opt.fineSize, 49 | w_offset:w_offset + self.opt.fineSize] 50 | 51 | if (not self.opt.no_flip) and random.random() < 0.5: 52 | A = torch.flip(A, [2]) 53 | 54 | # let B directly equals to A 55 | B = A.clone() 56 | A_flip = torch.flip(A, [2]) 57 | B_flip = A_flip.clone() 58 | 59 | # Just zero the mask is fine if not offline_loading_mask. 60 | mask = A.clone().zero_() 61 | if self.opt.offline_loading_mask: 62 | if self.opt.isTrain: 63 | mask = Image.open(self.mask_paths[random.randint(0, len(self.mask_paths)-1)]) 64 | else: 65 | mask = Image.open(self.mask_paths[index % len(self.mask_paths)]) 66 | mask = mask.resize((self.opt.fineSize, self.opt.fineSize), Image.NEAREST) 67 | mask = transforms.ToTensor()(mask) 68 | 69 | return {'A': A, 'B': B, 'A_F': A_flip, 'B_F': B_flip, 'M': mask, 70 | 'A_paths': A_path} 71 | 72 | def __len__(self): 73 | return len(self.A_paths) 74 | 75 | def name(self): 76 | return 'AlignedDataset' 77 | -------------------------------------------------------------------------------- /data/aligned_dataset_resized.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import os.path 3 | import random 4 | import torchvision.transforms as transforms 5 | import torch 6 | from data.base_dataset import BaseDataset 7 | from data.image_folder import make_dataset 8 | from PIL import Image 9 | 10 | class AlignedDatasetResized(BaseDataset): 11 | def initialize(self, opt): 12 | self.opt = opt 13 | self.root = opt.dataroot 14 | self.dir_A = opt.dataroot # More Flexible for users 15 | 16 | self.A_paths = sorted(make_dataset(self.dir_A)) 17 | 18 | assert(opt.resize_or_crop == 'resize_and_crop') 19 | 20 | transform_list = [transforms.ToTensor(), 21 | transforms.Normalize((0.5, 0.5, 0.5), 22 | (0.5, 0.5, 0.5))] 23 | 24 | self.transform = transforms.Compose(transform_list) 25 | 26 | def __getitem__(self, index): 27 | A_path = self.A_paths[index] 28 | A = Image.open(A_path).convert('RGB') 29 | 30 | A = A.resize ((self.opt.fineSize, self.opt.fineSize), Image.BICUBIC) 31 | 32 | A = self.transform(A) 33 | 34 | #if (not self.opt.no_flip) and random.random() < 0.5: 35 | # idx = [i for i in range(A.size(2) - 1, -1, -1)] # size(2)-1, size(2)-2, ... , 0 36 | # idx = torch.LongTensor(idx) 37 | # A = A.index_select(2, idx) 38 | 39 | # let B directly equals A 40 | B = A.clone() 41 | return {'A': A, 'B': B, 42 | 'A_paths': A_path} 43 | 44 | def __len__(self): 45 | return len(self.A_paths) 46 | 47 | def name(self): 48 | return 'AlignedDatasetResized' 49 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | class BaseDataset(data.Dataset): 4 | def __init__(self): 5 | super(BaseDataset, self).__init__() 6 | 7 | def name(self): 8 | return 'BaseDataset' 9 | 10 | def initialize(self, opt): 11 | pass 12 | 13 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | if opt.dataset_mode == 'aligned': 8 | from data.aligned_dataset import AlignedDataset 9 | dataset = AlignedDataset() 10 | 11 | elif opt.dataset_mode == 'aligned_resized': 12 | from data.aligned_dataset_resized import AlignedDatasetResized 13 | dataset = AlignedDatasetResized() 14 | 15 | elif opt.dataset_mode == 'single': 16 | from data.single_dataset import SingleDataset 17 | dataset = SingleDataset() 18 | 19 | else: 20 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 21 | 22 | print("dataset [%s] was created" % (dataset.name())) 23 | dataset.initialize(opt) 24 | return dataset 25 | 26 | 27 | class CustomDatasetDataLoader(BaseDataLoader): 28 | def name(self): 29 | return 'CustomDatasetDataLoader' 30 | 31 | def initialize(self, opt): 32 | BaseDataLoader.initialize(self, opt) 33 | self.dataset = CreateDataset(opt) 34 | 35 | self.dataloader = torch.utils.data.DataLoader( 36 | self.dataset, 37 | batch_size=opt.batchSize, 38 | shuffle=not opt.serial_batches, 39 | num_workers=int(opt.nThreads)) 40 | 41 | def load_data(self): 42 | return self 43 | 44 | def __len__(self): 45 | return min(len(self.dataset), self.opt.max_dataset_size) 46 | 47 | def __iter__(self): 48 | for i, data in enumerate(self.dataloader): 49 | if i*self.opt.batchSize >= self.opt.max_dataset_size: 50 | break 51 | yield data -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | 7 | 8 | class SingleDataset(BaseDataset): 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.root = opt.dataroot 12 | self.dir_A = os.path.join(opt.dataroot) 13 | 14 | # make_dataset returns paths of all images in one folder 15 | self.A_paths = make_dataset(self.dir_A) 16 | 17 | self.A_paths = sorted(self.A_paths) 18 | 19 | transform_list = [] 20 | if opt.resize_or_crop == 'resize_and_crop': 21 | transform_list.append(transforms.Scale(opt.loadSize)) 22 | 23 | if opt.isTrain and not opt.no_flip: 24 | transform_list.append(transforms.RandomHorizontalFlip()) 25 | 26 | if opt.resize_or_crop != 'no_resize': 27 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 28 | 29 | # Make it between [-1, 1], beacuse [(0-0.5)/0.5, (1-0.5)/0.5] 30 | transform_list += [transforms.ToTensor(), 31 | transforms.Normalize((0.5, 0.5, 0.5), 32 | (0.5, 0.5, 0.5))] 33 | 34 | self.transform = transforms.Compose(transform_list) 35 | 36 | def __getitem__(self, index): 37 | A_path = self.A_paths[index] 38 | 39 | A_img = Image.open(A_path).convert('RGB') 40 | 41 | A = self.transform(A_img) 42 | if self.opt.which_direction == 'BtoA': 43 | input_nc = self.opt.output_nc 44 | else: 45 | input_nc = self.opt.input_nc 46 | 47 | return {'A': A, 'A_paths': A_path} 48 | 49 | def __len__(self): 50 | return len(self.A_paths) 51 | 52 | def name(self): 53 | return 'SingleImageDataset' 54 | -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | # face model (Trained on CelebaHQ-256, the first 2k images are for testing, the rest are for training.) 2 | wget -c https://drive.google.com/open?id=1qvsWHVO9iXpEAPtwyRB25mklTmD0jgPV 3 | # face random mask model 4 | wget -c https://drive.google.com/open?id=1Pz9gkm2VYaEK3qMXnszJufsvRqbcXrjS 5 | # paris random mask model 6 | wget -c https://drive.google.com/open?id=14MzixaqYUdJNL5xGdVhSKI9jOfvGdr3M 7 | # paris center mask model 8 | wget -c https://drive.google.com/open?id=1nDkCdsqUdiEXfSjZ_P915gWeZELK0fo_ 9 | -------------------------------------------------------------------------------- /generate_masks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import numpy as np 3 | from options.train_options import TrainOptions 4 | import util.util as util 5 | import os 6 | from PIL import Image 7 | import glob 8 | 9 | mask_folder = 'masks/testing_masks' 10 | test_folder = './datasets/Paris/test' 11 | util.mkdir(mask_folder) 12 | 13 | opt = TrainOptions().parse() 14 | 15 | f = glob.glob(test_folder+'/*.png') 16 | print(f) 17 | 18 | for fl in f: 19 | mask = torch.zeros(opt.fineSize, opt.fineSize) 20 | if opt.mask_sub_type == 'fractal': 21 | assert 1==2, "It is broken now..." 22 | mask = util.create_walking_mask() # create an initial random mask. 23 | 24 | elif opt.mask_sub_type == 'rect': 25 | mask, rand_t, rand_l = util.create_rand_mask(opt) 26 | 27 | elif opt.mask_sub_type == 'island': 28 | mask = util.wrapper_gmask(opt) 29 | 30 | print('Generating mask for test image: '+os.path.basename(fl)) 31 | util.save_image(mask.squeeze().numpy()*255, os.path.join(mask_folder, os.path.splitext(os.path.basename(fl))[0]+'_mask.png')) 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /imgs/compare/13_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_fake_B.png -------------------------------------------------------------------------------- /imgs/compare/13_fake_B_flip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_fake_B_flip.png -------------------------------------------------------------------------------- /imgs/compare/13_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_real_A.png -------------------------------------------------------------------------------- /imgs/compare/13_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_real_B.png -------------------------------------------------------------------------------- /imgs/compare/18_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_fake_B.png -------------------------------------------------------------------------------- /imgs/compare/18_fake_B_flip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_fake_B_flip.png -------------------------------------------------------------------------------- /imgs/compare/18_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_real_A.png -------------------------------------------------------------------------------- /imgs/compare/18_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_real_B.png -------------------------------------------------------------------------------- /imgs/face_center/0_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/0_fake_B.png -------------------------------------------------------------------------------- /imgs/face_center/0_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/0_real_A.png -------------------------------------------------------------------------------- /imgs/face_center/0_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/0_real_B.png -------------------------------------------------------------------------------- /imgs/face_center/106_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/106_fake_B.png -------------------------------------------------------------------------------- /imgs/face_center/106_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/106_real_A.png -------------------------------------------------------------------------------- /imgs/face_center/106_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/106_real_B.png -------------------------------------------------------------------------------- /imgs/face_center/111_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/111_fake_B.png -------------------------------------------------------------------------------- /imgs/face_center/111_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/111_real_A.png -------------------------------------------------------------------------------- /imgs/face_center/111_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/111_real_B.png -------------------------------------------------------------------------------- /imgs/face_center/11_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/11_fake_B.png -------------------------------------------------------------------------------- /imgs/face_center/11_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/11_real_A.png -------------------------------------------------------------------------------- /imgs/face_center/11_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/11_real_B.png -------------------------------------------------------------------------------- /imgs/face_center/14_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/14_fake_B.png -------------------------------------------------------------------------------- /imgs/face_center/14_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/14_real_A.png -------------------------------------------------------------------------------- /imgs/face_center/14_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/14_real_B.png -------------------------------------------------------------------------------- /imgs/face_center/1_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/1_fake_B.png -------------------------------------------------------------------------------- /imgs/face_center/1_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/1_real_A.png -------------------------------------------------------------------------------- /imgs/face_center/1_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/1_real_B.png -------------------------------------------------------------------------------- /imgs/face_random/0_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/0_fake_B.png -------------------------------------------------------------------------------- /imgs/face_random/0_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/0_real_A.png -------------------------------------------------------------------------------- /imgs/face_random/0_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/0_real_B.png -------------------------------------------------------------------------------- /imgs/face_random/1_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/1_fake_B.png -------------------------------------------------------------------------------- /imgs/face_random/1_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/1_real_A.png -------------------------------------------------------------------------------- /imgs/face_random/1_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/1_real_B.png -------------------------------------------------------------------------------- /imgs/paris_center/003_im_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/003_im_fake_B.png -------------------------------------------------------------------------------- /imgs/paris_center/003_im_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/003_im_real_A.png -------------------------------------------------------------------------------- /imgs/paris_center/003_im_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/003_im_real_B.png -------------------------------------------------------------------------------- /imgs/paris_center/004_im_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/004_im_fake_B.png -------------------------------------------------------------------------------- /imgs/paris_center/004_im_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/004_im_real_A.png -------------------------------------------------------------------------------- /imgs/paris_center/004_im_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/004_im_real_B.png -------------------------------------------------------------------------------- /imgs/paris_center/048_im_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/048_im_fake_B.png -------------------------------------------------------------------------------- /imgs/paris_center/048_im_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/048_im_real_A.png -------------------------------------------------------------------------------- /imgs/paris_center/048_im_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/048_im_real_B.png -------------------------------------------------------------------------------- /imgs/paris_random/006_im_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/006_im_fake_B.png -------------------------------------------------------------------------------- /imgs/paris_random/006_im_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/006_im_real_A.png -------------------------------------------------------------------------------- /imgs/paris_random/006_im_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/006_im_real_B.png -------------------------------------------------------------------------------- /imgs/paris_random/055_im_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/055_im_fake_B.png -------------------------------------------------------------------------------- /imgs/paris_random/055_im_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/055_im_real_A.png -------------------------------------------------------------------------------- /imgs/paris_random/055_im_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/055_im_real_B.png -------------------------------------------------------------------------------- /imgs/paris_random/073_im_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/073_im_fake_B.png -------------------------------------------------------------------------------- /imgs/paris_random/073_im_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/073_im_real_A.png -------------------------------------------------------------------------------- /imgs/paris_random/073_im_real_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/073_im_real_B.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | def create_model(opt): 2 | model = None 3 | print(opt.model) 4 | if opt.model == 'shiftnet': 5 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized') 6 | from models.shift_net.shiftnet_model import ShiftNetModel 7 | model = ShiftNetModel() 8 | 9 | elif opt.model == 'res_shiftnet': 10 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized') 11 | from models.res_shift_net.shiftnet_model import ResShiftNetModel 12 | model = ResShiftNetModel() 13 | 14 | elif opt.model == 'patch_soft_shiftnet': 15 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized') 16 | from models.patch_soft_shift.patch_soft_shiftnet_model import PatchSoftShiftNetModel 17 | model = PatchSoftShiftNetModel() 18 | 19 | elif opt.model == 'res_patch_soft_shiftnet': 20 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized') 21 | from models.res_patch_soft_shift.res_patch_soft_shiftnet_model import ResPatchSoftShiftNetModel 22 | model = ResPatchSoftShiftNetModel() 23 | 24 | elif opt.model == 'face_shiftnet': 25 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized') 26 | from models.face_shift_net.face_shiftnet_model import FaceShiftNetModel 27 | model = FaceShiftNetModel() 28 | else: 29 | raise ValueError("Model [%s] not recognized." % opt.model) 30 | model.initialize(opt) 31 | print("model [%s] was created" % (model.name())) 32 | return model 33 | -------------------------------------------------------------------------------- /models/face_shift_net/InnerFaceShiftTriple.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import util.util as util 4 | from .InnerFaceShiftTripleFunction import InnerFaceShiftTripleFunction 5 | 6 | 7 | class InnerFaceShiftTriple(nn.Module): 8 | def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3, device='gpu'): 9 | super(InnerFaceShiftTriple, self).__init__() 10 | 11 | self.shift_sz = shift_sz 12 | self.stride = stride 13 | self.mask_thred = mask_thred 14 | self.triple_weight = triple_weight 15 | self.layer_to_last = layer_to_last 16 | self.device = device 17 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy. 18 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region) 19 | 20 | 21 | def set_mask(self, mask_global): 22 | self.mask_all = util.cal_feat_mask(mask_global, self.layer_to_last) 23 | 24 | def _split_mask(self, cur_bsize): 25 | # get the visible indexes of gpus and assign correct mask to set of images 26 | cur_device = torch.cuda.current_device() 27 | self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :] 28 | 29 | 30 | # If mask changes, then need to set cal_fix_flag true each iteration. 31 | def forward(self, input, flip_feat=None): 32 | self.bz, self.c, self.h, self.w = input.size() 33 | if self.device != 'cpu': 34 | self._split_mask(self.bz) 35 | else: 36 | self.cur_mask = self.mask_all 37 | self.mask = self.cur_mask 38 | self.mask_flip = torch.flip(self.mask, [3]) 39 | 40 | self.flag = util.cal_flag_given_mask_thred(self.mask, self.shift_sz, self.stride, self.mask_thred) 41 | self.flag_flip = util.cal_flag_given_mask_thred(self.mask_flip, self.shift_sz, self.stride, self.mask_thred) 42 | 43 | final_out = InnerFaceShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.flag_flip, self.show_flow, flip_feat) 44 | if self.show_flow: 45 | self.flow_srcs = InnerFaceShiftTripleFunction.get_flow_src() 46 | 47 | innerFeat = input.clone().narrow(1, self.c//2, self.c//2) 48 | return final_out, innerFeat 49 | 50 | def get_flow(self): 51 | return self.flow_srcs 52 | 53 | def set_flow_true(self): 54 | self.show_flow = True 55 | 56 | def set_flow_false(self): 57 | self.show_flow = False 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__+ '(' \ 61 | + ' ,triple_weight ' + str(self.triple_weight) + ')' 62 | -------------------------------------------------------------------------------- /models/face_shift_net/InnerFaceShiftTripleFunction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from util.NonparametricShift import Modified_NonparametricShift 3 | import torch 4 | import util.util as util 5 | import time 6 | 7 | # This script offers a version of shift from multi-references. 8 | class InnerFaceShiftTripleFunction(torch.autograd.Function): 9 | ctx = None 10 | 11 | @staticmethod 12 | def forward(ctx, input, shift_sz, stride, triple_w, flag, flag_flip, show_flow, flip_feat=None): 13 | InnerFaceShiftTripleFunction.ctx = ctx 14 | assert input.dim() == 4, "Input Dim has to be 4" 15 | ctx.triple_w = triple_w 16 | ctx.flag = flag 17 | ctx.flag_flip = flag_flip 18 | ctx.show_flow = show_flow 19 | 20 | ctx.bz, c_real, ctx.h, ctx.w = input.size() 21 | c = c_real 22 | 23 | ctx.ind_lst = torch.Tensor(ctx.bz, ctx.h * ctx.w, ctx.h * ctx.w).zero_().to(input) 24 | ctx.ind_lst_flip = ctx.ind_lst.clone() 25 | 26 | # former and latter are all tensors 27 | former_all = input.narrow(1, 0, c//2) ### decoder feature 28 | latter_all = input.narrow(1, c//2, c//2) ### encoder feature 29 | shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all).zero_() # addition feature 30 | 31 | if not flip_feat is None: 32 | assert flip_feat.size() == former_all.size(), "flip_feat size should be equal to former size" 33 | 34 | ctx.flag = ctx.flag.to(input).long() 35 | ctx.flag_flip = ctx.flag_flip.to(input).long() 36 | 37 | # None batch version 38 | Nonparm = Modified_NonparametricShift() 39 | ctx.shift_offsets = [] 40 | 41 | for idx in range(ctx.bz): 42 | flag_cur = ctx.flag[idx] 43 | flag_cur_flip = ctx.flag_flip[idx] 44 | latter = latter_all.narrow(0, idx, 1) ### encoder feature 45 | former = former_all.narrow(0, idx, 1) ### decoder feature 46 | 47 | #GET COSINE, RESHAPED LATTER AND ITS INDEXES 48 | cosine, latter_windows, i_2, i_3, i_1 = Nonparm.cosine_similarity(former.clone().squeeze(), latter.clone().squeeze(), 1, stride, flag_cur) 49 | cosine_flip, latter_windows_flip, _, _, _ = Nonparm.cosine_similarity(former.clone().squeeze(), flip_feat.clone().squeeze(), 1, stride, flag_cur_flip) 50 | 51 | # compare which is the bigger one. 52 | cosine_con = torch.cat([cosine, cosine_flip], dim=1) 53 | _, indexes_con = torch.max(cosine_con, dim=1) 54 | # then ori_larger is (non_mask_count*1), 55 | # 1:indicating the original feat is better for shift. 56 | # 0:indicating the flippled feat is a better one. 57 | ori_larger = (indexes_con < cosine.size(1)).long().view(-1,1) 58 | 59 | 60 | ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY 61 | _, indexes = torch.max(cosine, dim=1) 62 | _, indexes_flip = torch.max(cosine_flip, dim=1) 63 | 64 | # SET TRANSITION MATRIX 65 | mask_indexes = (flag_cur == 1).nonzero(as_tuple=False) 66 | non_mask_indexes = (flag_cur == 0).nonzero(as_tuple=False)[indexes] 67 | # then remove some indexes where we should select flip feat according to ori_larger 68 | mask_indexes_select_index = (mask_indexes.squeeze() * ori_larger.squeeze()).nonzero(as_tuple=False) 69 | mask_indexes_select = mask_indexes[mask_indexes_select_index].squeeze() 70 | ctx.ind_lst[idx][mask_indexes_select, non_mask_indexes] = 1 71 | 72 | 73 | 74 | non_mask_indexes_flip = (flag_cur_flip == 0).nonzero(as_tuple=False)[indexes_flip] 75 | # then remove some indexes where we should select ori feat according to 1-ori_larger 76 | mask_indexes_flip_select_index = (mask_indexes.squeeze() * (1 - ori_larger.squeeze())).nonzero(as_tuple=False) 77 | mask_indexes_flip_select = mask_indexes[mask_indexes_flip_select_index].squeeze() 78 | ctx.ind_lst_flip[idx][mask_indexes_flip_select, non_mask_indexes_flip] = 1 79 | 80 | 81 | # GET FINAL SHIFT FEATURE 82 | ori_tmp = Nonparm._paste(latter_windows, ctx.ind_lst[idx], i_2, i_3, i_1) 83 | ori_tmp_flip = Nonparm._paste(latter_windows_flip, ctx.ind_lst_flip[idx], i_2, i_3, i_1) 84 | 85 | # combine the two features by directly adding, it is ok. 86 | shift_masked_all[idx] = ori_tmp + ori_tmp_flip 87 | 88 | if ctx.show_flow: 89 | shift_offset = torch.stack([non_mask_indexes.squeeze() // ctx.w, non_mask_indexes.squeeze() % ctx.w], dim=-1) 90 | ctx.shift_offsets.append(shift_offset) 91 | 92 | if ctx.show_flow: 93 | # Note: Here we assume that each mask is the same for the same batch image. 94 | ctx.shift_offsets = torch.cat(ctx.shift_offsets, dim=0).float() # make it cudaFloatTensor 95 | # Assume mask is the same for each image in a batch. 96 | mask_nums = ctx.shift_offsets.size(0)//ctx.bz 97 | ctx.flow_srcs = torch.zeros(ctx.bz, 3, ctx.h, ctx.w).type_as(input) 98 | 99 | for idx in range(ctx.bz): 100 | shift_offset = ctx.shift_offsets.narrow(0, idx*mask_nums, mask_nums) 101 | # reconstruct the original shift_map. 102 | shift_offsets_map = torch.zeros(1, ctx.h, ctx.w, 2).type_as(input) 103 | shift_offsets_map[:, (flag_cur == 1).nonzero(as_tuple=False).squeeze() // ctx.w, (flag_cur == 1).nonzero(as_tuple=False).squeeze() % ctx.w, :] = \ 104 | shift_offset.unsqueeze(0) 105 | # It is indicating the pixels(non-masked) that will shift the the masked region. 106 | flow_src = util.highlight_flow(shift_offsets_map, flag_cur.unsqueeze(0)) 107 | ctx.flow_srcs[idx] = flow_src 108 | 109 | return torch.cat((former_all, latter_all, shift_masked_all), 1) 110 | 111 | 112 | @staticmethod 113 | def get_flow_src(): 114 | return InnerFaceShiftTripleFunction.ctx.flow_srcs 115 | 116 | 117 | # How it works, the extra grad from feat_flip will be enchaned the grad of the second part of the layer (when input I). 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | ind_lst = ctx.ind_lst 121 | ind_lst_flip = ctx.ind_lst_flip 122 | 123 | c = grad_output.size(1) 124 | 125 | # # the former and the latter are keep original. Only the thrid part is shifted. 126 | grad_former_all = grad_output[:, 0:c//3, :, :] 127 | grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone() 128 | grad_shifted_all = grad_output[:, c*2//3:c, :, :].clone() 129 | 130 | for idx in range(ctx.bz): 131 | 132 | # C: content, pixels in masked region of the former part. 133 | # S: style, pixels in the non-masked region of the latter part. 134 | # N: the shifted feature, the new feature that will be used as the third part of features maps. 135 | # W_mat: ind_lst[idx], shift matrix. 136 | # Note: **only the masked region in N has values**. 137 | 138 | # The gradient of shift feature should be added back to the latter part(to be precise: S). 139 | # `ind_lst[idx][i,j] = 1` means that the i_th pixel will **be replaced** by j_th pixel in the forward. 140 | # When applying `S mm W_mat`, then S will be transfer to N. 141 | # (pixels in non-masked region of the latter part will be shift to the masked region in the third part.) 142 | # However, we need to transfer back the gradient of the third part to S. 143 | # This means the graident in S will **`be replaced`(to be precise, enhanced)** by N. 144 | 145 | # So we need to transpose `W_mat` 146 | W_mat_t = ind_lst[idx].t() 147 | W_mat_t_flip = ind_lst_flip[idx].t() 148 | 149 | grad = grad_shifted_all[idx].view(c//3, -1).t() 150 | 151 | # only the grad of points that locations(non_mask) contribute to is kept 152 | grad_shifted_weighted = torch.mm(W_mat_t, grad) 153 | # only the grad of points that locations(non_mask_flip) contribute to is kept 154 | grad_shifted_weighted_flip = torch.mm(W_mat_t_flip, grad) 155 | 156 | # Then transpose it back 157 | grad_shifted_weighted = grad_shifted_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w) 158 | grad_shifted_weighted_flip = grad_shifted_weighted_flip.t().contiguous().view(1, c//3, ctx.h, ctx.w) 159 | 160 | grad_shifted_weighted_all = grad_shifted_weighted + grad_shifted_weighted_flip 161 | 162 | grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_shifted_weighted_all.mul(ctx.triple_w)) 163 | 164 | # note the input channel and the output channel are all c, as no mask input for now. 165 | grad_input = torch.cat([grad_former_all, grad_latter_all], 1) 166 | 167 | return grad_input, None, None, None, None, None, None, None 168 | -------------------------------------------------------------------------------- /models/face_shift_net/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/face_shift_net/face_shiftnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import util.util as util 4 | from models import networks 5 | from models.shift_net.base_model import BaseModel 6 | import time 7 | import torchvision.transforms as transforms 8 | import os 9 | import numpy as np 10 | from PIL import Image 11 | 12 | class FaceShiftNetModel(BaseModel): 13 | def name(self): 14 | return 'FaceShiftNetModel' 15 | 16 | 17 | def create_random_mask(self): 18 | if self.opt.mask_type == 'random': 19 | if self.opt.mask_sub_type == 'fractal': 20 | assert 1==2, "It is broken somehow, use another mask_sub_type please" 21 | mask = util.create_walking_mask() # create an initial random mask. 22 | 23 | elif self.opt.mask_sub_type == 'rect': 24 | mask, rand_t, rand_l = util.create_rand_mask(self.opt) 25 | self.rand_t = rand_t 26 | self.rand_l = rand_l 27 | return mask 28 | 29 | elif self.opt.mask_sub_type == 'island': 30 | mask = util.wrapper_gmask(self.opt) 31 | return mask 32 | 33 | def initialize(self, opt): 34 | BaseModel.initialize(self, opt) 35 | self.opt = opt 36 | self.isTrain = opt.isTrain 37 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 38 | self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv'] 39 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 40 | if self.opt.show_flow: 41 | self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs'] 42 | else: 43 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 44 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 45 | if self.isTrain: 46 | self.model_names = ['G', 'D'] 47 | else: # during test time, only load Gs 48 | self.model_names = ['G'] 49 | 50 | 51 | # batchsize should be 1 for mask_global 52 | self.mask_global = torch.zeros(self.opt.batchSize, 1, \ 53 | opt.fineSize, opt.fineSize, dtype=torch.bool) 54 | 55 | # Here we need to set an artificial mask_global(center hole is ok.) 56 | self.mask_global.zero_() 57 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\ 58 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1 59 | 60 | if len(opt.gpu_ids) > 0: 61 | self.use_gpu = True 62 | self.mask_global = self.mask_global.to(self.device) 63 | 64 | # load/define networks 65 | # self.ng_innerCos_list is the guidance loss list in netG inner layers. 66 | # self.ng_shift_list is the mask list constructing shift operation. 67 | if opt.add_mask2input: 68 | input_nc = opt.input_nc + 1 69 | else: 70 | input_nc = opt.input_nc 71 | 72 | self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(input_nc, opt.output_nc, opt.ngf, 73 | opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain) 74 | 75 | if self.isTrain: 76 | use_sigmoid = False 77 | if opt.gan_type == 'vanilla': 78 | use_sigmoid = True # only vanilla GAN using BCECriterion 79 | # don't use cGAN 80 | self.netD = networks.define_D(opt.input_nc, opt.ndf, 81 | opt.which_model_netD, 82 | opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain) 83 | 84 | # add style extractor 85 | self.vgg16_extractor = util.VGG16FeatureExtractor() 86 | if len(opt.gpu_ids) > 0: 87 | self.vgg16_extractor = self.vgg16_extractor.to(self.gpu_ids[0]) 88 | self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor, self.gpu_ids) 89 | 90 | 91 | if self.isTrain: 92 | self.old_lr = opt.lr 93 | # define loss functions 94 | self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device) 95 | self.criterionL1 = torch.nn.L1Loss() 96 | self.criterionL1_mask = networks.Discounted_L1(opt).to(self.device) # make weights/buffers transfer to the correct device 97 | # VGG loss 98 | self.criterionL2_style_loss = torch.nn.MSELoss() 99 | self.criterionL2_content_loss = torch.nn.MSELoss() 100 | # TV loss 101 | self.tv_criterion = networks.TVLoss(self.opt.tv_weight) 102 | 103 | # initialize optimizers 104 | self.schedulers = [] 105 | self.optimizers = [] 106 | if self.opt.gan_type == 'wgan_gp': 107 | opt.beta1 = 0 108 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 109 | lr=opt.lr, betas=(opt.beta1, 0.9)) 110 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 111 | lr=opt.lr, betas=(opt.beta1, 0.9)) 112 | else: 113 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 114 | lr=opt.lr, betas=(opt.beta1, 0.999)) 115 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 116 | lr=opt.lr, betas=(opt.beta1, 0.999)) 117 | self.optimizers.append(self.optimizer_G) 118 | self.optimizers.append(self.optimizer_D) 119 | for optimizer in self.optimizers: 120 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 121 | 122 | if not self.isTrain or opt.continue_train: 123 | self.load_networks(opt.which_epoch) 124 | 125 | self.print_networks(opt.verbose) 126 | 127 | def set_input(self, input): 128 | self.image_paths = input['A_paths'] 129 | real_A = input['A'].to(self.device) 130 | real_B = input['B'].to(self.device) 131 | real_A_flip = input['A_F'].to(self.device) 132 | # directly load mask offline 133 | self.mask_global = input['M'].to(self.device).byte() 134 | self.mask_global = self.mask_global.narrow(1,0,1).bool() 135 | 136 | # create mask online 137 | if not self.opt.offline_loading_mask: 138 | if self.opt.mask_type == 'center': 139 | self.mask_global.zero_() 140 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\ 141 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1 142 | self.rand_t, self.rand_l = int(self.opt.fineSize/4) + self.opt.overlap, int(self.opt.fineSize/4) + self.opt.overlap 143 | elif self.opt.mask_type == 'random': 144 | self.mask_global = self.create_random_mask().type_as(self.mask_global).view(1, *self.mask_global.size()[-3:]) 145 | # As generating random masks online are computation-heavy 146 | # So just generate one ranodm mask for a batch images. 147 | self.mask_global = self.mask_global.expand(self.opt.batchSize, *self.mask_global.size()[-3:]) 148 | else: 149 | raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type) 150 | # For loading mask offline, we also need to change 'opt.mask_type' and 'opt.mask_sub_type' 151 | # to avoid forgetting such settings. 152 | else: 153 | self.opt.mask_type = 'random' 154 | self.opt.mask_sub_type = 'island' 155 | 156 | self.set_latent_mask(self.mask_global) 157 | 158 | real_A.narrow(1,0,1).masked_fill_(self.mask_global, 0.)#2*123.0/255.0 - 1.0 159 | real_A.narrow(1,1,1).masked_fill_(self.mask_global, 0.)#2*104.0/255.0 - 1.0 160 | real_A.narrow(1,2,1).masked_fill_(self.mask_global, 0.)#2*117.0/255.0 - 1.0 161 | 162 | self.mask_global_flip = torch.flip(self.mask_global.float(), [3]).bool() 163 | real_A_flip.narrow(1,0,1).masked_fill_(self.mask_global_flip, 0.)#2*123.0/255.0 - 1.0 164 | real_A_flip.narrow(1,1,1).masked_fill_(self.mask_global_flip, 0.)#2*104.0/255.0 - 1.0 165 | real_A_flip.narrow(1,2,1).masked_fill_(self.mask_global_flip, 0.)#2*117.0/255.0 - 1.0 166 | 167 | 168 | if self.opt.add_mask2input: 169 | # make it 4 dimensions. 170 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1. 171 | real_A = torch.cat((real_A, (~self.mask_global).expand(real_A.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1) 172 | real_A_flip = torch.cat((real_A_flip, (~self.mask_global_flip).expand(real_A_flip.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1) 173 | 174 | self.real_A = real_A 175 | self.real_B = real_B 176 | self.real_A_flip = real_A_flip 177 | 178 | 179 | def set_latent_mask(self, mask_global): 180 | for ng_shift in self.ng_shift_list: # ITERATE OVER THE LIST OF ng_shift_list 181 | ng_shift.set_mask(mask_global) 182 | for ng_innerCos in self.ng_innerCos_list: # ITERATE OVER THE LIST OF ng_innerCos_list: 183 | ng_innerCos.set_mask(mask_global) 184 | 185 | def set_gt_latent(self): 186 | if not self.opt.skip: 187 | if self.opt.add_mask2input: 188 | # make it 4 dimensions. 189 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1. 190 | real_B = torch.cat([self.real_B, (~self.mask_global).expand(self.real_B.size(0), 1, self.real_B.size(2), self.real_B.size(3)).type_as(self.real_B)], dim=1) 191 | else: 192 | real_B = self.real_B 193 | self.netG(real_B) # input ground truth 194 | 195 | 196 | def forward(self): 197 | _, flip_feat = self.netG(self.real_A_flip) 198 | # set guidance here 199 | self.set_gt_latent() 200 | self.fake_B, _ = self.netG(self.real_A, flip_feat) 201 | 202 | # Just assume one shift layer. 203 | def set_flow_src(self): 204 | self.flow_srcs = self.ng_shift_list[0].get_flow() 205 | self.flow_srcs = F.interpolate(self.flow_srcs, scale_factor=8, mode='nearest') 206 | # Just to avoid forgetting setting show_map_false 207 | self.set_show_map_false() 208 | 209 | # Just assume one shift layer. 210 | def set_show_map_true(self): 211 | self.ng_shift_list[0].set_flow_true() 212 | 213 | def set_show_map_false(self): 214 | self.ng_shift_list[0].set_flow_false() 215 | 216 | def get_image_paths(self): 217 | return self.image_paths 218 | 219 | def backward_D(self): 220 | fake_B = self.fake_B 221 | # Real 222 | real_B = self.real_B # GroundTruth 223 | 224 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results. 225 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect': 226 | # Using the cropped fake_B as the input of D. 227 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 228 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 229 | 230 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 231 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 232 | 233 | self.pred_fake = self.netD(fake_B.detach()) 234 | self.pred_real = self.netD(real_B) 235 | 236 | if self.opt.gan_type == 'wgan_gp': 237 | gradient_penalty, _ = util.cal_gradient_penalty(self.netD, real_B, fake_B.detach(), self.device, constant=1, lambda_gp=self.opt.gp_lambda) 238 | self.loss_D_fake = torch.mean(self.pred_fake) 239 | self.loss_D_real = -torch.mean(self.pred_real) 240 | 241 | self.loss_D = self.loss_D_fake + self.loss_D_real + gradient_penalty 242 | else: 243 | if self.opt.gan_type in ['vanilla', 'lsgan']: 244 | self.loss_D_fake = self.criterionGAN(self.pred_fake, False) 245 | self.loss_D_real = self.criterionGAN (self.pred_real, True) 246 | 247 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 248 | 249 | elif self.opt.gan_type == 're_s_gan': 250 | self.loss_D = self.criterionGAN(self.pred_real - self.pred_fake, True) 251 | 252 | elif self.opt.gan_type == 're_avg_gan': 253 | self.loss_D = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), True) \ 254 | + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), False)) / 2. 255 | # for `re_avg_gan`, need to retain graph of D. 256 | if self.opt.gan_type == 're_avg_gan': 257 | self.loss_D.backward(retain_graph=True) 258 | else: 259 | self.loss_D.backward() 260 | 261 | 262 | def backward_G(self): 263 | # First, G(A) should fake the discriminator 264 | fake_B = self.fake_B 265 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results. 266 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect': 267 | # Using the cropped fake_B as the input of D. 268 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 269 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 270 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 271 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 272 | else: 273 | real_B = self.real_B 274 | 275 | pred_fake = self.netD(fake_B) 276 | 277 | 278 | if self.opt.gan_type == 'wgan_gp': 279 | self.loss_G_GAN = -torch.mean(pred_fake) 280 | else: 281 | if self.opt.gan_type in ['vanilla', 'lsgan']: 282 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.gan_weight 283 | 284 | elif self.opt.gan_type == 're_s_gan': 285 | pred_real = self.netD (real_B) 286 | self.loss_G_GAN = self.criterionGAN (pred_fake - pred_real, True) * self.opt.gan_weight 287 | 288 | elif self.opt.gan_type == 're_avg_gan': 289 | self.pred_real = self.netD(real_B) 290 | self.loss_G_GAN = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \ 291 | + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2. 292 | self.loss_G_GAN *= self.opt.gan_weight 293 | 294 | 295 | # If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'. 296 | self.loss_G_L1, self.loss_G_L1_m = 0, 0 297 | self.loss_G_L1 += self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A 298 | # calcuate mask construction loss 299 | # When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1). 300 | # Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1. 301 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect': 302 | mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 303 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 304 | mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 305 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 306 | # Using Discounting L1 loss 307 | self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight_G 308 | 309 | self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN 310 | 311 | # Then, add TV loss 312 | self.loss_tv = self.tv_criterion(self.fake_B*self.mask_global.float()) 313 | 314 | # Finally, add style loss 315 | vgg_ft_fakeB = self.vgg16_extractor(fake_B) 316 | vgg_ft_realB = self.vgg16_extractor(real_B) 317 | self.loss_style = 0 318 | self.loss_content = 0 319 | 320 | for i in range(3): 321 | self.loss_style += self.criterionL2_style_loss(util.gram_matrix(vgg_ft_fakeB[i]), util.gram_matrix(vgg_ft_realB[i])) 322 | self.loss_content += self.criterionL2_content_loss(vgg_ft_fakeB[i], vgg_ft_realB[i]) 323 | 324 | self.loss_style *= self.opt.style_weight 325 | self.loss_content *= self.opt.content_weight 326 | 327 | self.loss_G += (self.loss_style + self.loss_content + self.loss_tv) 328 | 329 | self.loss_G.backward() 330 | 331 | def optimize_parameters(self): 332 | self.forward() 333 | # update D 334 | self.set_requires_grad(self.netD, True) 335 | self.optimizer_D.zero_grad() 336 | self.backward_D() 337 | self.optimizer_D.step() 338 | 339 | # update G 340 | self.set_requires_grad(self.netD, False) 341 | self.optimizer_G.zero_grad() 342 | self.backward_G() 343 | self.optimizer_G.step() 344 | 345 | 346 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .discrimators import * 2 | from .losses import * 3 | from .modules import * 4 | from .shift_unet import * 5 | from .unet import * -------------------------------------------------------------------------------- /models/modules/denset_net.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .modules import * 6 | import torch.utils.model_zoo as model_zoo 7 | from collections import OrderedDict 8 | 9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 10 | 11 | 12 | model_urls = { 13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 17 | } 18 | 19 | 20 | def densenet121(pretrained=False, use_spectral_norm=True, **kwargs): 21 | r"""Densenet-121 model from 22 | `"Densely Connected Convolutional Networks" `_ 23 | 24 | Args: 25 | pretrained (bool): If True, returns a model pre-trained on ImageNet 26 | """ 27 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), use_spectral_norm=use_spectral_norm, 28 | **kwargs) 29 | if pretrained: 30 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 31 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 32 | # They are also in the checkpoints in model_urls. This pattern is used 33 | # to find such keys. 34 | pattern = re.compile( 35 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 36 | state_dict = model_zoo.load_url(model_urls['densenet121']) 37 | for key in list(state_dict.keys()): 38 | res = pattern.match(key) 39 | if res: 40 | new_key = res.group(1) + res.group(2) 41 | state_dict[new_key] = state_dict[key] 42 | del state_dict[key] 43 | model.load_state_dict(state_dict, strict=False) 44 | return model 45 | 46 | 47 | 48 | def densenet169(pretrained=False, **kwargs): 49 | r"""Densenet-169 model from 50 | `"Densely Connected Convolutional Networks" `_ 51 | 52 | Args: 53 | pretrained (bool): If True, returns a model pre-trained on ImageNet 54 | """ 55 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 56 | **kwargs) 57 | if pretrained: 58 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 59 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 60 | # They are also in the checkpoints in model_urls. This pattern is used 61 | # to find such keys. 62 | pattern = re.compile( 63 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 64 | state_dict = model_zoo.load_url(model_urls['densenet169']) 65 | for key in list(state_dict.keys()): 66 | res = pattern.match(key) 67 | if res: 68 | new_key = res.group(1) + res.group(2) 69 | state_dict[new_key] = state_dict[key] 70 | del state_dict[key] 71 | model.load_state_dict(state_dict) 72 | return model 73 | 74 | 75 | 76 | def densenet201(pretrained=False, **kwargs): 77 | r"""Densenet-201 model from 78 | `"Densely Connected Convolutional Networks" `_ 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 84 | **kwargs) 85 | if pretrained: 86 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 87 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 88 | # They are also in the checkpoints in model_urls. This pattern is used 89 | # to find such keys. 90 | pattern = re.compile( 91 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 92 | state_dict = model_zoo.load_url(model_urls['densenet201']) 93 | for key in list(state_dict.keys()): 94 | res = pattern.match(key) 95 | if res: 96 | new_key = res.group(1) + res.group(2) 97 | state_dict[new_key] = state_dict[key] 98 | del state_dict[key] 99 | model.load_state_dict(state_dict) 100 | return model 101 | 102 | 103 | 104 | def densenet161(pretrained=False, **kwargs): 105 | r"""Densenet-161 model from 106 | `"Densely Connected Convolutional Networks" `_ 107 | 108 | Args: 109 | pretrained (bool): If True, returns a model pre-trained on ImageNet 110 | """ 111 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 112 | **kwargs) 113 | if pretrained: 114 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 115 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 116 | # They are also in the checkpoints in model_urls. This pattern is used 117 | # to find such keys. 118 | pattern = re.compile( 119 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 120 | state_dict = model_zoo.load_url(model_urls['densenet161']) 121 | for key in list(state_dict.keys()): 122 | res = pattern.match(key) 123 | if res: 124 | new_key = res.group(1) + res.group(2) 125 | state_dict[new_key] = state_dict[key] 126 | del state_dict[key] 127 | model.load_state_dict(state_dict) 128 | return model 129 | 130 | 131 | 132 | class _DenseLayer(nn.Sequential): 133 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, use_spectral_norm): 134 | super(_DenseLayer, self).__init__() 135 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 136 | self.add_module('relu1', nn.ReLU()), 137 | self.add_module('conv1', spectral_norm(nn.Conv2d(num_input_features, bn_size * 138 | growth_rate, kernel_size=1, stride=1, bias=False), use_spectral_norm)), 139 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 140 | self.add_module('relu2', nn.ReLU()), 141 | self.add_module('conv2', spectral_norm(nn.Conv2d(bn_size * growth_rate, growth_rate, 142 | kernel_size=3, stride=1, padding=1, bias=False), use_spectral_norm)), 143 | self.drop_rate = drop_rate 144 | 145 | def forward(self, x): 146 | new_features = super(_DenseLayer, self).forward(x) 147 | if self.drop_rate > 0: 148 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 149 | return torch.cat([x, new_features], 1) 150 | 151 | 152 | class _DenseBlock(nn.Sequential): 153 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, use_spectral_norm): 154 | super(_DenseBlock, self).__init__() 155 | for i in range(num_layers): 156 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, use_spectral_norm) 157 | self.add_module('denselayer%d' % (i + 1), layer) 158 | 159 | 160 | class _Transition(nn.Sequential): 161 | def __init__(self, num_input_features, num_output_features, use_spectral_norm): 162 | super(_Transition, self).__init__() 163 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 164 | self.add_module('relu', nn.ReLU()) 165 | self.add_module('conv', spectral_norm(nn.Conv2d(num_input_features, num_output_features, 166 | kernel_size=1, stride=1, bias=False), use_spectral_norm)) 167 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 168 | 169 | 170 | class DenseNet(nn.Module): 171 | r"""Densenet-BC model class, based on 172 | `"Densely Connected Convolutional Networks" `_ 173 | 174 | Args: 175 | growth_rate (int) - how many filters to add each layer (`k` in paper) 176 | block_config (list of 4 ints) - how many layers in each pooling block 177 | num_init_features (int) - the number of filters to learn in the first convolution layer 178 | bn_size (int) - multiplicative factor for number of bottle neck layers 179 | (i.e. bn_size * k features in the bottleneck layer) 180 | drop_rate (float) - dropout rate after each dense layer 181 | num_classes (int) - number of classification classes 182 | """ 183 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), use_spectral_norm=True, 184 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 185 | 186 | super(DenseNet, self).__init__() 187 | 188 | # First convolution 189 | self.features = nn.Sequential(OrderedDict([ 190 | ('conv0', spectral_norm(nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False), use_spectral_norm)), 191 | ('norm0', nn.BatchNorm2d(num_init_features)), 192 | ('relu0', nn.ReLU()), 193 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 194 | ])) 195 | 196 | # Each denseblock 197 | num_features = num_init_features 198 | for i, num_layers in enumerate(block_config): 199 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 200 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, use_spectral_norm=use_spectral_norm) 201 | self.features.add_module('denseblock%d' % (i + 1), block) 202 | num_features = num_features + num_layers * growth_rate 203 | if i != len(block_config) - 1: 204 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, use_spectral_norm=use_spectral_norm) 205 | self.features.add_module('transition%d' % (i + 1), trans) 206 | num_features = num_features // 2 207 | 208 | # Final batch norm 209 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 210 | 211 | self.conv_last = spectral_norm(nn.Conv2d(num_features, 256, kernel_size=3), use_spectral_norm) 212 | 213 | # Linear layer 214 | # self.classifier = nn.Linear(num_features, num_classes) 215 | 216 | # Official init from torch repo. 217 | for m in self.modules(): 218 | if isinstance(m, nn.Conv2d): 219 | nn.init.kaiming_normal_(m.weight.data) 220 | elif isinstance(m, nn.BatchNorm2d): 221 | m.weight.data.fill_(1) 222 | m.bias.data.zero_() 223 | elif isinstance(m, nn.Linear): 224 | m.bias.data.zero_() 225 | 226 | def forward(self, x): 227 | features = self.features(x) 228 | features = self.conv_last(features) 229 | return features 230 | -------------------------------------------------------------------------------- /models/modules/discrimators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | from .denset_net import * 4 | 5 | from .modules import * 6 | ################################### This is for D ################################### 7 | # Defines the PatchGAN discriminator with the specified arguments. 8 | class NLayerDiscriminator(nn.Module): 9 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True): 10 | super(NLayerDiscriminator, self).__init__() 11 | if type(norm_layer) == functools.partial: 12 | use_bias = norm_layer.func == nn.InstanceNorm2d 13 | else: 14 | use_bias = norm_layer == nn.InstanceNorm2d 15 | 16 | kw = 4 17 | padw = 1 18 | sequence = [ 19 | spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), use_spectral_norm), 20 | nn.LeakyReLU(0.2, True) 21 | ] 22 | 23 | nf_mult = 1 24 | nf_mult_prev = 1 25 | for n in range(1, n_layers): 26 | nf_mult_prev = nf_mult 27 | nf_mult = min(2**n, 8) 28 | sequence += [ 29 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 30 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), use_spectral_norm), 31 | norm_layer(ndf * nf_mult), 32 | nn.LeakyReLU(0.2, True) 33 | ] 34 | 35 | nf_mult_prev = nf_mult 36 | nf_mult = min(2**n_layers, 8) 37 | sequence += [ 38 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 39 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), use_spectral_norm), 40 | norm_layer(ndf * nf_mult), 41 | nn.LeakyReLU(0.2, True) 42 | ] 43 | sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), use_spectral_norm)] 44 | 45 | if use_sigmoid: 46 | sequence += [nn.Sigmoid()] 47 | 48 | self.model = nn.Sequential(*sequence) 49 | 50 | def forward(self, input): 51 | return self.model(input) 52 | 53 | 54 | # Defines a densetnet inspired discriminator (Should improve its ability to create stronger representation) 55 | class DenseNetDiscrimator(nn.Module): 56 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True): 57 | super(DenseNetDiscrimator, self).__init__() 58 | self.model = densenet121(pretrained=True, use_spectral_norm=use_spectral_norm) 59 | self.use_sigmoid = use_sigmoid 60 | if self.use_sigmoid: 61 | self.sigmoid = nn.Sigmoid() 62 | 63 | def forward(self, input): 64 | if self.use_sigmoid: 65 | return self.sigmoid(self.model(input)) 66 | else: 67 | return self.model(input) 68 | -------------------------------------------------------------------------------- /models/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 6 | # When LSGAN is used, it is basically same as MSELoss, 7 | # but it abstracts away the need to create the target label tensor 8 | # that has the same size as the input 9 | class GANLoss(nn.Module): 10 | def __init__(self, gan_type='wgan_gp', target_real_label=1.0, target_fake_label=0.0): 11 | super(GANLoss, self).__init__() 12 | self.register_buffer('real_label', torch.tensor(target_real_label)) 13 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 14 | self.gan_type = gan_type 15 | if gan_type == 'wgan_gp': 16 | self.loss = nn.MSELoss() 17 | elif gan_type == 'lsgan': 18 | self.loss = nn.MSELoss() 19 | elif gan_type == 'vanilla': 20 | self.loss = nn.BCELoss() 21 | ####################################################################### 22 | ### Relativistic GAN - https://github.com/AlexiaJM/RelativisticGAN ### 23 | ####################################################################### 24 | # When Using `BCEWithLogitsLoss()`, remove the sigmoid layer in D. 25 | elif gan_type == 're_s_gan': 26 | self.loss = nn.BCEWithLogitsLoss() 27 | elif gan_type == 're_avg_gan': 28 | self.loss = nn.BCEWithLogitsLoss() 29 | else: 30 | raise ValueError("GAN type [%s] not recognized." % gan_type) 31 | 32 | def get_target_tensor(self, prediction, target_is_real): 33 | if target_is_real: 34 | target_tensor = self.real_label 35 | else: 36 | target_tensor = self.fake_label 37 | return target_tensor.expand_as(prediction) 38 | 39 | def __call__(self, prediction, target_is_real): 40 | if self.gan_type == 'wgan_gp': 41 | if target_is_real: 42 | loss = -prediction.mean() 43 | else: 44 | loss = prediction.mean() 45 | else: 46 | target_tensor = self.get_target_tensor(prediction, target_is_real) 47 | loss = self.loss(prediction, target_tensor) 48 | return loss 49 | 50 | ################# Discounting loss ######################### 51 | ###################################################### 52 | class Discounted_L1(nn.Module): 53 | def __init__(self, opt): 54 | super(Discounted_L1, self).__init__() 55 | # Register discounting template as a buffer 56 | self.register_buffer('discounting_mask', torch.tensor(spatial_discounting_mask(opt.fineSize//2 - opt.overlap * 2, opt.fineSize//2 - opt.overlap * 2, 0.9, opt.discounting))) 57 | self.L1 = nn.L1Loss() 58 | 59 | def forward(self, input, target): 60 | self._assert_no_grad(target) 61 | input_tmp = input * self.discounting_mask 62 | target_tmp = target * self.discounting_mask 63 | return self.L1(input_tmp, target_tmp) 64 | 65 | 66 | def _assert_no_grad(self, variable): 67 | assert not variable.requires_grad, \ 68 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 69 | "mark these variables as volatile or not requiring gradients" 70 | 71 | 72 | def spatial_discounting_mask(mask_width, mask_height, discounting_gamma, discounting=1): 73 | """Generate spatial discounting mask constant. 74 | Spatial discounting mask is first introduced in publication: 75 | Generative Image Inpainting with Contextual Attention, Yu et al. 76 | Returns: 77 | tf.Tensor: spatial discounting mask 78 | """ 79 | gamma = discounting_gamma 80 | shape = [1, 1, mask_width, mask_height] 81 | if discounting: 82 | print('Use spatial discounting l1 loss.') 83 | mask_values = np.ones((mask_width, mask_height), dtype='float32') 84 | for i in range(mask_width): 85 | for j in range(mask_height): 86 | mask_values[i, j] = max( 87 | gamma**min(i, mask_width-i), 88 | gamma**min(j, mask_height-j)) 89 | mask_values = np.expand_dims(mask_values, 0) 90 | mask_values = np.expand_dims(mask_values, 1) 91 | mask_values = mask_values 92 | else: 93 | mask_values = np.ones(shape, dtype='float32') 94 | 95 | return mask_values 96 | 97 | class TVLoss(nn.Module): 98 | def __init__(self, tv_loss_weight=1): 99 | super(TVLoss, self).__init__() 100 | self.tv_loss_weight = tv_loss_weight 101 | 102 | def forward(self, x): 103 | bz, _, h, w = x.size() 104 | count_h = self._tensor_size(x[:, :, 1:, :]) 105 | count_w = self._tensor_size(x[:, :, :, 1:]) 106 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h - 1, :]), 2).sum() 107 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w - 1]), 2).sum() 108 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / bz 109 | 110 | @staticmethod 111 | def _tensor_size(t): 112 | return t.size(1) * t.size(2) * t.size(3) -------------------------------------------------------------------------------- /models/modules/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import Parameter 6 | 7 | 8 | class Self_Attn (nn.Module): 9 | """ Self attention Layer""" 10 | 11 | ''' 12 | https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py 13 | ''' 14 | 15 | def __init__(self, in_dim, activation, with_attention=False): 16 | super (Self_Attn, self).__init__ () 17 | self.chanel_in = in_dim 18 | self.activation = activation 19 | self.with_attention = with_attention 20 | 21 | self.query_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 22 | self.key_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 23 | self.value_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim, kernel_size=1) 24 | self.gamma = nn.Parameter (torch.zeros (1)) 25 | 26 | self.softmax = nn.Softmax (dim=-1) # 27 | 28 | def forward(self, x): 29 | """ 30 | inputs : 31 | x : input feature maps( B X C X W X H) 32 | returns : 33 | out : self attention value + input feature 34 | attention: B X N X N (N is Width*Height) 35 | """ 36 | m_batchsize, C, width, height = x.size () 37 | proj_query = self.query_conv (x).view (m_batchsize, -1, width * height).permute (0, 2, 1) # B X CX(N) 38 | proj_key = self.key_conv (x).view (m_batchsize, -1, width * height) # B X C x (*W*H) 39 | energy = torch.bmm (proj_query, proj_key) # transpose check 40 | attention = self.softmax (energy) # BX (N) X (N) 41 | proj_value = self.value_conv (x).view (m_batchsize, -1, width * height) # B X C X N 42 | 43 | out = torch.bmm (proj_value, attention.permute (0, 2, 1)) 44 | out = out.view (m_batchsize, C, width, height) 45 | 46 | out = self.gamma * out + x 47 | 48 | if self.with_attention: 49 | return out, attention 50 | else: 51 | return out 52 | 53 | def l2normalize(v, eps=1e-12): 54 | return v / (v.norm() + eps) 55 | 56 | def spectral_norm(module, mode=True): 57 | if mode: 58 | return nn.utils.spectral_norm(module) 59 | 60 | return module 61 | 62 | 63 | class SwitchNorm2d(nn.Module): 64 | def __init__(self, num_features, eps=1e-5, momentum=0.9, using_moving_average=True, using_bn=True, 65 | last_gamma=False): 66 | super(SwitchNorm2d, self).__init__() 67 | self.eps = eps 68 | self.momentum = momentum 69 | self.using_moving_average = using_moving_average 70 | self.using_bn = using_bn 71 | self.last_gamma = last_gamma 72 | self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1)) 73 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 74 | if self.using_bn: 75 | self.mean_weight = nn.Parameter(torch.ones(3)) 76 | self.var_weight = nn.Parameter(torch.ones(3)) 77 | else: 78 | self.mean_weight = nn.Parameter(torch.ones(2)) 79 | self.var_weight = nn.Parameter(torch.ones(2)) 80 | if self.using_bn: 81 | self.register_buffer('running_mean', torch.zeros(1, num_features, 1)) 82 | self.register_buffer('running_var', torch.zeros(1, num_features, 1)) 83 | 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | if self.using_bn: 88 | self.running_mean.zero_() 89 | self.running_var.zero_() 90 | if self.last_gamma: 91 | self.weight.data.fill_(0) 92 | else: 93 | self.weight.data.fill_(1) 94 | self.bias.data.zero_() 95 | 96 | def _check_input_dim(self, input): 97 | if input.dim() != 4: 98 | raise ValueError('expected 4D input (got {}D input)' 99 | .format(input.dim())) 100 | 101 | def forward(self, x): 102 | self._check_input_dim(x) 103 | N, C, H, W = x.size() 104 | x = x.view(N, C, -1) 105 | mean_in = x.mean(-1, keepdim=True) 106 | var_in = x.var(-1, keepdim=True) 107 | 108 | mean_ln = mean_in.mean(1, keepdim=True) 109 | temp = var_in + mean_in ** 2 110 | var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2 111 | 112 | if self.using_bn: 113 | if self.training: 114 | mean_bn = mean_in.mean(0, keepdim=True) 115 | var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2 116 | if self.using_moving_average: 117 | self.running_mean.mul_(self.momentum) 118 | self.running_mean.add_((1 - self.momentum) * mean_bn.data) 119 | self.running_var.mul_(self.momentum) 120 | self.running_var.add_((1 - self.momentum) * var_bn.data) 121 | else: 122 | self.running_mean.add_(mean_bn.data) 123 | self.running_var.add_(mean_bn.data ** 2 + var_bn.data) 124 | else: 125 | mean_bn = torch.autograd.Variable(self.running_mean) 126 | var_bn = torch.autograd.Variable(self.running_var) 127 | 128 | softmax = nn.Softmax(0) 129 | mean_weight = softmax(self.mean_weight) 130 | var_weight = softmax(self.var_weight) 131 | 132 | if self.using_bn: 133 | mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn 134 | var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn 135 | else: 136 | mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln 137 | var = var_weight[0] * var_in + var_weight[1] * var_ln 138 | 139 | x = (x-mean) / (var+self.eps).sqrt() 140 | x = x.view(N, C, H, W) 141 | return x * self.weight + self.bias 142 | 143 | 144 | class PartialConv(nn.Module): 145 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 146 | padding=0, dilation=1, groups=1, bias=True): 147 | super(PartialConv).__init__() 148 | self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 149 | stride, padding, dilation, groups, bias) 150 | self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 151 | stride, padding, dilation, groups, False) 152 | 153 | #self.input_conv.apply(weights_init('kaiming')) 154 | 155 | torch.nn.init.constant_(self.mask_conv.weight, 1.0) 156 | 157 | # mask is not updated 158 | for param in self.mask_conv.parameters(): 159 | param.requires_grad = False 160 | 161 | def forward(self, input, mask): 162 | 163 | output = self.input_conv(input * mask) 164 | if self.input_conv.bias is not None: 165 | output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( 166 | output) 167 | else: 168 | output_bias = torch.zeros_like(output) 169 | 170 | with torch.no_grad(): 171 | output_mask = self.mask_conv(mask) 172 | 173 | no_update_holes = output_mask == 0 174 | mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) 175 | 176 | output_pre = (output - output_bias) / mask_sum + output_bias 177 | output = output_pre.masked_fill_(no_update_holes, 0.0) 178 | 179 | new_mask = torch.ones_like(output) 180 | new_mask = new_mask.masked_fill_(no_update_holes, 0.0) 181 | 182 | return output, new_mask 183 | 184 | 185 | class ResnetBlock(nn.Module): 186 | def __init__(self, dim, padding_type, norm_layer, use_bias): 187 | super(ResnetBlock, self).__init__() 188 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_bias) 189 | 190 | def build_conv_block(self, dim, padding_type, norm_layer, use_bias): 191 | conv_block = [] 192 | p = 0 193 | if padding_type == 'reflect': 194 | conv_block += [nn.ReflectionPad2d(1)] 195 | elif padding_type == 'replicate': 196 | conv_block += [nn.ReplicationPad2d(1)] 197 | elif padding_type == 'zero': 198 | p = 1 199 | else: 200 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 201 | 202 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 203 | norm_layer(dim), 204 | nn.ReLU(True)] 205 | 206 | p = 0 207 | if padding_type == 'reflect': 208 | conv_block += [nn.ReflectionPad2d(1)] 209 | elif padding_type == 'replicate': 210 | conv_block += [nn.ReplicationPad2d(1)] 211 | elif padding_type == 'zero': 212 | p = 1 213 | else: 214 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 215 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 216 | norm_layer(dim)] 217 | 218 | return nn.Sequential(*conv_block) 219 | 220 | def forward(self, x): 221 | out = x + self.conv_block(x) 222 | return out 223 | -------------------------------------------------------------------------------- /models/modules/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .modules import spectral_norm 5 | 6 | # Defines the Unet generator. 7 | # |num_downs|: number of downsamplings in UNet. For example, 8 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 9 | # at the bottleneck 10 | class UnetGenerator(nn.Module): 11 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 12 | norm_layer=nn.BatchNorm2d, use_spectral_norm=False): 13 | super(UnetGenerator, self).__init__() 14 | 15 | # construct unet structure 16 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_spectral_norm=use_spectral_norm) 17 | for i in range(num_downs - 5): 18 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 19 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 20 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 21 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 22 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 23 | 24 | self.model = unet_block 25 | 26 | def forward(self, input): 27 | return self.model(input) 28 | 29 | # construct network from the inside to the outside. 30 | # Defines the submodule with skip connection. 31 | # X -------------------identity---------------------- X 32 | # |-- downsampling -- |submodule| -- upsampling --| 33 | class UnetSkipConnectionBlock(nn.Module): 34 | def __init__(self, outer_nc, inner_nc, input_nc, 35 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_spectral_norm=False): 36 | super(UnetSkipConnectionBlock, self).__init__() 37 | self.outermost = outermost 38 | 39 | if input_nc is None: 40 | input_nc = outer_nc 41 | 42 | downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4, 43 | stride=2, padding=1), use_spectral_norm) 44 | downrelu = nn.LeakyReLU(0.2, True) 45 | downnorm = norm_layer(inner_nc) 46 | uprelu = nn.ReLU(True) 47 | upnorm = norm_layer(outer_nc) 48 | 49 | # Different position only has differences in `upconv` 50 | # for the outermost, the special is `tanh` 51 | if outermost: 52 | upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc, 53 | kernel_size=4, stride=2, 54 | padding=1), use_spectral_norm) 55 | down = [downconv] 56 | up = [uprelu, upconv, nn.Tanh()] 57 | model = down + [submodule] + up 58 | # for the innermost, the special is `inner_nc` instead of `inner_nc*2` 59 | elif innermost: 60 | upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc, 61 | kernel_size=4, stride=2, 62 | padding=1), use_spectral_norm) 63 | down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn 64 | up = [uprelu, upconv, upnorm] 65 | model = down + up 66 | # else, the normal 67 | else: 68 | upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc, 69 | kernel_size=4, stride=2, 70 | padding=1), use_spectral_norm) 71 | down = [downrelu, downconv, downnorm] 72 | up = [uprelu, upconv, upnorm] 73 | 74 | model = down + [submodule] + up 75 | 76 | self.model = nn.Sequential(*model) 77 | 78 | def forward(self, x): 79 | if self.outermost: # if it is the outermost, directly pass the input in. 80 | return self.model(x) 81 | else: 82 | x_latter = self.model(x) 83 | _, _, h, w = x.size() 84 | if h != x_latter.size(2) or w != x_latter.size(3): 85 | x_latter = F.interpolate(x_latter, (h, w), mode='bilinear') 86 | return torch.cat([x_latter, x], 1) # cat in the C channel 87 | 88 | 89 | # It is an easy type of UNet, intead of constructing UNet with UnetSkipConnectionBlocks. 90 | # In this way, every thing is much clear and more flexible for extension. 91 | class EasyUnetGenerator(nn.Module): 92 | def __init__(self, input_nc, output_nc, ngf=64, 93 | norm_layer=nn.BatchNorm2d, use_spectral_norm=False): 94 | super(EasyUnetGenerator, self).__init__() 95 | 96 | # Encoder layers 97 | self.e1_c = spectral_norm(nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm) 98 | 99 | self.e2_c = spectral_norm(nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm) 100 | self.e2_norm = norm_layer(ngf*2) 101 | 102 | self.e3_c = spectral_norm(nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm) 103 | self.e3_norm = norm_layer(ngf*4) 104 | 105 | self.e4_c = spectral_norm(nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 106 | self.e4_norm = norm_layer(ngf*8) 107 | 108 | self.e5_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 109 | self.e5_norm = norm_layer(ngf*8) 110 | 111 | self.e6_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 112 | self.e6_norm = norm_layer(ngf*8) 113 | 114 | self.e7_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 115 | self.e7_norm = norm_layer(ngf*8) 116 | 117 | self.e8_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 118 | 119 | # Deocder layers 120 | self.d1_c = spectral_norm(nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 121 | self.d1_norm = norm_layer(ngf*8) 122 | 123 | self.d2_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2 , ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 124 | self.d2_norm = norm_layer(ngf*8) 125 | 126 | self.d3_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 127 | self.d3_norm = norm_layer(ngf*8) 128 | 129 | self.d4_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm) 130 | self.d4_norm = norm_layer(ngf*8) 131 | 132 | self.d5_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm) 133 | self.d5_norm = norm_layer(ngf*4) 134 | 135 | self.d6_c = spectral_norm(nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm) 136 | self.d6_norm = norm_layer(ngf*2) 137 | 138 | self.d7_c = spectral_norm(nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm) 139 | self.d7_norm = norm_layer(ngf) 140 | 141 | self.d8_c = spectral_norm(nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1), use_spectral_norm) 142 | 143 | 144 | # In this case, we have very flexible unet construction mode. 145 | def forward(self, input): 146 | # Encoder 147 | # No norm on the first layer 148 | e1 = self.e1_c(input) 149 | e2 = self.e2_norm(self.e2_c(F.leaky_relu_(e1, negative_slope=0.2))) 150 | e3 = self.e3_norm(self.e3_c(F.leaky_relu_(e2, negative_slope=0.2))) 151 | e4 = self.e4_norm(self.e4_c(F.leaky_relu_(e3, negative_slope=0.2))) 152 | e5 = self.e5_norm(self.e5_c(F.leaky_relu_(e4, negative_slope=0.2))) 153 | e6 = self.e6_norm(self.e6_c(F.leaky_relu_(e5, negative_slope=0.2))) 154 | e7 = self.e7_norm(self.e7_c(F.leaky_relu_(e6, negative_slope=0.2))) 155 | # No norm on the inner_most layer 156 | e8 = self.e8_c(F.leaky_relu_(e7, negative_slope=0.2)) 157 | 158 | # Decoder 159 | d1 = self.d1_norm(self.d1_c(F.relu_(e8))) 160 | d2 = self.d2_norm(self.d2_c(F.relu_(torch.cat([d1, e7], dim=1)))) 161 | d3 = self.d3_norm(self.d3_c(F.relu_(torch.cat([d2, e6], dim=1)))) 162 | d4 = self.d4_norm(self.d4_c(F.relu_(torch.cat([d3, e5], dim=1)))) 163 | d5 = self.d5_norm(self.d5_c(F.relu_(torch.cat([d4, e4], dim=1)))) 164 | d6 = self.d6_norm(self.d6_c(F.relu_(torch.cat([d5, e3], dim=1)))) 165 | d7 = self.d7_norm(self.d7_c(F.relu_(torch.cat([d6, e2], dim=1)))) 166 | # No norm on the last layer 167 | d8 = self.d8_c(F.relu_(torch.cat([d7, e1], 1))) 168 | 169 | d8 = torch.tanh(d8) 170 | 171 | return d8 172 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | from torch.nn import init 3 | from torch.optim import lr_scheduler 4 | from torchvision import models 5 | 6 | 7 | from .modules import * 8 | 9 | ############################################################################### 10 | # Functions 11 | ############################################################################### 12 | def get_norm_layer(norm_type='instance'): 13 | if norm_type == 'batch': 14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 15 | elif norm_type == 'instance': 16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=False) 17 | elif norm_type == 'switchable': 18 | norm_layer = functools.partial(SwitchNorm2d) 19 | elif norm_type == 'none': 20 | norm_layer = None 21 | else: 22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 23 | return norm_layer 24 | 25 | 26 | def get_scheduler(optimizer, opt): 27 | if opt.lr_policy == 'lambda': 28 | def lambda_rule(epoch): 29 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 30 | return lr_l 31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 32 | elif opt.lr_policy == 'step': 33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 34 | elif opt.lr_policy == 'plateau': 35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 36 | elif opt.lr_policy == 'cosine': 37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 38 | else: 39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 40 | return scheduler 41 | 42 | 43 | def init_weights(net, init_type='normal', gain=0.02): 44 | def init_func(m): 45 | classname = m.__class__.__name__ 46 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 47 | if init_type == 'normal': 48 | init.normal_(m.weight.data, 0.0, gain) 49 | elif init_type == 'xavier': 50 | init.xavier_normal_(m.weight.data, gain=gain) 51 | elif init_type == 'kaiming': 52 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 53 | elif init_type == 'orthogonal': 54 | init.orthogonal_(m.weight.data, gain=gain) 55 | else: 56 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 57 | if hasattr(m, 'bias') and m.bias is not None: 58 | init.constant_(m.bias.data, 0.0) 59 | elif classname.find('BatchNorm2d') != -1: 60 | init.normal_(m.weight.data, 1.0, gain) 61 | init.constant_(m.bias.data, 0.0) 62 | 63 | print('initialize network with %s' % init_type) 64 | net.apply(init_func) 65 | 66 | 67 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 68 | if len(gpu_ids) > 0: 69 | assert(torch.cuda.is_available()) 70 | net.to(gpu_ids[0]) 71 | net = torch.nn.DataParallel(net, gpu_ids) 72 | init_weights(net, init_type, gain=init_gain) 73 | return net 74 | 75 | 76 | # Note: Adding SN to G tends to give inferior results. Need more checking. 77 | def define_G(input_nc, output_nc, ngf, which_model_netG, opt, mask_global, norm='batch', use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02): 78 | netG = None 79 | norm_layer = get_norm_layer(norm_type=norm) 80 | 81 | innerCos_list = [] 82 | shift_list = [] 83 | 84 | print('input_nc {}'.format(input_nc)) 85 | print('output_nc {}'.format(output_nc)) 86 | print('which_model_netG {}'.format(which_model_netG)) 87 | 88 | # Here we need to initlize an artificial mask_global to construct the init model. 89 | # When training, we need to set mask for special layers(mostly for Shift layers) first. 90 | # If mask is fixed during training, we only need to set mask for these layers once, 91 | # else we need to set the masks each iteration, generating new random masks and mask the input 92 | # as well as setting masks for these special layers. 93 | print('[CREATED] MODEL') 94 | if which_model_netG == 'unet_256': 95 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 96 | elif which_model_netG == 'easy_unet_256': 97 | netG = EasyUnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 98 | elif which_model_netG == 'face_unet_shift_triple': 99 | netG = FaceUnetGenerator(input_nc, output_nc, innerCos_list, shift_list, mask_global, opt, \ 100 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 101 | elif which_model_netG == 'unet_shift_triple': 102 | netG = UnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \ 103 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 104 | elif which_model_netG == 'res_unet_shift_triple': 105 | netG = ResUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \ 106 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 107 | elif which_model_netG == 'patch_soft_unet_shift_triple': 108 | netG = PatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \ 109 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 110 | elif which_model_netG == 'res_patch_soft_unet_shift_triple': 111 | netG = ResPatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \ 112 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm) 113 | else: 114 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 115 | print('[CREATED] MODEL') 116 | print('Constraint in netG:') 117 | print(innerCos_list) 118 | 119 | print('Shift in netG:') 120 | print(shift_list) 121 | 122 | print('NetG:') 123 | print(netG) 124 | 125 | return init_net(netG, init_type, init_gain, gpu_ids), innerCos_list, shift_list 126 | 127 | 128 | def define_D(input_nc, ndf, which_model_netD, 129 | n_layers_D=3, norm='batch', use_sigmoid=False, use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02): 130 | netD = None 131 | norm_layer = get_norm_layer(norm_type=norm) 132 | 133 | if which_model_netD == 'basic': 134 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm) 135 | 136 | elif which_model_netD == 'n_layers': 137 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm) 138 | 139 | elif which_model_netD == 'densenet': 140 | netD = DenseNetDiscrimator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm) 141 | 142 | else: 143 | print('Discriminator model name [%s] is not recognized' % 144 | which_model_netD) 145 | 146 | print('NetD:') 147 | print(netD) 148 | return init_net(netD, init_type, init_gain, gpu_ids) 149 | 150 | -------------------------------------------------------------------------------- /models/patch_soft_shift/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/patch_soft_shift/__init__.py -------------------------------------------------------------------------------- /models/patch_soft_shift/innerPatchSoftShiftTriple.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import util.util as util 4 | from .innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule 5 | 6 | 7 | # TODO: Make it compatible for show_flow. 8 | # 9 | class InnerPatchSoftShiftTriple(nn.Module): 10 | def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3): 11 | super(InnerPatchSoftShiftTriple, self).__init__() 12 | 13 | self.shift_sz = shift_sz 14 | self.stride = stride 15 | self.mask_thred = mask_thred 16 | self.triple_weight = triple_weight 17 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy. 18 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region) 19 | self.fuse = fuse 20 | self.layer_to_last = layer_to_last 21 | self.softShift = InnerPatchSoftShiftTripleModule() 22 | 23 | def set_mask(self, mask_global): 24 | mask = util.cal_feat_mask(mask_global, self.layer_to_last) 25 | self.mask = mask 26 | return self.mask 27 | 28 | # If mask changes, then need to set cal_fix_flag true each iteration. 29 | def forward(self, input): 30 | _, self.c, self.h, self.w = input.size() 31 | 32 | # Just pass self.mask in, instead of self.flag. 33 | final_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse) 34 | if self.show_flow: 35 | self.flow_srcs = self.softShift.get_flow_src() 36 | return final_out 37 | 38 | def get_flow(self): 39 | return self.flow_srcs 40 | 41 | def set_flow_true(self): 42 | self.show_flow = True 43 | 44 | def set_flow_false(self): 45 | self.show_flow = False 46 | 47 | def __repr__(self): 48 | return self.__class__.__name__+ '(' \ 49 | + ' ,triple_weight ' + str(self.triple_weight) + ')' 50 | -------------------------------------------------------------------------------- /models/patch_soft_shift/innerPatchSoftShiftTripleModule.py: -------------------------------------------------------------------------------- 1 | from util.NonparametricShift import Modified_NonparametricShift 2 | from torch.nn import functional as F 3 | import torch.nn as nn 4 | import torch 5 | import util.util as util 6 | 7 | 8 | class InnerPatchSoftShiftTripleModule(nn.Module): 9 | def forward(self, input, stride, triple_w, mask, mask_thred, shift_sz, show_flow, fuse=True): 10 | assert input.dim() == 4, "Input Dim has to be 4" 11 | assert mask.dim() == 4, "Mask Dim has to be 4" 12 | self.triple_w = triple_w 13 | self.mask = mask 14 | self.mask_thred = mask_thred 15 | self.show_flow = show_flow 16 | 17 | self.bz, self.c, self.h, self.w = input.size() 18 | 19 | self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.FloatTensor 20 | 21 | self.ind_lst = self.Tensor(self.bz, self.h * self.w, self.h * self.w).zero_() 22 | 23 | # former and latter are all tensors 24 | former_all = input.narrow(1, 0, self.c//2) ### decoder feature 25 | latter_all = input.narrow(1, self.c//2, self.c//2) ### encoder feature 26 | shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all) # addition feature 27 | 28 | self.mask = self.mask.to(input) 29 | 30 | # extract patches from latter. 31 | latter_all_pad = F.pad(latter_all, [shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2], 'constant', 0) 32 | latter_all_windows = latter_all_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride) 33 | latter_all_windows = latter_all_windows.contiguous().view(self.bz, -1, self.c//2, shift_sz, shift_sz) 34 | 35 | # Extract patches from mask 36 | # Mention: mask here must be 1*1*H*W 37 | m_pad = F.pad(self.mask, (shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2), 'constant', 0) 38 | m = m_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride) 39 | m = m.contiguous().view(self.bz, 1, -1, shift_sz, shift_sz) 40 | 41 | # It implements the similar functionality as `cal_flag_given_mask_thred`. 42 | # However, it differs what `mm` means. 43 | # Here mm: the masked reigon is filled with 0, nonmasked region is filled with 1. 44 | # While mm in `cal_flag_given_mask_thred`, it is opposite. 45 | m = torch.mean(torch.mean(m, dim=3, keepdim=True), dim=4, keepdim=True) 46 | mm = m.le(self.mask_thred/(1.*shift_sz**2)).float() # bz*1*(32*32)*1*1 47 | 48 | fuse_weight = torch.eye(shift_sz).view(1, 1, shift_sz, shift_sz).type_as(input) 49 | 50 | self.shift_offsets = [] 51 | for idx in range(self.bz): 52 | mm_cur = mm[idx] 53 | # latter_win = latter_all_windows.narrow(0, idx, 1)[0] 54 | latter_win = latter_all_windows.narrow(0, idx, 1)[0] 55 | former = former_all.narrow(0, idx, 1) 56 | 57 | # normalize latter for each patch. 58 | latter_den = torch.sqrt(torch.einsum("bcij,bcij->b", [latter_win, latter_win])) 59 | latter_den = torch.max(latter_den, self.Tensor([1e-4])) 60 | 61 | latter_win_normed = latter_win/latter_den.view(-1, 1, 1, 1) 62 | y_i = F.conv2d(former, latter_win_normed, stride=1, padding=shift_sz//2) 63 | 64 | # conv implementation for fuse scores to encourage large patches 65 | if fuse: 66 | y_i = y_i.view(1, 1, self.h*self.w, self.h*self.w) # make all of depth of spatial resolution. 67 | y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1) 68 | 69 | y_i = y_i.contiguous().view(1, self.h, self.w, self.h, self.w) 70 | y_i = y_i.permute(0, 2, 1, 4, 3) 71 | y_i = y_i.contiguous().view(1, 1, self.h*self.w, self.h*self.w) 72 | 73 | y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1) 74 | y_i = y_i.contiguous().view(1, self.w, self.h, self.w, self.h) 75 | y_i = y_i.permute(0, 2, 1, 4, 3) 76 | 77 | y_i = y_i.contiguous().view(1, self.h*self.w, self.h, self.w) # 1*(32*32)*32*32 78 | 79 | # firstly, wash away the masked reigon. 80 | # multiply `mm` means (:, index_masked, :, :) will be 0. 81 | y_i = y_i * mm_cur 82 | 83 | # Then apply softmax to the nonmasked region. 84 | cosine = F.softmax(y_i*10, dim=1) 85 | 86 | # Finally, dummy parameters of masked reigon are filtered out. 87 | cosine = cosine * mm_cur 88 | 89 | # paste 90 | shift_i = F.conv_transpose2d(cosine, latter_win, stride=1, padding=shift_sz//2)/9. 91 | shift_masked_all[idx] = shift_i 92 | 93 | # Addition: show shift map 94 | # TODO: fix me. 95 | # cosine here is a full size of 32*32, not only the masked region in `shift_net`, 96 | # which results in non-direct reusing the code. 97 | # torch.set_printoptions(threshold=2015) 98 | # if self.show_flow: 99 | # _, indexes = torch.max(cosine, dim=1) 100 | # # calculate self.flag from self.m 101 | # self.flag = (1 - mm).view(-1) 102 | # torch.set_printoptions(threshold=1025) 103 | # print(self.flag) 104 | # non_mask_indexes = (self.flag == 0.).nonzero() 105 | # non_mask_indexes = non_mask_indexes[indexes] 106 | # print('ll') 107 | # print(non_mask_indexes.size()) 108 | # print(non_mask_indexes) 109 | # # Here non_mask_index is too large, should be 192. 110 | # shift_offset = torch.stack([non_mask_indexes.squeeze() // self.w, non_mask_indexes.squeeze() % self.w], dim=-1) 111 | # print(shift_offset.size()) 112 | # self.shift_offsets.append(shift_offset) 113 | 114 | # print('cc') 115 | # if self.show_flow: 116 | # # Note: Here we assume that each mask is the same for the same batch image. 117 | # self.shift_offsets = torch.cat(self.shift_offsets, dim=0).float() # make it cudaFloatTensor 118 | # # Assume mask is the same for each image in a batch. 119 | # mask_nums = self.shift_offsets.size(0)//self.bz 120 | # self.flow_srcs = torch.zeros(self.bz, 3, self.h, self.w).type_as(input) 121 | 122 | # for idx in range(self.bz): 123 | # shift_offset = self.shift_offsets.narrow(0, idx*mask_nums, mask_nums) 124 | # # reconstruct the original shift_map. 125 | # shift_offsets_map = torch.zeros(1, self.h, self.w, 2).type_as(input) 126 | # print(shift_offsets_map.size()) 127 | # print(shift_offset.unsqueeze(0).size()) 128 | 129 | # print(shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :].size()) 130 | # shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :] = \ 131 | # shift_offset.unsqueeze(0) 132 | # # It is indicating the pixels(non-masked) that will shift the the masked region. 133 | # flow_src = util.highlight_flow(shift_offsets_map, self.flag.unsqueeze(0)) 134 | # self.flow_srcs[idx] = flow_src 135 | 136 | return torch.cat((former_all, latter_all, shift_masked_all), 1) 137 | 138 | def get_flow_src(self): 139 | return self.flow_srcs 140 | -------------------------------------------------------------------------------- /models/patch_soft_shift/patch_soft_shiftnet_model.py: -------------------------------------------------------------------------------- 1 | from models.shift_net.shiftnet_model import ShiftNetModel 2 | 3 | 4 | class PatchSoftShiftNetModel(ShiftNetModel): 5 | def name(self): 6 | return 'PatchSoftShiftNetModel' 7 | -------------------------------------------------------------------------------- /models/res_patch_soft_shift/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/res_patch_soft_shift/__init__.py -------------------------------------------------------------------------------- /models/res_patch_soft_shift/innerResPatchSoftShiftTriple.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import util.util as util 4 | 5 | from models.patch_soft_shift.innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule 6 | 7 | 8 | # TODO: Make it compatible for show_flow. 9 | # 10 | class InnerResPatchSoftShiftTriple(nn.Module): 11 | def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3): 12 | super(InnerResPatchSoftShiftTriple, self).__init__() 13 | 14 | self.shift_sz = shift_sz 15 | self.stride = stride 16 | self.mask_thred = mask_thred 17 | self.triple_weight = triple_weight 18 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy. 19 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region) 20 | self.fuse = fuse 21 | self.layer_to_last = layer_to_last 22 | self.softShift = InnerPatchSoftShiftTripleModule() 23 | 24 | # Additional for ResShift. 25 | self.inner_nc = inner_nc 26 | self.res_net = nn.Sequential( 27 | nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1), 28 | nn.InstanceNorm2d(inner_nc), 29 | nn.ReLU(True), 30 | nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1), 31 | nn.InstanceNorm2d(inner_nc) 32 | ) 33 | 34 | def set_mask(self, mask_global): 35 | mask = util.cal_feat_mask(mask_global, self.layer_to_last) 36 | self.mask = mask 37 | return self.mask 38 | 39 | # If mask changes, then need to set cal_fix_flag true each iteration. 40 | def forward(self, input): 41 | _, self.c, self.h, self.w = input.size() 42 | 43 | # Just pass self.mask in, instead of self.flag. 44 | # Try to making it faster by avoiding `cal_flag_given_mask_thread`. 45 | shift_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse) 46 | 47 | c_out = shift_out.size(1) 48 | # get F_c, F_s, F_shift 49 | F_c = shift_out.narrow(1, 0, c_out//3) 50 | F_s = shift_out.narrow(1, c_out//3, c_out//3) 51 | F_shift = shift_out.narrow(1, c_out*2//3, c_out//3) 52 | F_fuse = F_c * F_shift 53 | F_com = torch.cat([F_c, F_fuse], dim=1) 54 | 55 | res_out = self.res_net(F_com) 56 | F_c = F_c + res_out 57 | 58 | final_out = torch.cat([F_c, F_s], dim=1) 59 | 60 | if self.show_flow: 61 | self.flow_srcs = self.softShift.get_flow_src() 62 | return final_out 63 | 64 | def get_flow(self): 65 | return self.flow_srcs 66 | 67 | def set_flow_true(self): 68 | self.show_flow = True 69 | 70 | def set_flow_false(self): 71 | self.show_flow = False 72 | 73 | def __repr__(self): 74 | return self.__class__.__name__+ '(' \ 75 | + ' ,triple_weight ' + str(self.triple_weight) + ')' 76 | -------------------------------------------------------------------------------- /models/res_patch_soft_shift/res_patch_soft_shiftnet_model.py: -------------------------------------------------------------------------------- 1 | from models.shift_net.shiftnet_model import ShiftNetModel 2 | 3 | 4 | class ResPatchSoftShiftNetModel(ShiftNetModel): 5 | def name(self): 6 | return 'ResPatchSoftShiftNetModel' 7 | -------------------------------------------------------------------------------- /models/res_shift_net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/res_shift_net/__init__.py -------------------------------------------------------------------------------- /models/res_shift_net/innerResShiftTriple.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import util.util as util 4 | 5 | from models.shift_net.InnerShiftTripleFunction import InnerShiftTripleFunction 6 | 7 | class InnerResShiftTriple(nn.Module): 8 | def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3): 9 | super(InnerResShiftTriple, self).__init__() 10 | 11 | self.shift_sz = shift_sz 12 | self.stride = stride 13 | self.mask_thred = mask_thred 14 | self.triple_weight = triple_weight 15 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy. 16 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region) 17 | self.layer_to_last = layer_to_last 18 | 19 | # Additional for ResShift. 20 | self.inner_nc = inner_nc 21 | self.res_net = nn.Sequential( 22 | nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1), 23 | nn.InstanceNorm2d(inner_nc), 24 | nn.ReLU(True), 25 | nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1), 26 | nn.InstanceNorm2d(inner_nc) 27 | ) 28 | 29 | 30 | def set_mask(self, mask_global): 31 | mask = util.cal_feat_mask(mask_global, self.layer_to_last) 32 | self.mask = mask.squeeze() 33 | return self.mask 34 | 35 | # If mask changes, then need to set cal_fix_flag true each iteration. 36 | def forward(self, input): 37 | #print(input.shape) 38 | _, self.c, self.h, self.w = input.size() 39 | self.flag = util.cal_flag_given_mask_thred(self.mask, self.shift_sz, self.stride, self.mask_thred) 40 | shift_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow) 41 | 42 | c_out = shift_out.size(1) 43 | # get F_c, F_s, F_shift 44 | F_c = shift_out.narrow(1, 0, c_out//3) 45 | F_s = shift_out.narrow(1, c_out//3, c_out//3) 46 | F_shift = shift_out.narrow(1, c_out*2//3, c_out//3) 47 | F_fuse = F_c * F_shift 48 | F_com = torch.cat([F_c, F_fuse], dim=1) 49 | 50 | res_out = self.res_net(F_com) 51 | F_c = F_c + res_out 52 | 53 | final_out = torch.cat([F_c, F_s], dim=1) 54 | 55 | if self.show_flow: 56 | self.flow_srcs = InnerShiftTripleFunction.get_flow_src() 57 | return final_out 58 | 59 | def get_flow(self): 60 | return self.flow_srcs 61 | 62 | def set_flow_true(self): 63 | self.show_flow = True 64 | 65 | def set_flow_false(self): 66 | self.show_flow = False 67 | 68 | def __repr__(self): 69 | return self.__class__.__name__+ '(' \ 70 | + ' ,triple_weight ' + str(self.triple_weight) + ')' 71 | -------------------------------------------------------------------------------- /models/res_shift_net/shiftnet_model.py: -------------------------------------------------------------------------------- 1 | from models.shift_net.shiftnet_model import ShiftNetModel 2 | 3 | 4 | class ResShiftNetModel(ShiftNetModel): 5 | def name(self): 6 | return 'ResShiftNetModel' -------------------------------------------------------------------------------- /models/shift_net/InnerCos.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import util.util as util 5 | from .InnerCosFunction import InnerCosFunction 6 | 7 | class InnerCos(nn.Module): 8 | def __init__(self, crit='MSE', strength=1, skip=0, layer_to_last=3, device='gpu'): 9 | super(InnerCos, self).__init__() 10 | self.crit = crit 11 | self.criterion = torch.nn.MSELoss() if self.crit == 'MSE' else torch.nn.L1Loss() 12 | self.strength = strength 13 | # To define whether this layer is skipped. 14 | self.skip = skip 15 | self.layer_to_last = layer_to_last 16 | self.device = device 17 | # Init a dummy value is fine. 18 | self.target = torch.tensor(1.0) 19 | 20 | def set_mask(self, mask_global): 21 | mask_all = util.cal_feat_mask(mask_global, self.layer_to_last) 22 | self.mask_all = mask_all.float() 23 | 24 | 25 | def _split_mask(self, cur_bsize): 26 | # get the visible indexes of gpus and assign correct mask to set of images 27 | cur_device = torch.cuda.current_device() 28 | self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :] 29 | 30 | def forward(self, in_data): 31 | self.bz, self.c, _, _ = in_data.size() 32 | if self.device != 'cpu': 33 | self._split_mask(self.bz) 34 | else: 35 | self.cur_mask = self.mask_all 36 | self.cur_mask = self.cur_mask.to(in_data) 37 | if not self.skip: 38 | # It works like this: 39 | # Each iteration contains 2 forward passes, In the first forward pass, we input a GT image, just to get the target. 40 | # In the second forward pass, we input the corresponding corrupted image, then back-propagate the network, the guidance loss works as expected. 41 | self.output = InnerCosFunction.apply(in_data, self.criterion, self.strength, self.target, self.cur_mask) 42 | self.target = in_data.narrow(1, self.c // 2, self.c // 2).detach() # the latter part 43 | else: 44 | self.output = in_data 45 | return self.output 46 | 47 | 48 | def __repr__(self): 49 | skip_str = 'True' if not self.skip else 'False' 50 | return self.__class__.__name__+ '(' \ 51 | + 'skip: ' + skip_str \ 52 | + 'layer ' + str(self.layer_to_last) + ' to last' \ 53 | + ' ,strength: ' + str(self.strength) + ')' 54 | -------------------------------------------------------------------------------- /models/shift_net/InnerCosFunction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class InnerCosFunction(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, input, criterion, strength, target, mask): 8 | ctx.c = input.size(1) 9 | ctx.strength = strength 10 | ctx.criterion = criterion 11 | if len(target.size()) == 0: # For the first iteration. 12 | target = target.expand_as(input.narrow(1, ctx.c // 2, ctx.c // 2)).type_as(input) 13 | 14 | ctx.save_for_backward(input, target, mask) 15 | return input 16 | 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | 21 | with torch.enable_grad(): 22 | input, target, mask = ctx.saved_tensors 23 | former = input.narrow(1, 0, ctx.c//2) 24 | former_in_mask = torch.mul(former, mask) 25 | if former_in_mask.size() != target.size(): # For the last iteration of one epoch 26 | target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask) 27 | 28 | former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True) 29 | ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength 30 | ctx.loss.backward() 31 | 32 | grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad 33 | 34 | return grad_output, None, None, None, None -------------------------------------------------------------------------------- /models/shift_net/InnerShiftTriple.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import util.util as util 4 | from .InnerShiftTripleFunction import InnerShiftTripleFunction 5 | 6 | class InnerShiftTriple(nn.Module): 7 | def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3, device='gpu'): 8 | super(InnerShiftTriple, self).__init__() 9 | 10 | self.shift_sz = shift_sz 11 | self.stride = stride 12 | self.mask_thred = mask_thred 13 | self.triple_weight = triple_weight 14 | self.layer_to_last = layer_to_last 15 | self.device = device 16 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy. 17 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region) 18 | 19 | 20 | def set_mask(self, mask_global): 21 | self.mask_all = util.cal_feat_mask(mask_global, self.layer_to_last) 22 | 23 | def _split_mask(self, cur_bsize): 24 | # get the visible indexes of gpus and assign correct mask to set of images 25 | cur_device = torch.cuda.current_device() 26 | self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :] 27 | 28 | # If mask changes, then need to set cal_fix_flag true each iteration. 29 | def forward(self, input): 30 | self.bz, self.c, self.h, self.w = input.size() 31 | if self.device != 'cpu': 32 | self._split_mask(self.bz) 33 | else: 34 | self.cur_mask = self.mask_all 35 | self.flag = util.cal_flag_given_mask_thred(self.cur_mask, self.shift_sz, self.stride, self.mask_thred) 36 | final_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow) 37 | if self.show_flow: 38 | self.flow_srcs = InnerShiftTripleFunction.get_flow_src() 39 | return final_out 40 | 41 | def get_flow(self): 42 | return self.flow_srcs 43 | 44 | def set_flow_true(self): 45 | self.show_flow = True 46 | 47 | def set_flow_false(self): 48 | self.show_flow = False 49 | 50 | def __repr__(self): 51 | return self.__class__.__name__+ '(' \ 52 | + ' ,triple_weight ' + str(self.triple_weight) + ')' 53 | -------------------------------------------------------------------------------- /models/shift_net/InnerShiftTripleFunction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift 3 | import torch 4 | import util.util as util 5 | import time 6 | 7 | 8 | class InnerShiftTripleFunction(torch.autograd.Function): 9 | ctx = None 10 | 11 | @staticmethod 12 | def forward(ctx, input, shift_sz, stride, triple_w, flag, show_flow): 13 | InnerShiftTripleFunction.ctx = ctx 14 | assert input.dim() == 4, "Input Dim has to be 4" 15 | ctx.triple_w = triple_w 16 | ctx.flag = flag 17 | ctx.show_flow = show_flow 18 | 19 | ctx.bz, c_real, ctx.h, ctx.w = input.size() 20 | c = c_real 21 | 22 | ctx.ind_lst = torch.Tensor(ctx.bz, ctx.h * ctx.w, ctx.h * ctx.w).zero_().to(input) 23 | 24 | # former and latter are all tensors 25 | former_all = input.narrow(1, 0, c//2) ### decoder feature 26 | latter_all = input.narrow(1, c//2, c//2) ### encoder feature 27 | shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all).zero_() # addition feature 28 | 29 | ctx.flag = ctx.flag.to(input).long() 30 | 31 | # None batch version 32 | bNonparm = Batch_NonShift() 33 | ctx.shift_offsets = [] 34 | 35 | # batch version 36 | cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former_all.clone(), latter_all.clone(), 1, stride, flag) 37 | 38 | _, indexes = torch.max(cosine, dim=2) 39 | 40 | mask_indexes = (flag==1).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1) 41 | 42 | non_mask_indexes = (flag==0).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1).gather(1, indexes) 43 | 44 | idx_b = torch.arange(ctx.bz).long().unsqueeze(1).expand(ctx.bz, mask_indexes.size(1)) 45 | # set the elemnets of indexed by [mask_indexes, non_mask_indexes] to 1. 46 | # It is a batch version 47 | ctx.ind_lst[(idx_b, mask_indexes, non_mask_indexes)] = 1 48 | 49 | shift_masked_all = bNonparm._paste(latter_windows, ctx.ind_lst, i_2, i_3, i_1) 50 | 51 | 52 | # --- Non-batch version ---- 53 | #for idx in range(ctx.bz): 54 | # flag_cur = ctx.flag[idx] 55 | # latter = latter_all.narrow(0, idx, 1) ### encoder feature 56 | # former = former_all.narrow(0, idx, 1) ### decoder feature 57 | 58 | # #GET COSINE, RESHAPED LATTER AND ITS INDEXES 59 | # cosine, latter_windows, i_2, i_3, i_1 = Nonparm.cosine_similarity(former.clone().squeeze(), latter.clone().squeeze(), 1, stride, flag_cur) 60 | 61 | # ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY 62 | # _, indexes = torch.max(cosine, dim=1) 63 | 64 | # # SET TRANSITION MATRIX 65 | # mask_indexes = (flag_cur == 1).nonzero() 66 | # non_mask_indexes = (flag_cur == 0).nonzero()[indexes] 67 | # ctx.ind_lst[idx][mask_indexes, non_mask_indexes] = 1 68 | 69 | # # GET FINAL SHIFT FEATURE 70 | # shift_masked_all[idx] = Nonparm._paste(latter_windows, ctx.ind_lst[idx], i_2, i_3, i_1) 71 | 72 | # if ctx.show_flow: 73 | # shift_offset = torch.stack([non_mask_indexes.squeeze() // ctx.w, non_mask_indexes.squeeze() % ctx.w], dim=-1) 74 | # ctx.shift_offsets.append(shift_offset) 75 | 76 | if ctx.show_flow: 77 | assert 1==2, "I do not want maintance the functionality of `show flow`... ^_^" 78 | ctx.shift_offsets = torch.cat(ctx.shift_offsets, dim=0).float() # make it cudaFloatTensor 79 | # Assume mask is the same for each image in a batch. 80 | mask_nums = ctx.shift_offsets.size(0)//ctx.bz 81 | ctx.flow_srcs = torch.zeros(ctx.bz, 3, ctx.h, ctx.w).type_as(input) 82 | 83 | for idx in range(ctx.bz): 84 | shift_offset = ctx.shift_offsets.narrow(0, idx*mask_nums, mask_nums) 85 | # reconstruct the original shift_map. 86 | shift_offsets_map = torch.zeros(1, ctx.h, ctx.w, 2).type_as(input) 87 | shift_offsets_map[:, (flag_cur == 1).nonzero(as_tuple=False).squeeze() // ctx.w, (flag_cur == 1).nonzero(as_tuple=False).squeeze() % ctx.w, :] = \ 88 | shift_offset.unsqueeze(0) 89 | # It is indicating the pixels(non-masked) that will shift the the masked region. 90 | flow_src = util.highlight_flow(shift_offsets_map, flag_cur.unsqueeze(0)) 91 | ctx.flow_srcs[idx] = flow_src 92 | 93 | return torch.cat((former_all, latter_all, shift_masked_all), 1) 94 | 95 | 96 | @staticmethod 97 | def get_flow_src(): 98 | return InnerShiftTripleFunction.ctx.flow_srcs 99 | 100 | @staticmethod 101 | def backward(ctx, grad_output): 102 | ind_lst = ctx.ind_lst 103 | 104 | c = grad_output.size(1) 105 | 106 | # # the former and the latter are keep original. Only the thrid part is shifted. 107 | # C: content, pixels in masked region of the former part. 108 | # S: style, pixels in the non-masked region of the latter part. 109 | # N: the shifted feature, the new feature that will be used as the third part of features maps. 110 | # W_mat: ind_lst[idx], shift matrix. 111 | # Note: **only the masked region in N has values**. 112 | 113 | # The gradient of shift feature should be added back to the latter part(to be precise: S). 114 | # `ind_lst[idx][i,j] = 1` means that the i_th pixel will **be replaced** by j_th pixel in the forward. 115 | # When applying `S mm W_mat`, then S will be transfer to N. 116 | # (pixels in non-masked region of the latter part will be shift to the masked region in the third part.) 117 | # However, we need to transfer back the gradient of the third part to S. 118 | # This means the graident in S will **`be replaced`(to be precise, enhanced)** by N. 119 | grad_former_all = grad_output[:, 0:c//3, :, :] 120 | grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone() 121 | grad_shifted_all = grad_output[:, c*2//3:c, :, :].clone() 122 | 123 | W_mat_t = ind_lst.permute(0, 2, 1).contiguous() 124 | grad = grad_shifted_all.view(ctx.bz, c//3, -1).permute(0, 2, 1) 125 | grad_shifted_weighted = torch.bmm(W_mat_t, grad) 126 | grad_shifted_weighted = grad_shifted_weighted.permute(0, 2, 1).contiguous().view(ctx.bz, c//3, ctx.h, ctx.w) 127 | grad_latter_all = torch.add(grad_latter_all, grad_shifted_weighted.mul(ctx.triple_w)) 128 | 129 | # ----- 'Non_batch version here' -------------------- 130 | # for idx in range(ctx.bz): 131 | # # So we need to transpose `W_mat` 132 | # W_mat_t = ind_lst[idx].t() 133 | 134 | # grad = grad_shifted_all[idx].view(c//3, -1).t() 135 | 136 | # grad_shifted_weighted = torch.mm(W_mat_t, grad) 137 | 138 | # # Then transpose it back 139 | # grad_shifted_weighted = grad_shifted_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w) 140 | # grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_shifted_weighted.mul(ctx.triple_w)) 141 | 142 | # note the input channel and the output channel are all c, as no mask input for now. 143 | grad_input = torch.cat([grad_former_all, grad_latter_all], 1) 144 | 145 | return grad_input, None, None, None, None, None, None 146 | -------------------------------------------------------------------------------- /models/shift_net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/shift_net/__init__.py -------------------------------------------------------------------------------- /models/shift_net/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | 5 | 6 | class BaseModel(): 7 | def name(self): 8 | return 'BaseModel' 9 | 10 | def initialize(self, opt): 11 | self.opt = opt 12 | self.gpu_ids = opt.gpu_ids 13 | self.isTrain = opt.isTrain 14 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 16 | if opt.resize_or_crop != 'scale_width': 17 | torch.backends.cudnn.benchmark = True 18 | self.loss_names = [] 19 | self.model_names = [] 20 | self.visual_names = [] 21 | self.image_paths = [] 22 | 23 | def set_input(self, input): 24 | self.input = input 25 | 26 | def forward(self): 27 | pass 28 | 29 | # used in test time, wrapping `forward` in no_grad() so we don't save 30 | # intermediate steps for backprop 31 | def test(self): 32 | with torch.no_grad(): 33 | self.forward() 34 | 35 | # get image paths 36 | def get_image_paths(self): 37 | return self.image_paths 38 | 39 | def optimize_parameters(self): 40 | pass 41 | 42 | # update learning rate (called once every epoch) 43 | def update_learning_rate(self): 44 | for scheduler in self.schedulers: 45 | scheduler.step() 46 | lr = self.optimizers[0].param_groups[0]['lr'] 47 | print('learning rate = %.7f' % lr) 48 | 49 | # return visualization images. train.py will display these images, and save the images to a html 50 | def get_current_visuals(self): 51 | visual_ret = OrderedDict() 52 | for name in self.visual_names: 53 | if isinstance(name, str): 54 | visual_ret[name] = getattr(self, name) 55 | return visual_ret 56 | 57 | # return traning losses/errors. train.py will print out these errors as debugging information 58 | def get_current_losses(self): 59 | errors_ret = OrderedDict() 60 | for name in self.loss_names: 61 | if isinstance(name, str): 62 | # float(...) works for both scalar tensor and float number 63 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 64 | return errors_ret 65 | 66 | # save models to the disk 67 | def save_networks(self, which_epoch): 68 | for name in self.model_names: 69 | if isinstance(name, str): 70 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 71 | save_path = os.path.join(self.save_dir, save_filename) 72 | net = getattr(self, 'net' + name) 73 | 74 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 75 | torch.save(net.module.cpu().state_dict(), save_path) 76 | net.cuda(self.gpu_ids[0]) 77 | else: 78 | torch.save(net.cpu().state_dict(), save_path) 79 | 80 | 81 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 82 | key = keys[i] 83 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 84 | if module.__class__.__name__.startswith('InstanceNorm') and \ 85 | (key == 'running_mean' or key == 'running_var'): 86 | if getattr(module, key) is None: 87 | state_dict.pop('.'.join(keys)) 88 | else: 89 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 90 | 91 | # load models from the disk 92 | def load_networks(self, which_epoch): 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 96 | load_path = os.path.join(self.save_dir, load_filename) 97 | net = getattr(self, 'net' + name) 98 | if isinstance(net, torch.nn.DataParallel): 99 | net = net.module 100 | # if you are using PyTorch newer than 0.4 (e.g., built from 101 | # GitHub source), you can remove str() on self.device 102 | state_dict = torch.load(load_path, map_location=str(self.device)) 103 | # patch InstanceNorm checkpoints prior to 0.4 104 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 105 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 106 | net.load_state_dict(state_dict) 107 | 108 | # print network information 109 | def print_networks(self, verbose): 110 | print('---------- Networks initialized -------------') 111 | for name in self.model_names: 112 | if isinstance(name, str): 113 | net = getattr(self, 'net' + name) 114 | num_params = 0 115 | for param in net.parameters(): 116 | num_params += param.numel() 117 | if verbose: 118 | print(net) 119 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 120 | print('-----------------------------------------------') 121 | 122 | def set_requires_grad(self, nets, requires_grad=False): 123 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 124 | Parameters: 125 | nets (network list) -- a list of networks 126 | requires_grad (bool) -- whether the networks require gradients or not 127 | """ 128 | if not isinstance(nets, list): 129 | nets = [nets] 130 | for net in nets: 131 | if net is not None: 132 | for param in net.parameters(): 133 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /models/shift_net/shiftnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import util.util as util 4 | from models import networks 5 | from models.shift_net.base_model import BaseModel 6 | import time 7 | import torchvision.transforms as transforms 8 | import os 9 | import numpy as np 10 | from PIL import Image 11 | 12 | class ShiftNetModel(BaseModel): 13 | def name(self): 14 | return 'ShiftNetModel' 15 | 16 | 17 | def create_random_mask(self): 18 | if self.opt.mask_type == 'random': 19 | if self.opt.mask_sub_type == 'fractal': 20 | assert 1==2, "It is broken somehow, use another mask_sub_type please" 21 | mask = util.create_walking_mask() # create an initial random mask. 22 | 23 | elif self.opt.mask_sub_type == 'rect': 24 | mask, rand_t, rand_l = util.create_rand_mask(self.opt) 25 | self.rand_t = rand_t 26 | self.rand_l = rand_l 27 | return mask 28 | 29 | elif self.opt.mask_sub_type == 'island': 30 | mask = util.wrapper_gmask(self.opt) 31 | return mask 32 | 33 | def initialize(self, opt): 34 | BaseModel.initialize(self, opt) 35 | self.opt = opt 36 | self.isTrain = opt.isTrain 37 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 38 | self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv'] 39 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 40 | if self.opt.show_flow: 41 | self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs'] 42 | else: 43 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 44 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 45 | if self.isTrain: 46 | self.model_names = ['G', 'D'] 47 | else: # during test time, only load Gs 48 | self.model_names = ['G'] 49 | 50 | 51 | # batchsize should be 1 for mask_global 52 | self.mask_global = torch.zeros((self.opt.batchSize, 1, \ 53 | opt.fineSize, opt.fineSize), dtype=torch.bool) 54 | 55 | # Here we need to set an artificial mask_global(center hole is ok.) 56 | self.mask_global.zero_() 57 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\ 58 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1 59 | 60 | if len(opt.gpu_ids) > 0: 61 | self.mask_global = self.mask_global.to(self.device) 62 | 63 | # load/define networks 64 | # self.ng_innerCos_list is the guidance loss list in netG inner layers. 65 | # self.ng_shift_list is the mask list constructing shift operation. 66 | if opt.add_mask2input: 67 | input_nc = opt.input_nc + 1 68 | else: 69 | input_nc = opt.input_nc 70 | 71 | self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(input_nc, opt.output_nc, opt.ngf, 72 | opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain) 73 | 74 | if self.isTrain: 75 | use_sigmoid = False 76 | if opt.gan_type == 'vanilla': 77 | use_sigmoid = True # only vanilla GAN using BCECriterion 78 | # don't use cGAN 79 | self.netD = networks.define_D(opt.input_nc, opt.ndf, 80 | opt.which_model_netD, 81 | opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain) 82 | 83 | # add style extractor 84 | self.vgg16_extractor = util.VGG16FeatureExtractor() 85 | if len(opt.gpu_ids) > 0: 86 | self.vgg16_extractor = self.vgg16_extractor.to(self.gpu_ids[0]) 87 | self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor, self.gpu_ids) 88 | 89 | if self.isTrain: 90 | self.old_lr = opt.lr 91 | # define loss functions 92 | self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device) 93 | self.criterionL1 = torch.nn.L1Loss() 94 | self.criterionL1_mask = networks.Discounted_L1(opt).to(self.device) # make weights/buffers transfer to the correct device 95 | # VGG loss 96 | self.criterionL2_style_loss = torch.nn.MSELoss() 97 | self.criterionL2_content_loss = torch.nn.MSELoss() 98 | # TV loss 99 | self.tv_criterion = networks.TVLoss(self.opt.tv_weight) 100 | 101 | # initialize optimizers 102 | self.schedulers = [] 103 | self.optimizers = [] 104 | if self.opt.gan_type == 'wgan_gp': 105 | opt.beta1 = 0 106 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 107 | lr=opt.lr, betas=(opt.beta1, 0.9)) 108 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 109 | lr=opt.lr, betas=(opt.beta1, 0.9)) 110 | else: 111 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 112 | lr=opt.lr, betas=(opt.beta1, 0.999)) 113 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 114 | lr=opt.lr, betas=(opt.beta1, 0.999)) 115 | self.optimizers.append(self.optimizer_G) 116 | self.optimizers.append(self.optimizer_D) 117 | for optimizer in self.optimizers: 118 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 119 | 120 | if not self.isTrain or opt.continue_train: 121 | self.load_networks(opt.which_epoch) 122 | 123 | self.print_networks(opt.verbose) 124 | 125 | def set_input(self, input): 126 | self.image_paths = input['A_paths'] 127 | real_A = input['A'].to(self.device) 128 | real_B = input['B'].to(self.device) 129 | # directly load mask offline 130 | self.mask_global = input['M'].to(self.device).byte() 131 | self.mask_global = self.mask_global.narrow(1,0,1).bool() 132 | 133 | # create mask online 134 | if not self.opt.offline_loading_mask: 135 | if self.opt.mask_type == 'center': 136 | self.mask_global.zero_() 137 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\ 138 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1 139 | self.rand_t, self.rand_l = int(self.opt.fineSize/4) + self.opt.overlap, int(self.opt.fineSize/4) + self.opt.overlap 140 | elif self.opt.mask_type == 'random': 141 | self.mask_global = self.create_random_mask().type_as(self.mask_global).view(1, *self.mask_global.size()[-3:]) 142 | # As generating random masks online are computation-heavy 143 | # So just generate one ranodm mask for a batch images. 144 | self.mask_global = self.mask_global.expand(self.opt.batchSize, *self.mask_global.size()[-3:]) 145 | else: 146 | raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type) 147 | # For loading mask offline, we also need to change 'opt.mask_type' and 'opt.mask_sub_type' 148 | # to avoid forgetting such settings. 149 | else: 150 | self.opt.mask_type = 'random' 151 | self.opt.mask_sub_type = 'island' 152 | 153 | self.set_latent_mask(self.mask_global) 154 | 155 | real_A.narrow(1,0,1).masked_fill_(self.mask_global, 0.)#2*123.0/255.0 - 1.0 156 | real_A.narrow(1,1,1).masked_fill_(self.mask_global, 0.)#2*104.0/255.0 - 1.0 157 | real_A.narrow(1,2,1).masked_fill_(self.mask_global, 0.)#2*117.0/255.0 - 1.0 158 | 159 | if self.opt.add_mask2input: 160 | # make it 4 dimensions. 161 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1. 162 | real_A = torch.cat((real_A, (~self.mask_global).expand(real_A.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1) 163 | 164 | self.real_A = real_A 165 | self.real_B = real_B 166 | 167 | 168 | def set_latent_mask(self, mask_global): 169 | for ng_shift in self.ng_shift_list: # ITERATE OVER THE LIST OF ng_shift_list 170 | ng_shift.set_mask(mask_global) 171 | for ng_innerCos in self.ng_innerCos_list: # ITERATE OVER THE LIST OF ng_innerCos_list: 172 | ng_innerCos.set_mask(mask_global) 173 | 174 | def set_gt_latent(self): 175 | if not self.opt.skip: 176 | if self.opt.add_mask2input: 177 | # make it 4 dimensions. 178 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1. 179 | real_B = torch.cat([self.real_B, (~self.mask_global).expand(self.real_B.size(0), 1, self.real_B.size(2), self.real_B.size(3)).type_as(self.real_B)], dim=1) 180 | else: 181 | real_B = self.real_B 182 | self.netG(real_B) # input ground truth 183 | 184 | 185 | def forward(self): 186 | self.set_gt_latent() 187 | self.fake_B = self.netG(self.real_A) 188 | 189 | # Just assume one shift layer. 190 | def set_flow_src(self): 191 | self.flow_srcs = self.ng_shift_list[0].get_flow() 192 | self.flow_srcs = F.interpolate(self.flow_srcs, scale_factor=8, mode='nearest') 193 | # Just to avoid forgetting setting show_map_false 194 | self.set_show_map_false() 195 | 196 | # Just assume one shift layer. 197 | def set_show_map_true(self): 198 | self.ng_shift_list[0].set_flow_true() 199 | 200 | def set_show_map_false(self): 201 | self.ng_shift_list[0].set_flow_false() 202 | 203 | def get_image_paths(self): 204 | return self.image_paths 205 | 206 | def backward_D(self): 207 | fake_B = self.fake_B 208 | # Real 209 | real_B = self.real_B # GroundTruth 210 | 211 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results. 212 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect': 213 | # Using the cropped fake_B as the input of D. 214 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 215 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 216 | 217 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 218 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 219 | 220 | self.pred_fake = self.netD(fake_B.detach()) 221 | self.pred_real = self.netD(real_B) 222 | 223 | if self.opt.gan_type == 'wgan_gp': 224 | gradient_penalty, _ = util.cal_gradient_penalty(self.netD, real_B, fake_B.detach(), self.device, constant=1, lambda_gp=self.opt.gp_lambda) 225 | self.loss_D_fake = torch.mean(self.pred_fake) 226 | self.loss_D_real = -torch.mean(self.pred_real) 227 | 228 | self.loss_D = self.loss_D_fake + self.loss_D_real + gradient_penalty 229 | else: 230 | if self.opt.gan_type in ['vanilla', 'lsgan']: 231 | self.loss_D_fake = self.criterionGAN(self.pred_fake, False) 232 | self.loss_D_real = self.criterionGAN (self.pred_real, True) 233 | 234 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 235 | 236 | elif self.opt.gan_type == 're_s_gan': 237 | self.loss_D = self.criterionGAN(self.pred_real - self.pred_fake, True) 238 | 239 | self.loss_D.backward() 240 | 241 | 242 | def backward_G(self): 243 | # First, G(A) should fake the discriminator 244 | fake_B = self.fake_B 245 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results. 246 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect': 247 | # Using the cropped fake_B as the input of D. 248 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 249 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 250 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 251 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 252 | else: 253 | real_B = self.real_B 254 | 255 | pred_fake = self.netD(fake_B) 256 | 257 | 258 | if self.opt.gan_type == 'wgan_gp': 259 | self.loss_G_GAN = -torch.mean(pred_fake) 260 | else: 261 | if self.opt.gan_type in ['vanilla', 'lsgan']: 262 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.gan_weight 263 | 264 | elif self.opt.gan_type == 're_s_gan': 265 | pred_real = self.netD (real_B) 266 | self.loss_G_GAN = self.criterionGAN (pred_fake - pred_real, True) * self.opt.gan_weight 267 | 268 | elif self.opt.gan_type == 're_avg_gan': 269 | self.pred_real = self.netD(real_B) 270 | self.loss_G_GAN = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \ 271 | + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2. 272 | self.loss_G_GAN *= self.opt.gan_weight 273 | 274 | 275 | # If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'. 276 | self.loss_G_L1, self.loss_G_L1_m = 0, 0 277 | self.loss_G_L1 += self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A 278 | # calcuate mask construction loss 279 | # When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1). 280 | # Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1. 281 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect': 282 | mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 283 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 284 | mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \ 285 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap] 286 | # Using Discounting L1 loss 287 | self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight_G 288 | 289 | self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN 290 | 291 | # Then, add TV loss 292 | self.loss_tv = self.tv_criterion(self.fake_B*self.mask_global.float()) 293 | 294 | # Finally, add style loss 295 | vgg_ft_fakeB = self.vgg16_extractor(fake_B) 296 | vgg_ft_realB = self.vgg16_extractor(real_B) 297 | self.loss_style = 0 298 | self.loss_content = 0 299 | 300 | for i in range(3): 301 | self.loss_style += self.criterionL2_style_loss(util.gram_matrix(vgg_ft_fakeB[i]), util.gram_matrix(vgg_ft_realB[i])) 302 | self.loss_content += self.criterionL2_content_loss(vgg_ft_fakeB[i], vgg_ft_realB[i]) 303 | 304 | self.loss_style *= self.opt.style_weight 305 | self.loss_content *= self.opt.content_weight 306 | 307 | self.loss_G += (self.loss_style + self.loss_content + self.loss_tv) 308 | 309 | self.loss_G.backward() 310 | 311 | def optimize_parameters(self): 312 | self.forward() 313 | # update D 314 | self.set_requires_grad(self.netD, True) 315 | self.optimizer_D.zero_grad() 316 | self.backward_D() 317 | self.optimizer_D.step() 318 | 319 | # update G 320 | self.set_requires_grad(self.netD, False) 321 | self.optimizer_G.zero_grad() 322 | self.backward_G() 323 | self.optimizer_G.step() 324 | 325 | 326 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/notebooks/__init__.py -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.initialized = False 9 | 10 | def initialize(self, parser): 11 | parser.add_argument('--dataroot', default='./datasets/Paris/train', help='path to training/testing images') 12 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 13 | parser.add_argument('--loadSize', type=int, default=350, help='scale images to this size') 14 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 15 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 16 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 17 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 18 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 19 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD, [basic|densenet]') 20 | parser.add_argument('--which_model_netG', type=str, default='unet_shift_triple', help='selects model to use for netG [unet_256| unet_shift_triple| \ 21 | res_unet_shift_triple|patch_soft_unet_shift_triple| \ 22 | res_patch_soft_unet_shift_triple| face_unet_shift_triple]') 23 | parser.add_argument('--model', type=str, default='shiftnet', \ 24 | help='chooses which model to use. [shiftnet|res_shiftnet|patch_soft_shiftnet|res_patch_soft_shiftnet|test]') 25 | parser.add_argument('--triple_weight', type=float, default=1, help='The weight on the gradient of skip connections from the gradient of shifted') 26 | parser.add_argument('--name', type=str, default='exp', help='name of the experiment. It decides where to store samples and models') 27 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 28 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, use \'-1 \' for cpu training/testing') 29 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [aligned | aligned_resized | single]') 30 | parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 31 | parser.add_argument('--checkpoints_dir', type=str, default='./log', help='models are saved here') 32 | parser.add_argument('--norm', type=str, default='instance', help='[instance|batch|switchable] normalization') 33 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 34 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 35 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') 36 | 37 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 38 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width]') 39 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 40 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 41 | parser.add_argument('--show_flow', type=int, default=0, help='show the flow information. WARNING: set display_freq a large number as it is quite slow when showing flow') 42 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 43 | ## model specific 44 | parser.add_argument('--mask_type', type=str, default='center', 45 | help='the type of mask you want to apply, \'center\' or \'random\'') 46 | parser.add_argument('--mask_sub_type', type=str, default='island', 47 | help='the type of mask you want to apply, \'rect \' or \'fractal \' or \'island \'') 48 | parser.add_argument('--lambda_A', type=int, default=100, help='weight on L1 term in objective') 49 | parser.add_argument('--stride', type=int, default=1, help='should be dense, 1 is a good option.') 50 | parser.add_argument('--shift_sz', type=int, default=1, help='shift_sz>1 only for \'soft_shift_patch\'.') 51 | parser.add_argument('--mask_thred', type=int, default=1, help='number to decide whether a patch is masked') 52 | parser.add_argument('--overlap', type=int, default=4, help='the overlap for center mask') 53 | parser.add_argument('--bottleneck', type=int, default=512, help='neurals of fc') 54 | parser.add_argument('--gp_lambda', type=float, default=10.0, help='gradient penalty coefficient') 55 | parser.add_argument('--constrain', type=str, default='MSE', help='guidance loss type') 56 | parser.add_argument('--strength', type=float, default=1, help='the weight of guidance loss') 57 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 58 | parser.add_argument('--skip', type=int, default=0, help='Whether skip guidance loss, if skipped performance degrades with dozens of percents faster') 59 | parser.add_argument('--fuse', type=int, default=0, help='Fuse may encourage large patches shifting when using \'patch_soft_shift\'') 60 | parser.add_argument('--gan_type', type=str, default='vanilla', help='wgan_gp, ' 61 | 'lsgan, ' 62 | 'vanilla, ' 63 | 're_s_gan (Relativistic Standard GAN), ') 64 | parser.add_argument('--gan_weight', type=float, default=0.2, help='the weight of gan loss') 65 | # New added 66 | parser.add_argument('--style_weight', type=float, default=10.0, help='the weight of style loss') 67 | parser.add_argument('--content_weight', type=float, default=1.0, help='the weight of content loss') 68 | parser.add_argument('--tv_weight', type=float, default=0.0, help='the weight of tv loss, you can set a small value, such as 0.1/0.01') 69 | parser.add_argument('--offline_loading_mask', type=int, default=0, help='whether to load mask offline randomly') 70 | parser.add_argument('--mask_weight_G', type=float, default=400.0, help='the weight of mask part in ouput of G, you can try different mask_weight') 71 | parser.add_argument('--discounting', type=int, default=1, help='the loss type of mask part, whether using discounting l1 loss or normal l1') 72 | parser.add_argument('--use_spectral_norm_D', type=int, default=1, help='whether to add spectral norm to D, it helps improve results') 73 | parser.add_argument('--use_spectral_norm_G', type=int, default=0, help='whether to add spectral norm in G. Seems very bad when adding SN to G') 74 | parser.add_argument('--only_lastest', type=int, default=0, 75 | help='If True, it will save only the lastest weights') 76 | parser.add_argument('--add_mask2input', type=int, default=1, 77 | help='If True, It will add the mask as a fourth dimension over input space') 78 | 79 | self.initialized = True 80 | return parser 81 | 82 | def gather_options(self, options=None): 83 | # initialize parser with basic options 84 | if not self.initialized: 85 | parser = argparse.ArgumentParser( 86 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 87 | parser = self.initialize(parser) 88 | 89 | 90 | self.parser = parser 91 | if options == None: 92 | return parser.parse_args() 93 | else: 94 | return parser.parse_args(options) 95 | 96 | def print_options(self, opt): 97 | message = '' 98 | message += '----------------- Options ---------------\n' 99 | for k, v in sorted(vars(opt).items()): 100 | comment = '' 101 | default = self.parser.get_default(k) 102 | if v != default: 103 | comment = '\t[default: %s]' % str(default) 104 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 105 | message += '----------------- End -------------------' 106 | print(message) 107 | 108 | # save to the disk 109 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 110 | util.mkdirs(expr_dir) 111 | file_name = os.path.join(expr_dir, 'opt.txt') 112 | with open(file_name, 'wt') as opt_file: 113 | opt_file.write(message) 114 | opt_file.write('\n') 115 | 116 | def parse(self, options=None): 117 | 118 | opt = self.gather_options(options=options) 119 | opt.isTrain = self.isTrain # train or test 120 | 121 | # process opt.suffix 122 | if opt.suffix: 123 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 124 | opt.name = opt.name + suffix 125 | 126 | self.print_options(opt) 127 | 128 | # set gpu ids 129 | os.environ["CUDA_VISIBLE_DEVICES"]=opt.gpu_ids 130 | str_ids = opt.gpu_ids.split(',') 131 | opt.gpu_ids = [] 132 | for str_id in str_ids: 133 | id = int(str_id) 134 | if id >= 0: 135 | opt.gpu_ids.append(id) 136 | # re-order gpu ids 137 | opt.gpu_ids = [i.item() for i in torch.arange(len(opt.gpu_ids))] 138 | if len(opt.gpu_ids) > 0: 139 | torch.cuda.set_device(opt.gpu_ids[0]) 140 | 141 | self.opt = opt 142 | return self.opt 143 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | parser.add_argument('--which_epoch', type=str, default='20', help='which epoch to load? set to latest to use latest cached model') 12 | parser.add_argument('--how_many', type=int, default=1000, help='how many test images to run') 13 | parser.add_argument('--testing_mask_folder', type=str, default='masks/testing_masks', help='perpared masks for testing') 14 | self.isTrain = False 15 | 16 | return parser 17 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | # Here is the options especially for training 4 | 5 | class TrainOptions(BaseOptions): 6 | def initialize(self, parser): 7 | parser = BaseOptions.initialize(self, parser) 8 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 9 | parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 10 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 11 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 12 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 13 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 14 | parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console') 15 | parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 16 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 17 | parser.add_argument('--save_epoch_freq', type=int, default=2, help='frequency of saving checkpoints at the end of epochs') 18 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 19 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 20 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 21 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 22 | parser.add_argument('--niter', type=int, default=30, help='# of iter at starting learning rate') 23 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 24 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 25 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 26 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 27 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 28 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 29 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 30 | parser.add_argument('--training_mask_folder', type=str, default='masks/training_masks', help='prepared masks for training') 31 | self.isTrain = True 32 | 33 | return parser 34 | -------------------------------------------------------------------------------- /shift_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/shift_layer.png -------------------------------------------------------------------------------- /show_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | import util.util as util 4 | from util.NonparametricShift import Modified_NonparametricShift 5 | from torch.nn import functional as F 6 | import numpy as numpy 7 | import matplotlib.pyplot as plt 8 | 9 | bz = 1 10 | c = 2 # at least 2 11 | w = 4 12 | h = 4 13 | 14 | feature_size = [bz, c, w, h] 15 | 16 | former = torch.rand(c*h*w).mul_(50).reshape(c, h, w).int().float() 17 | latter = torch.rand(c*h*w).mul_(50).reshape(c, h, w).int().float() 18 | 19 | 20 | flag = torch.zeros(h,w).byte() 21 | flag[h//4:h//2+1, h//4:h//2+1] = 1 22 | flag = flag.view(h*w) 23 | 24 | ind_lst = torch.FloatTensor(h*w, h*w).zero_() 25 | shift_offsets = [] 26 | 27 | Nonparm = Modified_NonparametricShift() 28 | cosine, latter_windows, i_2, i_3, i_1, i_4 = Nonparm.cosine_similarity(former, latter, 1, 1, flag) 29 | ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY 30 | 31 | _, indexes = torch.max(cosine, dim=1) 32 | 33 | 34 | # SET TRANSITION MATRIX 35 | mask_indexes = (flag == 1).nonzero() 36 | non_mask_indexes = (flag == 0).nonzero()[indexes] 37 | ind_lst[mask_indexes, non_mask_indexes] = 1 38 | 39 | 40 | # GET FINAL SHIFT FEATURE 41 | shift_masked_all = Nonparm._paste(latter_windows, ind_lst, i_2, i_3, i_1, i_4) 42 | 43 | print('flag') 44 | print(flag.reshape(h,w)) 45 | print('ind_lst') 46 | print(ind_lst) 47 | print('out') 48 | print(shift_masked_all) 49 | 50 | # get shift offset () 51 | shift_offset = torch.stack([non_mask_indexes.squeeze() // w, torch.fmod(non_mask_indexes.squeeze(), w)], dim=-1) 52 | 53 | 54 | shift_offsets.append(shift_offset) 55 | shift_offsets = torch.cat(shift_offsets, dim=0).float() 56 | print('shift_offset') 57 | print(shift_offset) 58 | print(shift_offset.size()) # (5*5)*2 (masked points) 59 | 60 | shift_offsets_cl = shift_offsets.clone() 61 | 62 | 63 | #visualize which pixels are attended 64 | print(flag.size()) # 256, (16*16) 65 | 66 | 67 | # global and N*C*H*W 68 | # put shift_offsets_cl back to the global map. 69 | shift_offsets_map = torch.zeros(bz, h, w, 2).float() 70 | print(shift_offsets_map.size()) # 1*16*16 71 | 72 | # mask_indexes 是对应的mask区域的点的位置。 73 | # shift_offsets是对应的要shift到mask区域的外部点的位置。 74 | shift_offsets_map[:, mask_indexes.squeeze() // w, mask_indexes.squeeze() % w, :] = shift_offsets_cl.unsqueeze(0) 75 | # 至此,shift_offsets_map是完整的,而且只有mask内部有值,代表着该点将被外面的某点替换。“某点”的坐标就是该点的值(2个通道) 76 | print('global shift_offsets_map') 77 | print(shift_offsets_map) 78 | print(shift_offsets_map.size()) 79 | print(shift_offsets_map.type()) 80 | 81 | flow2 = til.highlight_flow(shift_offsets_map, flag.unsqueeze(0)) 82 | print('flow2 size') 83 | print(flow2.size()) 84 | 85 | # upflow = F.interpolate(flow, scale_factor=4, mode='nearest') 86 | upflow2 = F.interpolate(flow2, scale_factor=1, mode='nearest') 87 | 88 | print('**After upsample flow2 size**') 89 | print(upflow2.size()) 90 | 91 | # upflow = upflow.squeeze().permute(1,2,0) 92 | upflow2 = upflow2.squeeze().permute(1,2,0) 93 | print(upflow2.size()) 94 | 95 | # print('flow 1') 96 | # print(upflow) 97 | # print(upflow.size()) 98 | 99 | # print('flow 2') 100 | # print(upflow2) 101 | # print(upflow2.size()) 102 | plt.imshow(upflow2/255.) 103 | # # axs[0].imshow(upflow) 104 | # axs[1].imshow(upflow2) 105 | 106 | plt.show() 107 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import CreateDataLoader 5 | from models import create_model 6 | from util.visualizer import save_images 7 | from util import html 8 | 9 | if __name__ == "__main__": 10 | opt = TestOptions().parse() 11 | opt.nThreads = 1 # test code only supports nThreads = 1 12 | opt.batchSize = 1 # test code only supports batchSize = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | opt.display_id = -1 # no visdom display 16 | opt.loadSize = opt.fineSize # Do not scale! 17 | 18 | data_loader = CreateDataLoader(opt) 19 | dataset = data_loader.load_data() 20 | model = create_model(opt) 21 | 22 | # create website 23 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 24 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 25 | # test 26 | for i, data in enumerate(dataset): 27 | if i >= opt.how_many: 28 | break 29 | t1 = time.time() 30 | model.set_input(data) 31 | model.test() 32 | t2 = time.time() 33 | print(t2-t1) 34 | visuals = model.get_current_visuals() 35 | img_path = model.get_image_paths() 36 | print('process image... %s' % img_path) 37 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 38 | webpage.save() 39 | -------------------------------------------------------------------------------- /test_acc_shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import util.util as util 3 | from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift 4 | from torch.nn import functional as F 5 | import numpy as numpy 6 | import matplotlib.pyplot as plt 7 | 8 | bz = 2 9 | c = 3 # at least 2 10 | w = 16 11 | h = 16 12 | 13 | feature_size = [bz, c, w, h] 14 | 15 | former = torch.rand(bz*c*h*w).mul_(50).reshape(bz, c, h, w).int().float() 16 | latter = torch.rand(bz*c*h*w).mul_(50).reshape(bz, c, h, w).int().float() 17 | 18 | 19 | flag = torch.zeros(bz, h, w).byte() 20 | flag[:, h//4:h//2+1, h//4:h//2+1] = 1 21 | flag = flag.view(bz, h*w) 22 | 23 | ind_lst = torch.FloatTensor(bz, h*w, h*w).zero_() 24 | shift_offsets = [] 25 | 26 | #Nonparm = Modified_NonparametricShift() 27 | bNonparm = Batch_NonShift() 28 | cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former.clone(), latter.clone(), 1, 1, flag) 29 | print(cosine.size()) 30 | print(latter_windows.size()) 31 | ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY 32 | 33 | _, indexes = torch.max(cosine, dim=2) 34 | print('indexes dim') 35 | print(indexes.size()) 36 | 37 | 38 | # SET TRANSITION MATRIX 39 | mask_indexes = (flag == 1).nonzero() 40 | mask_indexes = mask_indexes[:,1] # remove indexes that indicates the batch dim 41 | mask_indexes = mask_indexes.view(bz, -1) 42 | 43 | # Also remove indexes of batch 44 | tmp = (flag==0).nonzero()[:,1] 45 | tmp = tmp.view(bz, -1) 46 | print('tmp size') 47 | print(tmp.size()) 48 | 49 | idx_tmp = indexes + torch.arange(indexes.size(0)).view(-1,1) * tmp.size(1) 50 | non_mask_indexes = tmp.view(-1)[idx_tmp] 51 | 52 | # Original method 53 | non_mask_indexes_2 = [] 54 | for i in range(bz): 55 | non_mask_indexes_tmp = tmp[i][indexes[i]] 56 | non_mask_indexes_2.append(non_mask_indexes_tmp) 57 | 58 | non_mask_indexes_2 = torch.stack(non_mask_indexes_2, dim=0) 59 | 60 | print('These two methods should be the same, as the error is 0!') 61 | print(torch.sum(non_mask_indexes-non_mask_indexes_2)) 62 | 63 | ind_lst2 = ind_lst.clone() 64 | for i in range(bz): 65 | ind_lst[i][mask_indexes[i], non_mask_indexes[i]] = 1 66 | 67 | print(ind_lst.sum()) 68 | print(ind_lst) 69 | 70 | for i in range(bz): 71 | for mi, nmi in zip(mask_indexes[i], non_mask_indexes[i]): 72 | print('The %d\t-th pixel in the %d-th tensor will shift to %d\t-th coordinate' %(nmi, i, mi)) 73 | print('~~~') 74 | 75 | # GET FINAL SHIFT FEATURE 76 | shift_masked_all = bNonparm._paste(latter_windows, ind_lst, i_2, i_3, i_1) 77 | print(shift_masked_all.size()) 78 | 79 | assert 1==2 80 | # print('flag') 81 | # print(flag.reshape(h,w)) 82 | # print('ind_lst') 83 | # print(ind_lst) 84 | # print('out') 85 | # print(shift_masked_all) 86 | 87 | # get shift offset () 88 | shift_offset = torch.stack([non_mask_indexes.squeeze() // w, torch.fmod(non_mask_indexes.squeeze(), w)], dim=-1) 89 | print('shift_offset') 90 | print(shift_offset) 91 | print(shift_offset.size()) 92 | 93 | shift_offsets.append(shift_offset) 94 | shift_offsets = torch.cat(shift_offsets, dim=0).float() 95 | print(shift_offsets.size()) 96 | print(shift_offsets) 97 | 98 | shift_offsets_cl = shift_offsets.clone() 99 | 100 | lt = (flag==1).nonzero()[0] 101 | rb = (flag==1).nonzero()[-1] 102 | 103 | mask_h = rb//w+1 - lt//w 104 | mask_w = rb%w+1 - lt%w 105 | 106 | shift_offsets = shift_offsets.view([bz] + [2] + [mask_h, mask_w]) # So only appropriate for square mask. 107 | print(shift_offsets.size()) 108 | print(shift_offsets) 109 | 110 | h_add = torch.arange(0, float(h)).view([1, 1, h, 1]).float() 111 | h_add = h_add.expand(bz, 1, h, w) 112 | w_add = torch.arange(0, float(w)).view([1, 1, 1, w]).float() 113 | w_add = w_add.expand(bz, 1, h, w) 114 | 115 | com_map = torch.cat([h_add, w_add], dim=1) 116 | print('com_map') 117 | print(com_map) 118 | 119 | com_map_crop = com_map[:, :, lt//w:rb//w+1, lt%w:rb%w+1] 120 | print('com_map crop') 121 | print(com_map_crop) 122 | 123 | shift_offsets = shift_offsets - com_map_crop 124 | print('final shift_offsets') 125 | print(shift_offsets) 126 | 127 | 128 | # to flow image 129 | flow = torch.from_numpy(util.flow_to_image(shift_offsets.permute(0,2,3,1).cpu().data.numpy())) 130 | flow = flow.permute(0,3,1,2) 131 | 132 | #visualize which pixels are attended 133 | print(flag.size()) 134 | print(shift_offsets.size()) 135 | 136 | # global and N*C*H*W 137 | # put shift_offsets_cl back to the global map. 138 | shift_offsets_map = flag.clone().view(-1) 139 | shift_offsets_map[indexes] = shift_offsets_cl.view(-1) 140 | print(shift_offsets_map) 141 | assert 1==2 142 | flow2 = torch.from_numpy(util.highlight_flow((shift_offsets_cl).numpy())) 143 | 144 | upflow = F.interpolate(flow, scale_factor=4, mode='nearest') 145 | upflow2 = F.interpolate(flow2, scale_factor=4, mode='nearest') 146 | 147 | 148 | upflow = upflow.squeeze().permute(1,2,0) 149 | upflow2 = upflow2.squeeze().permute(1,2,0) 150 | 151 | print('flow 1') 152 | print(upflow) 153 | print(upflow.size()) 154 | 155 | print('flow 2') 156 | print(upflow2) 157 | print(upflow2.size()) 158 | 159 | fig, axs = plot.subplots(ncols=2) 160 | axs[0].imshow(upflow) 161 | axs[1].imshow(upflow2) 162 | 163 | plt.show() 164 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | if __name__ == "__main__": 8 | opt = TrainOptions().parse() 9 | data_loader = CreateDataLoader(opt) 10 | dataset = data_loader.load_data() 11 | dataset_size = len(data_loader) 12 | print('#training images = %d' % dataset_size) 13 | 14 | model = create_model(opt) 15 | visualizer = Visualizer(opt) 16 | 17 | total_steps = 0 18 | 19 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 20 | epoch_start_time = time.time() 21 | iter_data_time = time.time() 22 | epoch_iter = 0 23 | 24 | for i, data in enumerate(dataset): 25 | iter_start_time = time.time() 26 | if total_steps % opt.print_freq == 0: 27 | t_data = iter_start_time - iter_data_time 28 | visualizer.reset() 29 | total_steps += opt.batchSize 30 | epoch_iter += opt.batchSize 31 | 32 | model.set_input(data) # it not only sets the input data with mask, but also sets the latent mask. 33 | 34 | # Additonal, should set it before 'optimize_parameters()'. 35 | if total_steps % opt.display_freq == 0: 36 | if opt.show_flow: 37 | model.set_show_map_true() 38 | 39 | model.optimize_parameters() 40 | 41 | if total_steps % opt.display_freq == 0: 42 | save_result = total_steps % opt.update_html_freq == 0 43 | if opt.show_flow: 44 | model.set_flow_src() 45 | model.set_show_map_false() 46 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 47 | 48 | if total_steps % opt.print_freq == 0: 49 | losses = model.get_current_losses() 50 | t = (time.time() - iter_start_time) / opt.batchSize 51 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 52 | if opt.display_id > 0: 53 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 54 | 55 | if total_steps % opt.save_latest_freq == 0: 56 | print('saving the latest model (epoch %d, total_steps %d)' % 57 | (epoch, total_steps)) 58 | model.save_networks('latest') 59 | 60 | iter_data_time = time.time() 61 | if epoch % opt.save_epoch_freq == 0: 62 | print('saving the model at the end of epoch %d, iters %d' % 63 | (epoch, total_steps)) 64 | model.save_networks('latest') 65 | if not opt.only_lastest: 66 | model.save_networks(epoch) 67 | 68 | print('End of epoch %d / %d \t Time Taken: %d sec' % 69 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 70 | model.update_learning_rate() 71 | -------------------------------------------------------------------------------- /util/NonparametricShift.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from time import time 7 | 8 | # These three functions only work when patch_size = 1x1 9 | class Modified_NonparametricShift(object): 10 | 11 | def _extract_patches_from_flag(self, img, patch_size, stride, flag, value): 12 | input_windows = self._unfold(img, patch_size, stride) 13 | 14 | input_windows = self._filter(input_windows, flag, value) 15 | 16 | return self._norm(input_windows) 17 | 18 | # former: content, to be replaced. 19 | # latter: style, source pixels. 20 | def cosine_similarity(self, former, latter, patch_size, stride, flag, with_former=False): 21 | former_windows = self._unfold(former, patch_size, stride) 22 | former = self._filter(former_windows, flag, 1) 23 | 24 | latter_windows, i_2, i_3, i_1 = self._unfold(latter, patch_size, stride, with_indexes=True) 25 | latter = self._filter(latter_windows, flag, 0) 26 | 27 | num = torch.einsum('ik,jk->ij', [former, latter]) 28 | norm_latter = torch.einsum("ij,ij->i", [latter, latter]) 29 | norm_former = torch.einsum("ij,ij->i", [former, former]) 30 | den = torch.sqrt(torch.einsum('i,j->ij', [norm_former, norm_latter])) 31 | if not with_former: 32 | return num / den, latter_windows, i_2, i_3, i_1 33 | else: 34 | return num / den, latter_windows, former_windows, i_2, i_3, i_1 35 | 36 | 37 | def _paste(self, input_windows, transition_matrix, i_2, i_3, i_1): 38 | ## TRANSPOSE FEATURES NEW FEATURES 39 | input_windows = torch.mm(transition_matrix, input_windows) 40 | 41 | ## RESIZE TO CORRET CONV FEATURES FORMAT 42 | input_windows = input_windows.view(i_2, i_3, i_1) 43 | input_windows = input_windows.permute(2, 0, 1).unsqueeze(0) 44 | return input_windows 45 | 46 | def _unfold(self, img, patch_size, stride, with_indexes=False): 47 | n_dim = 3 48 | assert img.dim() == n_dim, 'image must be of dimension 3.' 49 | 50 | kH, kW = patch_size, patch_size 51 | dH, dW = stride, stride 52 | input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW) 53 | 54 | i_1, i_2, i_3, i_4, i_5 = input_windows.size() 55 | 56 | if with_indexes: 57 | input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1) 58 | return input_windows, i_2, i_3, i_1 59 | else: 60 | input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1, i_4, i_5) 61 | return input_windows 62 | 63 | def _filter(self, input_windows, flag, value): 64 | ## EXTRACT MASK OR NOT DEPENDING ON VALUE 65 | input_window = input_windows[flag == value] 66 | return input_window.view(input_window.size(0), -1) 67 | 68 | 69 | def _norm(self, input_window): 70 | # This norm is incorrect. 71 | #return torch.norm(input_window, dim=1, keepdim=True) 72 | for i in range(input_window.size(0)): 73 | input_window[i] = input_window[i]*(1/(input_window[i].norm(2)+1e-8)) 74 | 75 | return input_window 76 | 77 | class Batch_NonShift(object): 78 | 79 | def _extract_patches_from_flag(self, img, patch_size, stride, flag, value): 80 | input_windows = self._unfold(img, patch_size, stride) 81 | 82 | input_windows = self._filter(input_windows, flag, value) 83 | 84 | return self._norm(input_windows) 85 | 86 | # former: content, to be replaced. 87 | # latter: style, source pixels. 88 | def cosine_similarity(self, former, latter, patch_size, stride, flag, with_former=False): 89 | former_windows = self._unfold(former, patch_size, stride) 90 | former = self._filter(former_windows, flag, 1) 91 | 92 | latter_windows, i_2, i_3, i_1 = self._unfold(latter, patch_size, stride, with_indexes=True) 93 | latter = self._filter(latter_windows, flag, 0) 94 | 95 | num = torch.einsum('bik,bjk->bij', [former, latter]) 96 | norm_latter = torch.einsum("bij,bij->bi", [latter, latter]) 97 | norm_former = torch.einsum("bij,bij->bi", [former, former]) 98 | den = torch.sqrt(torch.einsum('bi,bj->bij', [norm_former, norm_latter])) 99 | if not with_former: 100 | return num / den, latter_windows, i_2, i_3, i_1 101 | else: 102 | return num / den, latter_windows, former_windows, i_2, i_3, i_1 103 | 104 | 105 | # delete i_4, as i_4 is 1 106 | def _paste(self, input_windows, transition_matrix, i_2, i_3, i_1): 107 | ## TRANSPOSE FEATURES NEW FEATURES 108 | bz = input_windows.size(0) 109 | input_windows = torch.bmm(transition_matrix, input_windows) 110 | 111 | ## RESIZE TO CORRET CONV FEATURES FORMAT 112 | input_windows = input_windows.view(bz, i_2, i_3, i_1) 113 | input_windows = input_windows.permute(0, 3, 1, 2) 114 | return input_windows 115 | 116 | def _unfold(self, img, patch_size, stride, with_indexes=False): 117 | n_dim = 4 118 | assert img.dim() == n_dim, 'image must be of dimension 4.' 119 | 120 | kH, kW = patch_size, patch_size 121 | dH, dW = stride, stride 122 | input_windows = img.unfold(2, kH, dH).unfold(3, kW, dW) 123 | 124 | i_0, i_1, i_2, i_3, i_4, i_5 = input_windows.size() 125 | 126 | if with_indexes: 127 | input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1) 128 | return input_windows, i_2, i_3, i_1 129 | else: 130 | input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1, i_4, i_5) 131 | return input_windows 132 | 133 | def _filter(self, input_windows, flag, value): 134 | ## EXTRACT MASK OR NOT DEPENDING ON VALUE 135 | assert flag.dim() == 2, "flag should be batch version" 136 | input_window = input_windows[flag == value] 137 | bz = flag.size(0) 138 | return input_window.view(bz, input_window.size(0)//bz, -1) 139 | 140 | 141 | # Deprecated code 142 | class NonparametricShift(object): 143 | def buildAutoencoder(self, target_img, normalize, interpolate, nonmask_point_idx, patch_size=1, stride=1): 144 | nDim = 3 145 | assert target_img.dim() == nDim, 'target image must be of dimension 3.' 146 | C = target_img.size(0) 147 | 148 | self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor 149 | 150 | patches_all, patches_part = self._extract_patches(target_img, patch_size, stride, nonmask_point_idx) 151 | npatches_part = patches_part.size(0) 152 | npatches_all = patches_all.size(0) 153 | 154 | 155 | conv_enc_non_mask, conv_dec_non_mask = self._build(patch_size, stride, C, patches_part, npatches_part, normalize, interpolate) 156 | conv_enc_all, conv_dec_all = self._build(patch_size, stride, C, patches_all, npatches_all, normalize, interpolate) 157 | 158 | return conv_enc_all, conv_enc_non_mask, conv_dec_all, conv_dec_non_mask 159 | 160 | def _build(self, patch_size, stride, C, target_patches, npatches, normalize, interpolate): 161 | # for each patch, divide by its L2 norm. 162 | enc_patches = target_patches.clone() 163 | for i in range(npatches): 164 | enc_patches[i] = enc_patches[i]*(1/(enc_patches[i].norm(2)+1e-8)) 165 | conv_enc = nn.Conv2d(C, npatches, kernel_size=patch_size, stride=stride, bias=False) 166 | conv_enc.weight.data = enc_patches 167 | 168 | # normalize is not needed, it doesn't change the result! 169 | if normalize: 170 | raise NotImplementedError 171 | 172 | if interpolate: 173 | raise NotImplementedError 174 | 175 | conv_dec = nn.ConvTranspose2d(npatches, C, kernel_size=patch_size, stride=stride, bias=False) 176 | conv_dec.weight.data = target_patches 177 | 178 | return conv_enc, conv_dec 179 | 180 | def _extract_patches(self, img, patch_size, stride, nonmask_point_idx): 181 | n_dim = 3 182 | assert img.dim() == n_dim, 'image must be of dimension 3.' 183 | 184 | kH, kW = patch_size, patch_size 185 | dH, dW = stride, stride 186 | input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW) 187 | 188 | i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(2), input_windows.size(3), input_windows.size(4) 189 | input_windows = input_windows.permute(1,2,0,3,4).contiguous().view(i_2*i_3, i_1, i_4, i_5) 190 | 191 | patches_all = input_windows 192 | patches = input_windows.index_select(0, nonmask_point_idx) #It returns a new tensor, representing patches extracted from non-masked region! 193 | return patches_all, patches 194 | 195 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if refresh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="refresh", content=str(refresh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /util/poisson_blending.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse 3 | import cv2 4 | import pyamg 5 | 6 | # pre-process the mask array so that uint64 types from opencv.imread can be adapted 7 | def prepare_mask(mask): 8 | if type(mask[0][0]) is np.ndarray: 9 | result = np.ndarray((mask.shape[0], mask.shape[1]), dtype=np.uint8) 10 | for i in range(mask.shape[0]): 11 | for j in range(mask.shape[1]): 12 | if sum(mask[i][j]) > 0: 13 | result[i][j] = 1 14 | else: 15 | result[i][j] = 0 16 | mask = result 17 | return mask 18 | 19 | def blend(img_target, img_source, img_mask, offset=(0, 0)): 20 | # compute regions to be blended 21 | region_source = ( 22 | max(-offset[0], 0), 23 | max(-offset[1], 0), 24 | min(img_target.shape[0]-offset[0], img_source.shape[0]), 25 | min(img_target.shape[1]-offset[1], img_source.shape[1])) 26 | region_target = ( 27 | max(offset[0], 0), 28 | max(offset[1], 0), 29 | min(img_target.shape[0], img_source.shape[0]+offset[0]), 30 | min(img_target.shape[1], img_source.shape[1]+offset[1])) 31 | region_size = (region_source[2]-region_source[0], region_source[3]-region_source[1]) 32 | 33 | # clip and normalize mask image 34 | img_mask = img_mask[region_source[0]:region_source[2], region_source[1]:region_source[3]] 35 | img_mask = prepare_mask(img_mask) 36 | img_mask[img_mask==0] = False 37 | img_mask[img_mask!=False] = True 38 | 39 | # create coefficient matrix 40 | A = scipy.sparse.identity(np.prod(region_size), format='lil') 41 | for y in range(region_size[0]): 42 | for x in range(region_size[1]): 43 | if img_mask[y,x]: 44 | index = x+y*region_size[1] 45 | A[index, index] = 4 46 | if index+1 < np.prod(region_size): 47 | A[index, index+1] = -1 48 | if index-1 >= 0: 49 | A[index, index-1] = -1 50 | if index+region_size[1] < np.prod(region_size): 51 | A[index, index+region_size[1]] = -1 52 | if index-region_size[1] >= 0: 53 | A[index, index-region_size[1]] = -1 54 | A = A.tocsr() 55 | 56 | # create poisson matrix for b 57 | P = pyamg.gallery.poisson(img_mask.shape) 58 | 59 | # for each layer (ex. RGB) 60 | for num_layer in range(img_target.shape[2]): 61 | # get subimages 62 | t = img_target[region_target[0]:region_target[2], region_target[1]:region_target[3],num_layer] 63 | s = img_source[region_source[0]:region_source[2], region_source[1]:region_source[3],num_layer] 64 | t = t.flatten() 65 | s = s.flatten() 66 | 67 | # create b 68 | b = P * s 69 | for y in range(region_size[0]): 70 | for x in range(region_size[1]): 71 | if not img_mask[y,x]: 72 | index = x+y*region_size[1] 73 | b[index] = t[index] 74 | 75 | # solve Ax = b 76 | x = pyamg.solve(A,b,verb=False,tol=1e-10) 77 | 78 | # assign x to target image 79 | x = np.reshape(x, region_size) 80 | x[x>255] = 255 81 | x[x<0] = 0 82 | x = np.array(x, img_target.dtype) 83 | img_target[region_target[0]:region_target[2],region_target[1]:region_target[3],num_layer] = x 84 | 85 | return img_target -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | import sys 6 | from subprocess import Popen, PIPE 7 | from . import util, html 8 | from scipy.misc import imresize 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 16 | """Save images to the disk. 17 | Parameters: 18 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 19 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 20 | image_path (str) -- the string is used to create image paths 21 | aspect_ratio (float) -- the aspect ratio of saved images 22 | width (int) -- the images will be resized to width x width 23 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 24 | """ 25 | image_dir = webpage.get_image_dir() 26 | short_path = ntpath.basename(image_path[0]) 27 | name = os.path.splitext(short_path)[0] 28 | 29 | webpage.add_header(name) 30 | ims, txts, links = [], [], [] 31 | 32 | for label, im_data in visuals.items(): 33 | im = util.tensor2im(im_data) 34 | image_name = '%s_%s.png' % (name, label) 35 | save_path = os.path.join(image_dir, image_name) 36 | h, w, _ = im.shape 37 | if aspect_ratio > 1.0: 38 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 39 | if aspect_ratio < 1.0: 40 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 41 | util.save_image(im, save_path) 42 | 43 | ims.append(image_name) 44 | txts.append(label) 45 | links.append(image_name) 46 | webpage.add_images(ims, txts, links, width=width) 47 | 48 | class Visualizer(): 49 | def __init__(self, opt): 50 | self.display_id = opt.display_id 51 | self.use_html = opt.isTrain and not opt.no_html 52 | self.win_size = opt.display_winsize 53 | self.name = opt.name 54 | self.port = opt.display_port 55 | self.opt = opt 56 | self.saved = False 57 | if self.display_id > 0: 58 | import visdom 59 | self.ncols = opt.display_ncols 60 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 61 | if not self.vis.check_connection(): 62 | self.create_visdom_connections() 63 | 64 | if self.use_html: 65 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 66 | self.img_dir = os.path.join(self.web_dir, 'images') 67 | print('create web directory %s...' % self.web_dir) 68 | util.mkdirs([self.web_dir, self.img_dir]) 69 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 70 | with open(self.log_name, "a") as log_file: 71 | now = time.strftime("%c") 72 | log_file.write('================ Training Loss (%s) ================\n' % now) 73 | 74 | def reset(self): 75 | self.saved = False 76 | 77 | def create_visdom_connections(self): 78 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 79 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 80 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 81 | print('Command: %s' % cmd) 82 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 83 | 84 | # |visuals|: dictionary of images to display or save 85 | def display_current_results(self, visuals, epoch, save_result): 86 | if self.display_id > 0: # show images in the browser 87 | ncols = self.ncols 88 | if ncols > 0: 89 | ncols = min(ncols, len(visuals)) 90 | h, w = next(iter(visuals.values())).shape[:2] 91 | table_css = """""" % (w, h) 95 | title = self.name 96 | label_html = '' 97 | label_html_row = '' 98 | images = [] 99 | idx = 0 100 | for label, image in visuals.items(): 101 | image = util.rm_extra_dim(image) # remove the dummy dim 102 | image_numpy = util.tensor2im(image) 103 | label_html_row += '%s' % label 104 | images.append(image_numpy.transpose([2, 0, 1])) 105 | idx += 1 106 | if idx % ncols == 0: 107 | label_html += '%s' % label_html_row 108 | label_html_row = '' 109 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 110 | while idx % ncols != 0: 111 | images.append(white_image) 112 | label_html_row += '' 113 | idx += 1 114 | if label_html_row != '': 115 | label_html += '%s' % label_html_row 116 | try: 117 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 118 | padding=2, opts=dict(title=title + ' images')) 119 | label_html = '%s
' % label_html 120 | self.vis.text(table_css + label_html, win=self.display_id + 2, 121 | opts=dict(title=title + ' labels')) 122 | except VisdomExceptionBase: 123 | self.create_visdom_connections() 124 | else: 125 | idx = 1 126 | for label, image in visuals.items(): 127 | image_numpy = util.tensor2im(image) 128 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 129 | win=self.display_id + idx) 130 | idx += 1 131 | 132 | if self.use_html and (save_result or not self.saved): # save images to a html file 133 | self.saved = True 134 | for label, image in visuals.items(): 135 | image_numpy = util.tensor2im(image) 136 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 137 | util.save_image(image_numpy, img_path) 138 | # update website 139 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 140 | for n in range(epoch, 0, -1): 141 | webpage.add_header('epoch [%d]' % n) 142 | ims, txts, links = [], [], [] 143 | 144 | for label, image_numpy in visuals.items(): 145 | image_numpy = util.tensor2im(image) 146 | img_path = 'epoch%.3d_%s.png' % (n, label) 147 | ims.append(img_path) 148 | txts.append(label) 149 | links.append(img_path) 150 | webpage.add_images(ims, txts, links, width=self.win_size) 151 | webpage.save() 152 | 153 | # losses: dictionary of error labels and values 154 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 155 | if not hasattr(self, 'plot_data'): 156 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 157 | self.plot_data['X'].append(epoch + counter_ratio) 158 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 159 | self.vis.line( 160 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 161 | Y=np.array(self.plot_data['Y']), 162 | opts={ 163 | 'title': self.name + ' loss over time', 164 | 'legend': self.plot_data['legend'], 165 | 'xlabel': 'epoch', 166 | 'ylabel': 'loss'}, 167 | win=self.display_id) 168 | 169 | # losses: same format as |losses| of plot_current_losses 170 | def print_current_losses(self, epoch, i, losses, t, t_data): 171 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 172 | for k, v in losses.items(): 173 | message += '%s: %.3f ' % (k, v) 174 | 175 | print(message) 176 | with open(self.log_name, "a") as log_file: 177 | log_file.write('%s\n' % message) 178 | 179 | 180 | --------------------------------------------------------------------------------