├── .gitignore
├── LICENSE
├── README.md
├── architecture.png
├── data
├── __init__.py
├── aligned_dataset.py
├── aligned_dataset_resized.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── data_loader.py
├── image_folder.py
└── single_dataset.py
├── download_models.sh
├── generate_masks.py
├── imgs
├── compare
│ ├── 13_fake_B.png
│ ├── 13_fake_B_flip.png
│ ├── 13_real_A.png
│ ├── 13_real_B.png
│ ├── 18_fake_B.png
│ ├── 18_fake_B_flip.png
│ ├── 18_real_A.png
│ └── 18_real_B.png
├── face_center
│ ├── 0_fake_B.png
│ ├── 0_real_A.png
│ ├── 0_real_B.png
│ ├── 106_fake_B.png
│ ├── 106_real_A.png
│ ├── 106_real_B.png
│ ├── 111_fake_B.png
│ ├── 111_real_A.png
│ ├── 111_real_B.png
│ ├── 11_fake_B.png
│ ├── 11_real_A.png
│ ├── 11_real_B.png
│ ├── 14_fake_B.png
│ ├── 14_real_A.png
│ ├── 14_real_B.png
│ ├── 1_fake_B.png
│ ├── 1_real_A.png
│ └── 1_real_B.png
├── face_random
│ ├── 0_fake_B.png
│ ├── 0_real_A.png
│ ├── 0_real_B.png
│ ├── 1_fake_B.png
│ ├── 1_real_A.png
│ └── 1_real_B.png
├── paris_center
│ ├── 003_im_fake_B.png
│ ├── 003_im_real_A.png
│ ├── 003_im_real_B.png
│ ├── 004_im_fake_B.png
│ ├── 004_im_real_A.png
│ ├── 004_im_real_B.png
│ ├── 048_im_fake_B.png
│ ├── 048_im_real_A.png
│ └── 048_im_real_B.png
└── paris_random
│ ├── 006_im_fake_B.png
│ ├── 006_im_real_A.png
│ ├── 006_im_real_B.png
│ ├── 055_im_fake_B.png
│ ├── 055_im_real_A.png
│ ├── 055_im_real_B.png
│ ├── 073_im_fake_B.png
│ ├── 073_im_real_A.png
│ └── 073_im_real_B.png
├── models
├── __init__.py
├── face_shift_net
│ ├── InnerFaceShiftTriple.py
│ ├── InnerFaceShiftTripleFunction.py
│ ├── __init__.py
│ └── face_shiftnet_model.py
├── modules
│ ├── __init__.py
│ ├── denset_net.py
│ ├── discrimators.py
│ ├── losses.py
│ ├── modules.py
│ ├── shift_unet.py
│ └── unet.py
├── networks.py
├── patch_soft_shift
│ ├── __init__.py
│ ├── innerPatchSoftShiftTriple.py
│ ├── innerPatchSoftShiftTripleModule.py
│ └── patch_soft_shiftnet_model.py
├── res_patch_soft_shift
│ ├── __init__.py
│ ├── innerResPatchSoftShiftTriple.py
│ └── res_patch_soft_shiftnet_model.py
├── res_shift_net
│ ├── __init__.py
│ ├── innerResShiftTriple.py
│ └── shiftnet_model.py
└── shift_net
│ ├── InnerCos.py
│ ├── InnerCosFunction.py
│ ├── InnerShiftTriple.py
│ ├── InnerShiftTripleFunction.py
│ ├── __init__.py
│ ├── base_model.py
│ └── shiftnet_model.py
├── notebooks
├── NewModule.ipynb
├── OptimizingShift.ipynb
└── __init__.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── shift_layer.png
├── show_map.py
├── test.py
├── test_acc_shift.py
├── train.py
└── util
├── NonparametricShift.py
├── __init__.py
├── html.py
├── png.py
├── poisson_blending.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.JPG
2 | *.jpg
3 | *.txt
4 | datasets/
5 | checkpoints/
6 | output/
7 | results/
8 | .vscode/
9 | log/
10 | logs/
11 | *.swp
12 | *.pth
13 | *.pyc
14 | .idea/
15 | *-checkpoint.py
16 | *.ipynb_checkpoints/
17 | masks/
18 | resized_paris/
19 | fakeB/
20 | shifted/
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Zhaoyi-Yan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # New training strategy
2 | I release a new training strategy that helps deal with random mask training by reducing color shifting at the cost of about extra 30% training time. It is quite useful when we perform face inpainiting.
3 | Set `--which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'` and `--batchSize=1`to carry out the strategy.
4 |
5 | See some examples below, many approaches suffer from such `color shifting` when training with random masks on face datasets.
6 |
7 |
8 | Input | Navie Shift | Flip Shift | Ground-truth |
9 |
10 |
11 |
12 |
13 | |
14 |
15 |
16 | |
17 |
18 |
19 | |
20 |
21 |
22 | |
23 |
24 |
25 |
26 |
27 |
28 |
29 | |
30 |
31 |
32 | |
33 |
34 |
35 | |
36 |
37 |
38 | |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | Note: When you use `face_flip training strategy`, it suffers some minor drawbacks:
48 | 1. It is not fully-parallel compared with original shift.
49 | 2. It can only be trained on the 'cpu' or on a single gpu, the batch size must be 1, or it occurs an error.
50 |
51 | If you want to conquer these drawbacks, you can optimize it by referring to original shift. It is not difficult, however, I do not have time to do it.
52 |
53 | # Architecutre
54 |
55 |
56 | # Shift layer
57 |
58 |
59 | ## Prerequisites
60 | - Linux or Windows.
61 | - Python 2 or Python 3.
62 | - CPU or NVIDIA GPU + CUDA CuDNN.
63 | - Tested on pytorch >= **1.2**
64 |
65 | ## Getting Started
66 | ### Installation
67 | - Install PyTorch and dependencies from http://pytorch.org/
68 | - Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate).
69 |
70 | ```bash
71 | pip install visdom
72 | pip install dominate
73 | ```
74 | - Clone this repo:
75 | ```bash
76 | git clone https://github.com/Zhaoyi-Yan/Shift-Net_pytorch
77 | cd Shift-Net_pytorch
78 |
79 | ```
80 |
81 | # Trained models
82 | Usually, I would like to suggest you just pull the latest code and train by following the instructions.
83 |
84 | However, for now, several models have been trained and uploaded.
85 |
86 | | Mask | Paris | CelebaHQ_256 |
87 | | ---- | ---- | ---- |
88 | | center-mask | ok | ok |
89 | | random mask(from **partial conv**)| ok | ok |
90 |
91 | For CelebaHQ_256 dataset:
92 | I select the first 2k images in CelebaHQ_256 for testing, the rest are for training.
93 | ```
94 | python train.py --loadSize=256 --batchSize=1 --model='face_shiftnet' --name='celeb256' --which_model_netG='face_unet_shift_triple' --niter=30 --datarooot='./datasets/celeba-256/train'
95 | ```
96 | Mention: **`loadSize` should be `256` for face datasets, meaning direct resize the input image to `256x256`.**
97 |
98 | The following some results on celebaHQ-256 and Paris.
99 |
100 | Specially, for training models of random masks, we adopt the masks of **partial conv**(only the masks of which the ratio of masked region is 20~30% are used.)
101 |
102 |
103 |
104 | Input | Results | Ground-truth |
105 |
106 |
107 |
108 |
109 | |
110 |
111 |
112 | |
113 |
114 |
115 | |
116 |
117 |
118 |
119 |
120 |
121 | |
122 |
123 |
124 | |
125 |
126 |
127 | |
128 |
129 |
130 |
131 |
132 |
133 | |
134 |
135 |
136 | |
137 |
138 |
139 | |
140 |
141 |
142 |
143 |
144 |
145 | |
146 |
147 |
148 | |
149 |
150 |
151 | |
152 |
153 |
154 |
155 |
156 |
157 | |
158 |
159 |
160 | |
161 |
162 |
163 | |
164 |
165 |
166 |
167 |
168 |
169 | |
170 |
171 |
172 | |
173 |
174 |
175 | |
176 |
177 |
178 |
179 |
180 |
181 | |
182 |
183 |
184 | |
185 |
186 |
187 | |
188 |
189 |
190 |
191 |
192 |
193 | |
194 |
195 |
196 | |
197 |
198 |
199 | |
200 |
201 |
202 |
203 |
204 |
205 | For testing, please read the documnent carefully.
206 |
207 | Pretrained model for face center inpainting are available:
208 | ```bash
209 | bash download_models.sh
210 | ```
211 | Rename `face_center_mask.pth` to `30_net_G.pth`, and put it in the folder `./log/face_center_mask_20_30`(if not existed, create it)
212 | ```bash
213 | python test.py --which_model_netG='unet_shift_triple' --model='shiftnet' --name='face_center_mask_20_30' --which_epoch=30 --dataroot='./datasets/celeba-256/test'
214 | ```
215 |
216 | For face random inpainting, it is trained with `--which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'`. Rename `face_flip_random.pth` to `30_net_G.pth` and set `which_model_netG='face_unet_shift_triple'` and `--model='face_shiftnet'` when testing.
217 |
218 | Similarity, for paris random inpainting, rename `paris_random_mask_20_30.pth` to `30_net_G.pth`, and put it in the folder `./log/paris_random_mask_20_30`(if not existed, create it)
219 | Then test the model:
220 | ```
221 | python test.py --which_epoch=30 --name='paris_random_mask_20_30' --offline_loading_mask=1 --testing_mask_folder='masks' --dataroot='./datasets/celeba-256/test' --norm='instance'
222 | ```
223 | Mention, your own masks should be prepared in the folder `testing_mask_folder` in advance.
224 |
225 | For other models, I think you know how to evaluate them.
226 | For models trained with center mask, make sure `--mask_type='center' --offline_loading_mask=0`.
227 |
228 |
229 | ## Train models
230 | - Download your own inpainting datasets. Just put all the train/test images in some folder (eg, ./xx/train/ , ./xx/test/), change `dataroot` in `options/base_options.py` to the that path, that is all.
231 |
232 | - Train a model:
233 | Please read this paragraph carefully before running the code.
234 |
235 | Usually, we train/test `navie shift-net` with `center` mask.
236 |
237 | ```bash
238 | python train.py --batchsize=1 --use_spectral_norm_D=1 --which_model_netD='basic' --mask_type='center' --which_model_netG='unet_shift_triple' --model='shiftnet' --shift_sz=1 --mask_thred=1
239 | ```
240 |
241 | For some datasets, such as `CelebA`, some images are smaller than `256*256`, so you need add `--loadSize=256` when training, **it is important**.
242 |
243 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. The checkpoints will be saved in `./log` by default.
244 |
245 |
246 | **DO NOT** set batchsize larger than 1 for `square` mask training, the performance degrades a lot(I don't know why...)
247 | For `random mask`(`mask_sub_type` is NOT `rect` or your own random masks), the training batchsize can be larger than 1 without hurt of performance.
248 |
249 | Random mask training(both online and offline) are also supported.
250 |
251 | Personally, I would like to suggest you to loading the masks offline(similar as **partial conv**). Please refer to section **Masks**.
252 |
253 | ## Test the model
254 |
255 | **Keep the same settings as those during training phase to avoid errors or bad performance**
256 |
257 | For example, if you train `patch soft shift-net`, then the following testing command is appropriate.
258 | ```bash
259 | python test.py --fuse=1/0 --which_model_netG='patch_soft_unet_shift_triple' --model='patch_soft_shiftnet' --shift_sz=3 --mask_thred=4
260 | ```
261 | The test results will be saved to a html file here: `./results/`.
262 |
263 |
264 | ## Masks
265 | Usually, **Keep the same setting of masks of between training and testing.**
266 | It is because the performance is highly-related to the masks your applied in training.
267 | The consistency of training and testing masks is crucial to achieve good performance.
268 |
269 | | training | testing |
270 | | ---- | ---- |
271 | | center-mask | center-mask |
272 | | random-square| All |
273 | | random | All|
274 | | your own masks| your own masks|
275 |
276 | It means that if you train a model with `center-mask`, then you need test it using `center-mask`(even without one pixel offset). For more info, you may refer to https://github.com/Zhaoyi-Yan/Shift-Net_pytorch/issues/125
277 | ### Training by online-generating marks
278 | We offer three types of online-generating masks: `center-mask, random_square and random_mask`.
279 | If you want to train on your own masks silimar like **partial conv**, ref to **Training on your own masks**.
280 |
281 |
282 | ### Training on your own masks
283 | It now supports both online-generating and offline-loading for training and testing.
284 | We generate masks online by default, however, set `--offline_loading_mask=1` when you want to train/test with your own prepared masks.
285 | **The prepared masks should be put in the folder `--training_mask_folder` and `--testing_mask_folder`.**
286 |
287 | ### Masks when training
288 | For each batch, then:
289 | - Generating online: masks are the same for each image in a batch.(To save computation)
290 | - Loading offline: masks are loaded randomly for each image in a batch.
291 |
292 | ## Using Switchable Norm instead of Instance/Batch Norm
293 | For fixed mask training, `Switchable Norm` delivers better stableness when batchSize > 1. **Please use switchable norm when you want to training with batchsize is large, much more stable than instance norm or batchnorm!**
294 |
295 | ### Extra variants
296 |
297 | **These 3 models are just for fun**
298 |
299 | For `res patch soft shift-net`:
300 | ```bash
301 | python train.py --batchSize=1 --which_model_netG='res_patch_soft_unet_shift_triple' --model='res_patch_soft_shiftnet' --shift_sz=3 --mask_thred=4
302 | ```
303 |
304 | For `res navie shift-net`:
305 | ```bash
306 | python train.py --which_model_netG='res_unet_shift_triple' --model='res_shiftnet' --shift_sz=1 --mask_thred=1
307 | ```
308 |
309 | For `patch soft shift-net`:
310 | ```bash
311 | python train.py --which_model_netG='patch_soft_unet_shift_triple' --model='patch_soft_shiftnet' --shift_sz=3 --mask_thred=4
312 | ```
313 |
314 | DO NOT change the shift_sz and mask_thred. Otherwise, it errors with a high probability.
315 |
316 | For `patch soft shift-net` or `res patch soft shift-net`. You may set `fuse=1` to see whether it delivers better results(Mention, you need keep the same setting between training and testing).
317 |
318 |
319 | ## New things that I want to add
320 | - [x] Make U-Net handle with inputs of any sizes.
321 | - [x] Add more GANs, like spectural norm and relativelistic GAN.
322 | - [x] Boost the efficiency of shift layer.
323 | - [x] Directly resize the global_mask to get the mask in feature space.
324 | - [x] Visualization of flow. It is still experimental now.
325 | - [x] Extensions of Shift-Net. Still active in absorbing new features.
326 | - [x] Fix bug in guidance loss when adopting it in multi-gpu.
327 | - [x] Add composit L1 loss between mask loss and non-mask loss.
328 | - [x] Finish optimizing soft-shift.
329 | - [x] Add mask varaint in a batch.
330 | - [x] Support Online-generating/Offline-loading prepared masks for training/testing.
331 | - [x] Add VGG loss and TV loss
332 | - [x] Fix performance degradance when batchsize is larger than 1.
333 | - [x] Make it compatible for Pytorch 1.2
334 | - [ ] Training with mixed type of masks.
335 | - [ ] Try amp training
336 | - [ ] Try self-attn discriminator(maybe it helps)
337 |
338 | ## Somethings extra I have tried
339 | **Gated Conv**: I have tried gated conv(by replacing the normal convs of UNet with gated conv, expect the innermost/outermost layer).
340 | However, I obtained no benifits. Maybe I should try replacing all layers with gated conv. I will try again when I am free.
341 |
342 | **Non local block**: I added, but seems worse. Maybe I haven't added the blocks on the proper postion. (It makes the training time increase a lot. So I am not in favor of it.)
343 |
344 | ## Citation
345 | If you find this work useful or gives you some insights, please cite:
346 | ```
347 | @InProceedings{Yan_2018_Shift,
348 | author = {Yan, Zhaoyi and Li, Xiaoming and Li, Mu and Zuo, Wangmeng and Shan, Shiguang},
349 | title = {Shift-Net: Image Inpainting via Deep Feature Rearrangement},
350 | booktitle = {The European Conference on Computer Vision (ECCV)},
351 | month = {September},
352 | year = {2018}
353 | }
354 | ```
355 |
356 | ## Acknowledgments
357 | We benefit a lot from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
358 |
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/architecture.png
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/data/__init__.py
--------------------------------------------------------------------------------
/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 | import os.path
3 | import random
4 | import torchvision.transforms as transforms
5 | import torch
6 | import random
7 | from data.base_dataset import BaseDataset
8 | from data.image_folder import make_dataset
9 | from PIL import Image
10 |
11 | class AlignedDataset(BaseDataset):
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.dir_A = opt.dataroot
15 | self.A_paths = sorted(make_dataset(self.dir_A))
16 | if self.opt.offline_loading_mask:
17 | self.mask_folder = self.opt.training_mask_folder if self.opt.isTrain else self.opt.testing_mask_folder
18 | self.mask_paths = sorted(make_dataset(self.mask_folder))
19 |
20 | assert(opt.resize_or_crop == 'resize_and_crop')
21 |
22 | transform_list = [transforms.ToTensor(),
23 | transforms.Normalize((0.5, 0.5, 0.5),
24 | (0.5, 0.5, 0.5))]
25 |
26 | self.transform = transforms.Compose(transform_list)
27 |
28 | def __getitem__(self, index):
29 | A_path = self.A_paths[index]
30 | A = Image.open(A_path).convert('RGB')
31 | w, h = A.size
32 |
33 | if w < h:
34 | ht_1 = self.opt.loadSize * h // w
35 | wd_1 = self.opt.loadSize
36 | A = A.resize((wd_1, ht_1), Image.BICUBIC)
37 | else:
38 | wd_1 = self.opt.loadSize * w // h
39 | ht_1 = self.opt.loadSize
40 | A = A.resize((wd_1, ht_1), Image.BICUBIC)
41 |
42 | A = self.transform(A)
43 | h = A.size(1)
44 | w = A.size(2)
45 | w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
46 | h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
47 |
48 | A = A[:, h_offset:h_offset + self.opt.fineSize,
49 | w_offset:w_offset + self.opt.fineSize]
50 |
51 | if (not self.opt.no_flip) and random.random() < 0.5:
52 | A = torch.flip(A, [2])
53 |
54 | # let B directly equals to A
55 | B = A.clone()
56 | A_flip = torch.flip(A, [2])
57 | B_flip = A_flip.clone()
58 |
59 | # Just zero the mask is fine if not offline_loading_mask.
60 | mask = A.clone().zero_()
61 | if self.opt.offline_loading_mask:
62 | if self.opt.isTrain:
63 | mask = Image.open(self.mask_paths[random.randint(0, len(self.mask_paths)-1)])
64 | else:
65 | mask = Image.open(self.mask_paths[index % len(self.mask_paths)])
66 | mask = mask.resize((self.opt.fineSize, self.opt.fineSize), Image.NEAREST)
67 | mask = transforms.ToTensor()(mask)
68 |
69 | return {'A': A, 'B': B, 'A_F': A_flip, 'B_F': B_flip, 'M': mask,
70 | 'A_paths': A_path}
71 |
72 | def __len__(self):
73 | return len(self.A_paths)
74 |
75 | def name(self):
76 | return 'AlignedDataset'
77 |
--------------------------------------------------------------------------------
/data/aligned_dataset_resized.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 | import os.path
3 | import random
4 | import torchvision.transforms as transforms
5 | import torch
6 | from data.base_dataset import BaseDataset
7 | from data.image_folder import make_dataset
8 | from PIL import Image
9 |
10 | class AlignedDatasetResized(BaseDataset):
11 | def initialize(self, opt):
12 | self.opt = opt
13 | self.root = opt.dataroot
14 | self.dir_A = opt.dataroot # More Flexible for users
15 |
16 | self.A_paths = sorted(make_dataset(self.dir_A))
17 |
18 | assert(opt.resize_or_crop == 'resize_and_crop')
19 |
20 | transform_list = [transforms.ToTensor(),
21 | transforms.Normalize((0.5, 0.5, 0.5),
22 | (0.5, 0.5, 0.5))]
23 |
24 | self.transform = transforms.Compose(transform_list)
25 |
26 | def __getitem__(self, index):
27 | A_path = self.A_paths[index]
28 | A = Image.open(A_path).convert('RGB')
29 |
30 | A = A.resize ((self.opt.fineSize, self.opt.fineSize), Image.BICUBIC)
31 |
32 | A = self.transform(A)
33 |
34 | #if (not self.opt.no_flip) and random.random() < 0.5:
35 | # idx = [i for i in range(A.size(2) - 1, -1, -1)] # size(2)-1, size(2)-2, ... , 0
36 | # idx = torch.LongTensor(idx)
37 | # A = A.index_select(2, idx)
38 |
39 | # let B directly equals A
40 | B = A.clone()
41 | return {'A': A, 'B': B,
42 | 'A_paths': A_path}
43 |
44 | def __len__(self):
45 | return len(self.A_paths)
46 |
47 | def name(self):
48 | return 'AlignedDatasetResized'
49 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | class BaseDataset(data.Dataset):
4 | def __init__(self):
5 | super(BaseDataset, self).__init__()
6 |
7 | def name(self):
8 | return 'BaseDataset'
9 |
10 | def initialize(self, opt):
11 | pass
12 |
13 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 | import torch.utils.data
3 | from data.base_data_loader import BaseDataLoader
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | if opt.dataset_mode == 'aligned':
8 | from data.aligned_dataset import AlignedDataset
9 | dataset = AlignedDataset()
10 |
11 | elif opt.dataset_mode == 'aligned_resized':
12 | from data.aligned_dataset_resized import AlignedDatasetResized
13 | dataset = AlignedDatasetResized()
14 |
15 | elif opt.dataset_mode == 'single':
16 | from data.single_dataset import SingleDataset
17 | dataset = SingleDataset()
18 |
19 | else:
20 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
21 |
22 | print("dataset [%s] was created" % (dataset.name()))
23 | dataset.initialize(opt)
24 | return dataset
25 |
26 |
27 | class CustomDatasetDataLoader(BaseDataLoader):
28 | def name(self):
29 | return 'CustomDatasetDataLoader'
30 |
31 | def initialize(self, opt):
32 | BaseDataLoader.initialize(self, opt)
33 | self.dataset = CreateDataset(opt)
34 |
35 | self.dataloader = torch.utils.data.DataLoader(
36 | self.dataset,
37 | batch_size=opt.batchSize,
38 | shuffle=not opt.serial_batches,
39 | num_workers=int(opt.nThreads))
40 |
41 | def load_data(self):
42 | return self
43 |
44 | def __len__(self):
45 | return min(len(self.dataset), self.opt.max_dataset_size)
46 |
47 | def __iter__(self):
48 | for i, data in enumerate(self.dataloader):
49 | if i*self.opt.batchSize >= self.opt.max_dataset_size:
50 | break
51 | yield data
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.base_dataset import BaseDataset
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 |
7 |
8 | class SingleDataset(BaseDataset):
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.root = opt.dataroot
12 | self.dir_A = os.path.join(opt.dataroot)
13 |
14 | # make_dataset returns paths of all images in one folder
15 | self.A_paths = make_dataset(self.dir_A)
16 |
17 | self.A_paths = sorted(self.A_paths)
18 |
19 | transform_list = []
20 | if opt.resize_or_crop == 'resize_and_crop':
21 | transform_list.append(transforms.Scale(opt.loadSize))
22 |
23 | if opt.isTrain and not opt.no_flip:
24 | transform_list.append(transforms.RandomHorizontalFlip())
25 |
26 | if opt.resize_or_crop != 'no_resize':
27 | transform_list.append(transforms.RandomCrop(opt.fineSize))
28 |
29 | # Make it between [-1, 1], beacuse [(0-0.5)/0.5, (1-0.5)/0.5]
30 | transform_list += [transforms.ToTensor(),
31 | transforms.Normalize((0.5, 0.5, 0.5),
32 | (0.5, 0.5, 0.5))]
33 |
34 | self.transform = transforms.Compose(transform_list)
35 |
36 | def __getitem__(self, index):
37 | A_path = self.A_paths[index]
38 |
39 | A_img = Image.open(A_path).convert('RGB')
40 |
41 | A = self.transform(A_img)
42 | if self.opt.which_direction == 'BtoA':
43 | input_nc = self.opt.output_nc
44 | else:
45 | input_nc = self.opt.input_nc
46 |
47 | return {'A': A, 'A_paths': A_path}
48 |
49 | def __len__(self):
50 | return len(self.A_paths)
51 |
52 | def name(self):
53 | return 'SingleImageDataset'
54 |
--------------------------------------------------------------------------------
/download_models.sh:
--------------------------------------------------------------------------------
1 | # face model (Trained on CelebaHQ-256, the first 2k images are for testing, the rest are for training.)
2 | wget -c https://drive.google.com/open?id=1qvsWHVO9iXpEAPtwyRB25mklTmD0jgPV
3 | # face random mask model
4 | wget -c https://drive.google.com/open?id=1Pz9gkm2VYaEK3qMXnszJufsvRqbcXrjS
5 | # paris random mask model
6 | wget -c https://drive.google.com/open?id=14MzixaqYUdJNL5xGdVhSKI9jOfvGdr3M
7 | # paris center mask model
8 | wget -c https://drive.google.com/open?id=1nDkCdsqUdiEXfSjZ_P915gWeZELK0fo_
9 |
--------------------------------------------------------------------------------
/generate_masks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # import numpy as np
3 | from options.train_options import TrainOptions
4 | import util.util as util
5 | import os
6 | from PIL import Image
7 | import glob
8 |
9 | mask_folder = 'masks/testing_masks'
10 | test_folder = './datasets/Paris/test'
11 | util.mkdir(mask_folder)
12 |
13 | opt = TrainOptions().parse()
14 |
15 | f = glob.glob(test_folder+'/*.png')
16 | print(f)
17 |
18 | for fl in f:
19 | mask = torch.zeros(opt.fineSize, opt.fineSize)
20 | if opt.mask_sub_type == 'fractal':
21 | assert 1==2, "It is broken now..."
22 | mask = util.create_walking_mask() # create an initial random mask.
23 |
24 | elif opt.mask_sub_type == 'rect':
25 | mask, rand_t, rand_l = util.create_rand_mask(opt)
26 |
27 | elif opt.mask_sub_type == 'island':
28 | mask = util.wrapper_gmask(opt)
29 |
30 | print('Generating mask for test image: '+os.path.basename(fl))
31 | util.save_image(mask.squeeze().numpy()*255, os.path.join(mask_folder, os.path.splitext(os.path.basename(fl))[0]+'_mask.png'))
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/imgs/compare/13_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_fake_B.png
--------------------------------------------------------------------------------
/imgs/compare/13_fake_B_flip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_fake_B_flip.png
--------------------------------------------------------------------------------
/imgs/compare/13_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_real_A.png
--------------------------------------------------------------------------------
/imgs/compare/13_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/13_real_B.png
--------------------------------------------------------------------------------
/imgs/compare/18_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_fake_B.png
--------------------------------------------------------------------------------
/imgs/compare/18_fake_B_flip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_fake_B_flip.png
--------------------------------------------------------------------------------
/imgs/compare/18_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_real_A.png
--------------------------------------------------------------------------------
/imgs/compare/18_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/compare/18_real_B.png
--------------------------------------------------------------------------------
/imgs/face_center/0_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/0_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_center/0_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/0_real_A.png
--------------------------------------------------------------------------------
/imgs/face_center/0_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/0_real_B.png
--------------------------------------------------------------------------------
/imgs/face_center/106_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/106_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_center/106_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/106_real_A.png
--------------------------------------------------------------------------------
/imgs/face_center/106_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/106_real_B.png
--------------------------------------------------------------------------------
/imgs/face_center/111_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/111_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_center/111_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/111_real_A.png
--------------------------------------------------------------------------------
/imgs/face_center/111_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/111_real_B.png
--------------------------------------------------------------------------------
/imgs/face_center/11_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/11_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_center/11_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/11_real_A.png
--------------------------------------------------------------------------------
/imgs/face_center/11_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/11_real_B.png
--------------------------------------------------------------------------------
/imgs/face_center/14_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/14_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_center/14_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/14_real_A.png
--------------------------------------------------------------------------------
/imgs/face_center/14_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/14_real_B.png
--------------------------------------------------------------------------------
/imgs/face_center/1_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/1_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_center/1_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/1_real_A.png
--------------------------------------------------------------------------------
/imgs/face_center/1_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_center/1_real_B.png
--------------------------------------------------------------------------------
/imgs/face_random/0_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/0_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_random/0_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/0_real_A.png
--------------------------------------------------------------------------------
/imgs/face_random/0_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/0_real_B.png
--------------------------------------------------------------------------------
/imgs/face_random/1_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/1_fake_B.png
--------------------------------------------------------------------------------
/imgs/face_random/1_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/1_real_A.png
--------------------------------------------------------------------------------
/imgs/face_random/1_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/face_random/1_real_B.png
--------------------------------------------------------------------------------
/imgs/paris_center/003_im_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/003_im_fake_B.png
--------------------------------------------------------------------------------
/imgs/paris_center/003_im_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/003_im_real_A.png
--------------------------------------------------------------------------------
/imgs/paris_center/003_im_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/003_im_real_B.png
--------------------------------------------------------------------------------
/imgs/paris_center/004_im_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/004_im_fake_B.png
--------------------------------------------------------------------------------
/imgs/paris_center/004_im_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/004_im_real_A.png
--------------------------------------------------------------------------------
/imgs/paris_center/004_im_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/004_im_real_B.png
--------------------------------------------------------------------------------
/imgs/paris_center/048_im_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/048_im_fake_B.png
--------------------------------------------------------------------------------
/imgs/paris_center/048_im_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/048_im_real_A.png
--------------------------------------------------------------------------------
/imgs/paris_center/048_im_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_center/048_im_real_B.png
--------------------------------------------------------------------------------
/imgs/paris_random/006_im_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/006_im_fake_B.png
--------------------------------------------------------------------------------
/imgs/paris_random/006_im_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/006_im_real_A.png
--------------------------------------------------------------------------------
/imgs/paris_random/006_im_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/006_im_real_B.png
--------------------------------------------------------------------------------
/imgs/paris_random/055_im_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/055_im_fake_B.png
--------------------------------------------------------------------------------
/imgs/paris_random/055_im_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/055_im_real_A.png
--------------------------------------------------------------------------------
/imgs/paris_random/055_im_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/055_im_real_B.png
--------------------------------------------------------------------------------
/imgs/paris_random/073_im_fake_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/073_im_fake_B.png
--------------------------------------------------------------------------------
/imgs/paris_random/073_im_real_A.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/073_im_real_A.png
--------------------------------------------------------------------------------
/imgs/paris_random/073_im_real_B.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/imgs/paris_random/073_im_real_B.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | def create_model(opt):
2 | model = None
3 | print(opt.model)
4 | if opt.model == 'shiftnet':
5 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
6 | from models.shift_net.shiftnet_model import ShiftNetModel
7 | model = ShiftNetModel()
8 |
9 | elif opt.model == 'res_shiftnet':
10 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
11 | from models.res_shift_net.shiftnet_model import ResShiftNetModel
12 | model = ResShiftNetModel()
13 |
14 | elif opt.model == 'patch_soft_shiftnet':
15 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
16 | from models.patch_soft_shift.patch_soft_shiftnet_model import PatchSoftShiftNetModel
17 | model = PatchSoftShiftNetModel()
18 |
19 | elif opt.model == 'res_patch_soft_shiftnet':
20 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
21 | from models.res_patch_soft_shift.res_patch_soft_shiftnet_model import ResPatchSoftShiftNetModel
22 | model = ResPatchSoftShiftNetModel()
23 |
24 | elif opt.model == 'face_shiftnet':
25 | assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
26 | from models.face_shift_net.face_shiftnet_model import FaceShiftNetModel
27 | model = FaceShiftNetModel()
28 | else:
29 | raise ValueError("Model [%s] not recognized." % opt.model)
30 | model.initialize(opt)
31 | print("model [%s] was created" % (model.name()))
32 | return model
33 |
--------------------------------------------------------------------------------
/models/face_shift_net/InnerFaceShiftTriple.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import util.util as util
4 | from .InnerFaceShiftTripleFunction import InnerFaceShiftTripleFunction
5 |
6 |
7 | class InnerFaceShiftTriple(nn.Module):
8 | def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3, device='gpu'):
9 | super(InnerFaceShiftTriple, self).__init__()
10 |
11 | self.shift_sz = shift_sz
12 | self.stride = stride
13 | self.mask_thred = mask_thred
14 | self.triple_weight = triple_weight
15 | self.layer_to_last = layer_to_last
16 | self.device = device
17 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
18 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
19 |
20 |
21 | def set_mask(self, mask_global):
22 | self.mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
23 |
24 | def _split_mask(self, cur_bsize):
25 | # get the visible indexes of gpus and assign correct mask to set of images
26 | cur_device = torch.cuda.current_device()
27 | self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
28 |
29 |
30 | # If mask changes, then need to set cal_fix_flag true each iteration.
31 | def forward(self, input, flip_feat=None):
32 | self.bz, self.c, self.h, self.w = input.size()
33 | if self.device != 'cpu':
34 | self._split_mask(self.bz)
35 | else:
36 | self.cur_mask = self.mask_all
37 | self.mask = self.cur_mask
38 | self.mask_flip = torch.flip(self.mask, [3])
39 |
40 | self.flag = util.cal_flag_given_mask_thred(self.mask, self.shift_sz, self.stride, self.mask_thred)
41 | self.flag_flip = util.cal_flag_given_mask_thred(self.mask_flip, self.shift_sz, self.stride, self.mask_thred)
42 |
43 | final_out = InnerFaceShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.flag_flip, self.show_flow, flip_feat)
44 | if self.show_flow:
45 | self.flow_srcs = InnerFaceShiftTripleFunction.get_flow_src()
46 |
47 | innerFeat = input.clone().narrow(1, self.c//2, self.c//2)
48 | return final_out, innerFeat
49 |
50 | def get_flow(self):
51 | return self.flow_srcs
52 |
53 | def set_flow_true(self):
54 | self.show_flow = True
55 |
56 | def set_flow_false(self):
57 | self.show_flow = False
58 |
59 | def __repr__(self):
60 | return self.__class__.__name__+ '(' \
61 | + ' ,triple_weight ' + str(self.triple_weight) + ')'
62 |
--------------------------------------------------------------------------------
/models/face_shift_net/InnerFaceShiftTripleFunction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from util.NonparametricShift import Modified_NonparametricShift
3 | import torch
4 | import util.util as util
5 | import time
6 |
7 | # This script offers a version of shift from multi-references.
8 | class InnerFaceShiftTripleFunction(torch.autograd.Function):
9 | ctx = None
10 |
11 | @staticmethod
12 | def forward(ctx, input, shift_sz, stride, triple_w, flag, flag_flip, show_flow, flip_feat=None):
13 | InnerFaceShiftTripleFunction.ctx = ctx
14 | assert input.dim() == 4, "Input Dim has to be 4"
15 | ctx.triple_w = triple_w
16 | ctx.flag = flag
17 | ctx.flag_flip = flag_flip
18 | ctx.show_flow = show_flow
19 |
20 | ctx.bz, c_real, ctx.h, ctx.w = input.size()
21 | c = c_real
22 |
23 | ctx.ind_lst = torch.Tensor(ctx.bz, ctx.h * ctx.w, ctx.h * ctx.w).zero_().to(input)
24 | ctx.ind_lst_flip = ctx.ind_lst.clone()
25 |
26 | # former and latter are all tensors
27 | former_all = input.narrow(1, 0, c//2) ### decoder feature
28 | latter_all = input.narrow(1, c//2, c//2) ### encoder feature
29 | shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all).zero_() # addition feature
30 |
31 | if not flip_feat is None:
32 | assert flip_feat.size() == former_all.size(), "flip_feat size should be equal to former size"
33 |
34 | ctx.flag = ctx.flag.to(input).long()
35 | ctx.flag_flip = ctx.flag_flip.to(input).long()
36 |
37 | # None batch version
38 | Nonparm = Modified_NonparametricShift()
39 | ctx.shift_offsets = []
40 |
41 | for idx in range(ctx.bz):
42 | flag_cur = ctx.flag[idx]
43 | flag_cur_flip = ctx.flag_flip[idx]
44 | latter = latter_all.narrow(0, idx, 1) ### encoder feature
45 | former = former_all.narrow(0, idx, 1) ### decoder feature
46 |
47 | #GET COSINE, RESHAPED LATTER AND ITS INDEXES
48 | cosine, latter_windows, i_2, i_3, i_1 = Nonparm.cosine_similarity(former.clone().squeeze(), latter.clone().squeeze(), 1, stride, flag_cur)
49 | cosine_flip, latter_windows_flip, _, _, _ = Nonparm.cosine_similarity(former.clone().squeeze(), flip_feat.clone().squeeze(), 1, stride, flag_cur_flip)
50 |
51 | # compare which is the bigger one.
52 | cosine_con = torch.cat([cosine, cosine_flip], dim=1)
53 | _, indexes_con = torch.max(cosine_con, dim=1)
54 | # then ori_larger is (non_mask_count*1),
55 | # 1:indicating the original feat is better for shift.
56 | # 0:indicating the flippled feat is a better one.
57 | ori_larger = (indexes_con < cosine.size(1)).long().view(-1,1)
58 |
59 |
60 | ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
61 | _, indexes = torch.max(cosine, dim=1)
62 | _, indexes_flip = torch.max(cosine_flip, dim=1)
63 |
64 | # SET TRANSITION MATRIX
65 | mask_indexes = (flag_cur == 1).nonzero(as_tuple=False)
66 | non_mask_indexes = (flag_cur == 0).nonzero(as_tuple=False)[indexes]
67 | # then remove some indexes where we should select flip feat according to ori_larger
68 | mask_indexes_select_index = (mask_indexes.squeeze() * ori_larger.squeeze()).nonzero(as_tuple=False)
69 | mask_indexes_select = mask_indexes[mask_indexes_select_index].squeeze()
70 | ctx.ind_lst[idx][mask_indexes_select, non_mask_indexes] = 1
71 |
72 |
73 |
74 | non_mask_indexes_flip = (flag_cur_flip == 0).nonzero(as_tuple=False)[indexes_flip]
75 | # then remove some indexes where we should select ori feat according to 1-ori_larger
76 | mask_indexes_flip_select_index = (mask_indexes.squeeze() * (1 - ori_larger.squeeze())).nonzero(as_tuple=False)
77 | mask_indexes_flip_select = mask_indexes[mask_indexes_flip_select_index].squeeze()
78 | ctx.ind_lst_flip[idx][mask_indexes_flip_select, non_mask_indexes_flip] = 1
79 |
80 |
81 | # GET FINAL SHIFT FEATURE
82 | ori_tmp = Nonparm._paste(latter_windows, ctx.ind_lst[idx], i_2, i_3, i_1)
83 | ori_tmp_flip = Nonparm._paste(latter_windows_flip, ctx.ind_lst_flip[idx], i_2, i_3, i_1)
84 |
85 | # combine the two features by directly adding, it is ok.
86 | shift_masked_all[idx] = ori_tmp + ori_tmp_flip
87 |
88 | if ctx.show_flow:
89 | shift_offset = torch.stack([non_mask_indexes.squeeze() // ctx.w, non_mask_indexes.squeeze() % ctx.w], dim=-1)
90 | ctx.shift_offsets.append(shift_offset)
91 |
92 | if ctx.show_flow:
93 | # Note: Here we assume that each mask is the same for the same batch image.
94 | ctx.shift_offsets = torch.cat(ctx.shift_offsets, dim=0).float() # make it cudaFloatTensor
95 | # Assume mask is the same for each image in a batch.
96 | mask_nums = ctx.shift_offsets.size(0)//ctx.bz
97 | ctx.flow_srcs = torch.zeros(ctx.bz, 3, ctx.h, ctx.w).type_as(input)
98 |
99 | for idx in range(ctx.bz):
100 | shift_offset = ctx.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
101 | # reconstruct the original shift_map.
102 | shift_offsets_map = torch.zeros(1, ctx.h, ctx.w, 2).type_as(input)
103 | shift_offsets_map[:, (flag_cur == 1).nonzero(as_tuple=False).squeeze() // ctx.w, (flag_cur == 1).nonzero(as_tuple=False).squeeze() % ctx.w, :] = \
104 | shift_offset.unsqueeze(0)
105 | # It is indicating the pixels(non-masked) that will shift the the masked region.
106 | flow_src = util.highlight_flow(shift_offsets_map, flag_cur.unsqueeze(0))
107 | ctx.flow_srcs[idx] = flow_src
108 |
109 | return torch.cat((former_all, latter_all, shift_masked_all), 1)
110 |
111 |
112 | @staticmethod
113 | def get_flow_src():
114 | return InnerFaceShiftTripleFunction.ctx.flow_srcs
115 |
116 |
117 | # How it works, the extra grad from feat_flip will be enchaned the grad of the second part of the layer (when input I).
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | ind_lst = ctx.ind_lst
121 | ind_lst_flip = ctx.ind_lst_flip
122 |
123 | c = grad_output.size(1)
124 |
125 | # # the former and the latter are keep original. Only the thrid part is shifted.
126 | grad_former_all = grad_output[:, 0:c//3, :, :]
127 | grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
128 | grad_shifted_all = grad_output[:, c*2//3:c, :, :].clone()
129 |
130 | for idx in range(ctx.bz):
131 |
132 | # C: content, pixels in masked region of the former part.
133 | # S: style, pixels in the non-masked region of the latter part.
134 | # N: the shifted feature, the new feature that will be used as the third part of features maps.
135 | # W_mat: ind_lst[idx], shift matrix.
136 | # Note: **only the masked region in N has values**.
137 |
138 | # The gradient of shift feature should be added back to the latter part(to be precise: S).
139 | # `ind_lst[idx][i,j] = 1` means that the i_th pixel will **be replaced** by j_th pixel in the forward.
140 | # When applying `S mm W_mat`, then S will be transfer to N.
141 | # (pixels in non-masked region of the latter part will be shift to the masked region in the third part.)
142 | # However, we need to transfer back the gradient of the third part to S.
143 | # This means the graident in S will **`be replaced`(to be precise, enhanced)** by N.
144 |
145 | # So we need to transpose `W_mat`
146 | W_mat_t = ind_lst[idx].t()
147 | W_mat_t_flip = ind_lst_flip[idx].t()
148 |
149 | grad = grad_shifted_all[idx].view(c//3, -1).t()
150 |
151 | # only the grad of points that locations(non_mask) contribute to is kept
152 | grad_shifted_weighted = torch.mm(W_mat_t, grad)
153 | # only the grad of points that locations(non_mask_flip) contribute to is kept
154 | grad_shifted_weighted_flip = torch.mm(W_mat_t_flip, grad)
155 |
156 | # Then transpose it back
157 | grad_shifted_weighted = grad_shifted_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
158 | grad_shifted_weighted_flip = grad_shifted_weighted_flip.t().contiguous().view(1, c//3, ctx.h, ctx.w)
159 |
160 | grad_shifted_weighted_all = grad_shifted_weighted + grad_shifted_weighted_flip
161 |
162 | grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_shifted_weighted_all.mul(ctx.triple_w))
163 |
164 | # note the input channel and the output channel are all c, as no mask input for now.
165 | grad_input = torch.cat([grad_former_all, grad_latter_all], 1)
166 |
167 | return grad_input, None, None, None, None, None, None, None
168 |
--------------------------------------------------------------------------------
/models/face_shift_net/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/face_shift_net/face_shiftnet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | import util.util as util
4 | from models import networks
5 | from models.shift_net.base_model import BaseModel
6 | import time
7 | import torchvision.transforms as transforms
8 | import os
9 | import numpy as np
10 | from PIL import Image
11 |
12 | class FaceShiftNetModel(BaseModel):
13 | def name(self):
14 | return 'FaceShiftNetModel'
15 |
16 |
17 | def create_random_mask(self):
18 | if self.opt.mask_type == 'random':
19 | if self.opt.mask_sub_type == 'fractal':
20 | assert 1==2, "It is broken somehow, use another mask_sub_type please"
21 | mask = util.create_walking_mask() # create an initial random mask.
22 |
23 | elif self.opt.mask_sub_type == 'rect':
24 | mask, rand_t, rand_l = util.create_rand_mask(self.opt)
25 | self.rand_t = rand_t
26 | self.rand_l = rand_l
27 | return mask
28 |
29 | elif self.opt.mask_sub_type == 'island':
30 | mask = util.wrapper_gmask(self.opt)
31 | return mask
32 |
33 | def initialize(self, opt):
34 | BaseModel.initialize(self, opt)
35 | self.opt = opt
36 | self.isTrain = opt.isTrain
37 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
38 | self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv']
39 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
40 | if self.opt.show_flow:
41 | self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs']
42 | else:
43 | self.visual_names = ['real_A', 'fake_B', 'real_B']
44 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
45 | if self.isTrain:
46 | self.model_names = ['G', 'D']
47 | else: # during test time, only load Gs
48 | self.model_names = ['G']
49 |
50 |
51 | # batchsize should be 1 for mask_global
52 | self.mask_global = torch.zeros(self.opt.batchSize, 1, \
53 | opt.fineSize, opt.fineSize, dtype=torch.bool)
54 |
55 | # Here we need to set an artificial mask_global(center hole is ok.)
56 | self.mask_global.zero_()
57 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
58 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
59 |
60 | if len(opt.gpu_ids) > 0:
61 | self.use_gpu = True
62 | self.mask_global = self.mask_global.to(self.device)
63 |
64 | # load/define networks
65 | # self.ng_innerCos_list is the guidance loss list in netG inner layers.
66 | # self.ng_shift_list is the mask list constructing shift operation.
67 | if opt.add_mask2input:
68 | input_nc = opt.input_nc + 1
69 | else:
70 | input_nc = opt.input_nc
71 |
72 | self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(input_nc, opt.output_nc, opt.ngf,
73 | opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain)
74 |
75 | if self.isTrain:
76 | use_sigmoid = False
77 | if opt.gan_type == 'vanilla':
78 | use_sigmoid = True # only vanilla GAN using BCECriterion
79 | # don't use cGAN
80 | self.netD = networks.define_D(opt.input_nc, opt.ndf,
81 | opt.which_model_netD,
82 | opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain)
83 |
84 | # add style extractor
85 | self.vgg16_extractor = util.VGG16FeatureExtractor()
86 | if len(opt.gpu_ids) > 0:
87 | self.vgg16_extractor = self.vgg16_extractor.to(self.gpu_ids[0])
88 | self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor, self.gpu_ids)
89 |
90 |
91 | if self.isTrain:
92 | self.old_lr = opt.lr
93 | # define loss functions
94 | self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
95 | self.criterionL1 = torch.nn.L1Loss()
96 | self.criterionL1_mask = networks.Discounted_L1(opt).to(self.device) # make weights/buffers transfer to the correct device
97 | # VGG loss
98 | self.criterionL2_style_loss = torch.nn.MSELoss()
99 | self.criterionL2_content_loss = torch.nn.MSELoss()
100 | # TV loss
101 | self.tv_criterion = networks.TVLoss(self.opt.tv_weight)
102 |
103 | # initialize optimizers
104 | self.schedulers = []
105 | self.optimizers = []
106 | if self.opt.gan_type == 'wgan_gp':
107 | opt.beta1 = 0
108 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
109 | lr=opt.lr, betas=(opt.beta1, 0.9))
110 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
111 | lr=opt.lr, betas=(opt.beta1, 0.9))
112 | else:
113 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
114 | lr=opt.lr, betas=(opt.beta1, 0.999))
115 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
116 | lr=opt.lr, betas=(opt.beta1, 0.999))
117 | self.optimizers.append(self.optimizer_G)
118 | self.optimizers.append(self.optimizer_D)
119 | for optimizer in self.optimizers:
120 | self.schedulers.append(networks.get_scheduler(optimizer, opt))
121 |
122 | if not self.isTrain or opt.continue_train:
123 | self.load_networks(opt.which_epoch)
124 |
125 | self.print_networks(opt.verbose)
126 |
127 | def set_input(self, input):
128 | self.image_paths = input['A_paths']
129 | real_A = input['A'].to(self.device)
130 | real_B = input['B'].to(self.device)
131 | real_A_flip = input['A_F'].to(self.device)
132 | # directly load mask offline
133 | self.mask_global = input['M'].to(self.device).byte()
134 | self.mask_global = self.mask_global.narrow(1,0,1).bool()
135 |
136 | # create mask online
137 | if not self.opt.offline_loading_mask:
138 | if self.opt.mask_type == 'center':
139 | self.mask_global.zero_()
140 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
141 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
142 | self.rand_t, self.rand_l = int(self.opt.fineSize/4) + self.opt.overlap, int(self.opt.fineSize/4) + self.opt.overlap
143 | elif self.opt.mask_type == 'random':
144 | self.mask_global = self.create_random_mask().type_as(self.mask_global).view(1, *self.mask_global.size()[-3:])
145 | # As generating random masks online are computation-heavy
146 | # So just generate one ranodm mask for a batch images.
147 | self.mask_global = self.mask_global.expand(self.opt.batchSize, *self.mask_global.size()[-3:])
148 | else:
149 | raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type)
150 | # For loading mask offline, we also need to change 'opt.mask_type' and 'opt.mask_sub_type'
151 | # to avoid forgetting such settings.
152 | else:
153 | self.opt.mask_type = 'random'
154 | self.opt.mask_sub_type = 'island'
155 |
156 | self.set_latent_mask(self.mask_global)
157 |
158 | real_A.narrow(1,0,1).masked_fill_(self.mask_global, 0.)#2*123.0/255.0 - 1.0
159 | real_A.narrow(1,1,1).masked_fill_(self.mask_global, 0.)#2*104.0/255.0 - 1.0
160 | real_A.narrow(1,2,1).masked_fill_(self.mask_global, 0.)#2*117.0/255.0 - 1.0
161 |
162 | self.mask_global_flip = torch.flip(self.mask_global.float(), [3]).bool()
163 | real_A_flip.narrow(1,0,1).masked_fill_(self.mask_global_flip, 0.)#2*123.0/255.0 - 1.0
164 | real_A_flip.narrow(1,1,1).masked_fill_(self.mask_global_flip, 0.)#2*104.0/255.0 - 1.0
165 | real_A_flip.narrow(1,2,1).masked_fill_(self.mask_global_flip, 0.)#2*117.0/255.0 - 1.0
166 |
167 |
168 | if self.opt.add_mask2input:
169 | # make it 4 dimensions.
170 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1.
171 | real_A = torch.cat((real_A, (~self.mask_global).expand(real_A.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1)
172 | real_A_flip = torch.cat((real_A_flip, (~self.mask_global_flip).expand(real_A_flip.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1)
173 |
174 | self.real_A = real_A
175 | self.real_B = real_B
176 | self.real_A_flip = real_A_flip
177 |
178 |
179 | def set_latent_mask(self, mask_global):
180 | for ng_shift in self.ng_shift_list: # ITERATE OVER THE LIST OF ng_shift_list
181 | ng_shift.set_mask(mask_global)
182 | for ng_innerCos in self.ng_innerCos_list: # ITERATE OVER THE LIST OF ng_innerCos_list:
183 | ng_innerCos.set_mask(mask_global)
184 |
185 | def set_gt_latent(self):
186 | if not self.opt.skip:
187 | if self.opt.add_mask2input:
188 | # make it 4 dimensions.
189 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1.
190 | real_B = torch.cat([self.real_B, (~self.mask_global).expand(self.real_B.size(0), 1, self.real_B.size(2), self.real_B.size(3)).type_as(self.real_B)], dim=1)
191 | else:
192 | real_B = self.real_B
193 | self.netG(real_B) # input ground truth
194 |
195 |
196 | def forward(self):
197 | _, flip_feat = self.netG(self.real_A_flip)
198 | # set guidance here
199 | self.set_gt_latent()
200 | self.fake_B, _ = self.netG(self.real_A, flip_feat)
201 |
202 | # Just assume one shift layer.
203 | def set_flow_src(self):
204 | self.flow_srcs = self.ng_shift_list[0].get_flow()
205 | self.flow_srcs = F.interpolate(self.flow_srcs, scale_factor=8, mode='nearest')
206 | # Just to avoid forgetting setting show_map_false
207 | self.set_show_map_false()
208 |
209 | # Just assume one shift layer.
210 | def set_show_map_true(self):
211 | self.ng_shift_list[0].set_flow_true()
212 |
213 | def set_show_map_false(self):
214 | self.ng_shift_list[0].set_flow_false()
215 |
216 | def get_image_paths(self):
217 | return self.image_paths
218 |
219 | def backward_D(self):
220 | fake_B = self.fake_B
221 | # Real
222 | real_B = self.real_B # GroundTruth
223 |
224 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results.
225 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
226 | # Using the cropped fake_B as the input of D.
227 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
228 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
229 |
230 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
231 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
232 |
233 | self.pred_fake = self.netD(fake_B.detach())
234 | self.pred_real = self.netD(real_B)
235 |
236 | if self.opt.gan_type == 'wgan_gp':
237 | gradient_penalty, _ = util.cal_gradient_penalty(self.netD, real_B, fake_B.detach(), self.device, constant=1, lambda_gp=self.opt.gp_lambda)
238 | self.loss_D_fake = torch.mean(self.pred_fake)
239 | self.loss_D_real = -torch.mean(self.pred_real)
240 |
241 | self.loss_D = self.loss_D_fake + self.loss_D_real + gradient_penalty
242 | else:
243 | if self.opt.gan_type in ['vanilla', 'lsgan']:
244 | self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
245 | self.loss_D_real = self.criterionGAN (self.pred_real, True)
246 |
247 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
248 |
249 | elif self.opt.gan_type == 're_s_gan':
250 | self.loss_D = self.criterionGAN(self.pred_real - self.pred_fake, True)
251 |
252 | elif self.opt.gan_type == 're_avg_gan':
253 | self.loss_D = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), True) \
254 | + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), False)) / 2.
255 | # for `re_avg_gan`, need to retain graph of D.
256 | if self.opt.gan_type == 're_avg_gan':
257 | self.loss_D.backward(retain_graph=True)
258 | else:
259 | self.loss_D.backward()
260 |
261 |
262 | def backward_G(self):
263 | # First, G(A) should fake the discriminator
264 | fake_B = self.fake_B
265 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results.
266 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
267 | # Using the cropped fake_B as the input of D.
268 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
269 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
270 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
271 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
272 | else:
273 | real_B = self.real_B
274 |
275 | pred_fake = self.netD(fake_B)
276 |
277 |
278 | if self.opt.gan_type == 'wgan_gp':
279 | self.loss_G_GAN = -torch.mean(pred_fake)
280 | else:
281 | if self.opt.gan_type in ['vanilla', 'lsgan']:
282 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.gan_weight
283 |
284 | elif self.opt.gan_type == 're_s_gan':
285 | pred_real = self.netD (real_B)
286 | self.loss_G_GAN = self.criterionGAN (pred_fake - pred_real, True) * self.opt.gan_weight
287 |
288 | elif self.opt.gan_type == 're_avg_gan':
289 | self.pred_real = self.netD(real_B)
290 | self.loss_G_GAN = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \
291 | + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2.
292 | self.loss_G_GAN *= self.opt.gan_weight
293 |
294 |
295 | # If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'.
296 | self.loss_G_L1, self.loss_G_L1_m = 0, 0
297 | self.loss_G_L1 += self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
298 | # calcuate mask construction loss
299 | # When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1).
300 | # Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1.
301 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
302 | mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
303 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
304 | mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
305 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
306 | # Using Discounting L1 loss
307 | self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight_G
308 |
309 | self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN
310 |
311 | # Then, add TV loss
312 | self.loss_tv = self.tv_criterion(self.fake_B*self.mask_global.float())
313 |
314 | # Finally, add style loss
315 | vgg_ft_fakeB = self.vgg16_extractor(fake_B)
316 | vgg_ft_realB = self.vgg16_extractor(real_B)
317 | self.loss_style = 0
318 | self.loss_content = 0
319 |
320 | for i in range(3):
321 | self.loss_style += self.criterionL2_style_loss(util.gram_matrix(vgg_ft_fakeB[i]), util.gram_matrix(vgg_ft_realB[i]))
322 | self.loss_content += self.criterionL2_content_loss(vgg_ft_fakeB[i], vgg_ft_realB[i])
323 |
324 | self.loss_style *= self.opt.style_weight
325 | self.loss_content *= self.opt.content_weight
326 |
327 | self.loss_G += (self.loss_style + self.loss_content + self.loss_tv)
328 |
329 | self.loss_G.backward()
330 |
331 | def optimize_parameters(self):
332 | self.forward()
333 | # update D
334 | self.set_requires_grad(self.netD, True)
335 | self.optimizer_D.zero_grad()
336 | self.backward_D()
337 | self.optimizer_D.step()
338 |
339 | # update G
340 | self.set_requires_grad(self.netD, False)
341 | self.optimizer_G.zero_grad()
342 | self.backward_G()
343 | self.optimizer_G.step()
344 |
345 |
346 |
--------------------------------------------------------------------------------
/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .discrimators import *
2 | from .losses import *
3 | from .modules import *
4 | from .shift_unet import *
5 | from .unet import *
--------------------------------------------------------------------------------
/models/modules/denset_net.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .modules import *
6 | import torch.utils.model_zoo as model_zoo
7 | from collections import OrderedDict
8 |
9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
10 |
11 |
12 | model_urls = {
13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
17 | }
18 |
19 |
20 | def densenet121(pretrained=False, use_spectral_norm=True, **kwargs):
21 | r"""Densenet-121 model from
22 | `"Densely Connected Convolutional Networks" `_
23 |
24 | Args:
25 | pretrained (bool): If True, returns a model pre-trained on ImageNet
26 | """
27 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), use_spectral_norm=use_spectral_norm,
28 | **kwargs)
29 | if pretrained:
30 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
31 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
32 | # They are also in the checkpoints in model_urls. This pattern is used
33 | # to find such keys.
34 | pattern = re.compile(
35 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
36 | state_dict = model_zoo.load_url(model_urls['densenet121'])
37 | for key in list(state_dict.keys()):
38 | res = pattern.match(key)
39 | if res:
40 | new_key = res.group(1) + res.group(2)
41 | state_dict[new_key] = state_dict[key]
42 | del state_dict[key]
43 | model.load_state_dict(state_dict, strict=False)
44 | return model
45 |
46 |
47 |
48 | def densenet169(pretrained=False, **kwargs):
49 | r"""Densenet-169 model from
50 | `"Densely Connected Convolutional Networks" `_
51 |
52 | Args:
53 | pretrained (bool): If True, returns a model pre-trained on ImageNet
54 | """
55 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
56 | **kwargs)
57 | if pretrained:
58 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
59 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
60 | # They are also in the checkpoints in model_urls. This pattern is used
61 | # to find such keys.
62 | pattern = re.compile(
63 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
64 | state_dict = model_zoo.load_url(model_urls['densenet169'])
65 | for key in list(state_dict.keys()):
66 | res = pattern.match(key)
67 | if res:
68 | new_key = res.group(1) + res.group(2)
69 | state_dict[new_key] = state_dict[key]
70 | del state_dict[key]
71 | model.load_state_dict(state_dict)
72 | return model
73 |
74 |
75 |
76 | def densenet201(pretrained=False, **kwargs):
77 | r"""Densenet-201 model from
78 | `"Densely Connected Convolutional Networks" `_
79 |
80 | Args:
81 | pretrained (bool): If True, returns a model pre-trained on ImageNet
82 | """
83 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
84 | **kwargs)
85 | if pretrained:
86 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
87 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
88 | # They are also in the checkpoints in model_urls. This pattern is used
89 | # to find such keys.
90 | pattern = re.compile(
91 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
92 | state_dict = model_zoo.load_url(model_urls['densenet201'])
93 | for key in list(state_dict.keys()):
94 | res = pattern.match(key)
95 | if res:
96 | new_key = res.group(1) + res.group(2)
97 | state_dict[new_key] = state_dict[key]
98 | del state_dict[key]
99 | model.load_state_dict(state_dict)
100 | return model
101 |
102 |
103 |
104 | def densenet161(pretrained=False, **kwargs):
105 | r"""Densenet-161 model from
106 | `"Densely Connected Convolutional Networks" `_
107 |
108 | Args:
109 | pretrained (bool): If True, returns a model pre-trained on ImageNet
110 | """
111 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
112 | **kwargs)
113 | if pretrained:
114 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
115 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
116 | # They are also in the checkpoints in model_urls. This pattern is used
117 | # to find such keys.
118 | pattern = re.compile(
119 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
120 | state_dict = model_zoo.load_url(model_urls['densenet161'])
121 | for key in list(state_dict.keys()):
122 | res = pattern.match(key)
123 | if res:
124 | new_key = res.group(1) + res.group(2)
125 | state_dict[new_key] = state_dict[key]
126 | del state_dict[key]
127 | model.load_state_dict(state_dict)
128 | return model
129 |
130 |
131 |
132 | class _DenseLayer(nn.Sequential):
133 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, use_spectral_norm):
134 | super(_DenseLayer, self).__init__()
135 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
136 | self.add_module('relu1', nn.ReLU()),
137 | self.add_module('conv1', spectral_norm(nn.Conv2d(num_input_features, bn_size *
138 | growth_rate, kernel_size=1, stride=1, bias=False), use_spectral_norm)),
139 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
140 | self.add_module('relu2', nn.ReLU()),
141 | self.add_module('conv2', spectral_norm(nn.Conv2d(bn_size * growth_rate, growth_rate,
142 | kernel_size=3, stride=1, padding=1, bias=False), use_spectral_norm)),
143 | self.drop_rate = drop_rate
144 |
145 | def forward(self, x):
146 | new_features = super(_DenseLayer, self).forward(x)
147 | if self.drop_rate > 0:
148 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
149 | return torch.cat([x, new_features], 1)
150 |
151 |
152 | class _DenseBlock(nn.Sequential):
153 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, use_spectral_norm):
154 | super(_DenseBlock, self).__init__()
155 | for i in range(num_layers):
156 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, use_spectral_norm)
157 | self.add_module('denselayer%d' % (i + 1), layer)
158 |
159 |
160 | class _Transition(nn.Sequential):
161 | def __init__(self, num_input_features, num_output_features, use_spectral_norm):
162 | super(_Transition, self).__init__()
163 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
164 | self.add_module('relu', nn.ReLU())
165 | self.add_module('conv', spectral_norm(nn.Conv2d(num_input_features, num_output_features,
166 | kernel_size=1, stride=1, bias=False), use_spectral_norm))
167 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
168 |
169 |
170 | class DenseNet(nn.Module):
171 | r"""Densenet-BC model class, based on
172 | `"Densely Connected Convolutional Networks" `_
173 |
174 | Args:
175 | growth_rate (int) - how many filters to add each layer (`k` in paper)
176 | block_config (list of 4 ints) - how many layers in each pooling block
177 | num_init_features (int) - the number of filters to learn in the first convolution layer
178 | bn_size (int) - multiplicative factor for number of bottle neck layers
179 | (i.e. bn_size * k features in the bottleneck layer)
180 | drop_rate (float) - dropout rate after each dense layer
181 | num_classes (int) - number of classification classes
182 | """
183 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), use_spectral_norm=True,
184 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
185 |
186 | super(DenseNet, self).__init__()
187 |
188 | # First convolution
189 | self.features = nn.Sequential(OrderedDict([
190 | ('conv0', spectral_norm(nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False), use_spectral_norm)),
191 | ('norm0', nn.BatchNorm2d(num_init_features)),
192 | ('relu0', nn.ReLU()),
193 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
194 | ]))
195 |
196 | # Each denseblock
197 | num_features = num_init_features
198 | for i, num_layers in enumerate(block_config):
199 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
200 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, use_spectral_norm=use_spectral_norm)
201 | self.features.add_module('denseblock%d' % (i + 1), block)
202 | num_features = num_features + num_layers * growth_rate
203 | if i != len(block_config) - 1:
204 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, use_spectral_norm=use_spectral_norm)
205 | self.features.add_module('transition%d' % (i + 1), trans)
206 | num_features = num_features // 2
207 |
208 | # Final batch norm
209 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
210 |
211 | self.conv_last = spectral_norm(nn.Conv2d(num_features, 256, kernel_size=3), use_spectral_norm)
212 |
213 | # Linear layer
214 | # self.classifier = nn.Linear(num_features, num_classes)
215 |
216 | # Official init from torch repo.
217 | for m in self.modules():
218 | if isinstance(m, nn.Conv2d):
219 | nn.init.kaiming_normal_(m.weight.data)
220 | elif isinstance(m, nn.BatchNorm2d):
221 | m.weight.data.fill_(1)
222 | m.bias.data.zero_()
223 | elif isinstance(m, nn.Linear):
224 | m.bias.data.zero_()
225 |
226 | def forward(self, x):
227 | features = self.features(x)
228 | features = self.conv_last(features)
229 | return features
230 |
--------------------------------------------------------------------------------
/models/modules/discrimators.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 | from .denset_net import *
4 |
5 | from .modules import *
6 | ################################### This is for D ###################################
7 | # Defines the PatchGAN discriminator with the specified arguments.
8 | class NLayerDiscriminator(nn.Module):
9 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True):
10 | super(NLayerDiscriminator, self).__init__()
11 | if type(norm_layer) == functools.partial:
12 | use_bias = norm_layer.func == nn.InstanceNorm2d
13 | else:
14 | use_bias = norm_layer == nn.InstanceNorm2d
15 |
16 | kw = 4
17 | padw = 1
18 | sequence = [
19 | spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), use_spectral_norm),
20 | nn.LeakyReLU(0.2, True)
21 | ]
22 |
23 | nf_mult = 1
24 | nf_mult_prev = 1
25 | for n in range(1, n_layers):
26 | nf_mult_prev = nf_mult
27 | nf_mult = min(2**n, 8)
28 | sequence += [
29 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
30 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), use_spectral_norm),
31 | norm_layer(ndf * nf_mult),
32 | nn.LeakyReLU(0.2, True)
33 | ]
34 |
35 | nf_mult_prev = nf_mult
36 | nf_mult = min(2**n_layers, 8)
37 | sequence += [
38 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
39 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), use_spectral_norm),
40 | norm_layer(ndf * nf_mult),
41 | nn.LeakyReLU(0.2, True)
42 | ]
43 | sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), use_spectral_norm)]
44 |
45 | if use_sigmoid:
46 | sequence += [nn.Sigmoid()]
47 |
48 | self.model = nn.Sequential(*sequence)
49 |
50 | def forward(self, input):
51 | return self.model(input)
52 |
53 |
54 | # Defines a densetnet inspired discriminator (Should improve its ability to create stronger representation)
55 | class DenseNetDiscrimator(nn.Module):
56 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True):
57 | super(DenseNetDiscrimator, self).__init__()
58 | self.model = densenet121(pretrained=True, use_spectral_norm=use_spectral_norm)
59 | self.use_sigmoid = use_sigmoid
60 | if self.use_sigmoid:
61 | self.sigmoid = nn.Sigmoid()
62 |
63 | def forward(self, input):
64 | if self.use_sigmoid:
65 | return self.sigmoid(self.model(input))
66 | else:
67 | return self.model(input)
68 |
--------------------------------------------------------------------------------
/models/modules/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
6 | # When LSGAN is used, it is basically same as MSELoss,
7 | # but it abstracts away the need to create the target label tensor
8 | # that has the same size as the input
9 | class GANLoss(nn.Module):
10 | def __init__(self, gan_type='wgan_gp', target_real_label=1.0, target_fake_label=0.0):
11 | super(GANLoss, self).__init__()
12 | self.register_buffer('real_label', torch.tensor(target_real_label))
13 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
14 | self.gan_type = gan_type
15 | if gan_type == 'wgan_gp':
16 | self.loss = nn.MSELoss()
17 | elif gan_type == 'lsgan':
18 | self.loss = nn.MSELoss()
19 | elif gan_type == 'vanilla':
20 | self.loss = nn.BCELoss()
21 | #######################################################################
22 | ### Relativistic GAN - https://github.com/AlexiaJM/RelativisticGAN ###
23 | #######################################################################
24 | # When Using `BCEWithLogitsLoss()`, remove the sigmoid layer in D.
25 | elif gan_type == 're_s_gan':
26 | self.loss = nn.BCEWithLogitsLoss()
27 | elif gan_type == 're_avg_gan':
28 | self.loss = nn.BCEWithLogitsLoss()
29 | else:
30 | raise ValueError("GAN type [%s] not recognized." % gan_type)
31 |
32 | def get_target_tensor(self, prediction, target_is_real):
33 | if target_is_real:
34 | target_tensor = self.real_label
35 | else:
36 | target_tensor = self.fake_label
37 | return target_tensor.expand_as(prediction)
38 |
39 | def __call__(self, prediction, target_is_real):
40 | if self.gan_type == 'wgan_gp':
41 | if target_is_real:
42 | loss = -prediction.mean()
43 | else:
44 | loss = prediction.mean()
45 | else:
46 | target_tensor = self.get_target_tensor(prediction, target_is_real)
47 | loss = self.loss(prediction, target_tensor)
48 | return loss
49 |
50 | ################# Discounting loss #########################
51 | ######################################################
52 | class Discounted_L1(nn.Module):
53 | def __init__(self, opt):
54 | super(Discounted_L1, self).__init__()
55 | # Register discounting template as a buffer
56 | self.register_buffer('discounting_mask', torch.tensor(spatial_discounting_mask(opt.fineSize//2 - opt.overlap * 2, opt.fineSize//2 - opt.overlap * 2, 0.9, opt.discounting)))
57 | self.L1 = nn.L1Loss()
58 |
59 | def forward(self, input, target):
60 | self._assert_no_grad(target)
61 | input_tmp = input * self.discounting_mask
62 | target_tmp = target * self.discounting_mask
63 | return self.L1(input_tmp, target_tmp)
64 |
65 |
66 | def _assert_no_grad(self, variable):
67 | assert not variable.requires_grad, \
68 | "nn criterions don't compute the gradient w.r.t. targets - please " \
69 | "mark these variables as volatile or not requiring gradients"
70 |
71 |
72 | def spatial_discounting_mask(mask_width, mask_height, discounting_gamma, discounting=1):
73 | """Generate spatial discounting mask constant.
74 | Spatial discounting mask is first introduced in publication:
75 | Generative Image Inpainting with Contextual Attention, Yu et al.
76 | Returns:
77 | tf.Tensor: spatial discounting mask
78 | """
79 | gamma = discounting_gamma
80 | shape = [1, 1, mask_width, mask_height]
81 | if discounting:
82 | print('Use spatial discounting l1 loss.')
83 | mask_values = np.ones((mask_width, mask_height), dtype='float32')
84 | for i in range(mask_width):
85 | for j in range(mask_height):
86 | mask_values[i, j] = max(
87 | gamma**min(i, mask_width-i),
88 | gamma**min(j, mask_height-j))
89 | mask_values = np.expand_dims(mask_values, 0)
90 | mask_values = np.expand_dims(mask_values, 1)
91 | mask_values = mask_values
92 | else:
93 | mask_values = np.ones(shape, dtype='float32')
94 |
95 | return mask_values
96 |
97 | class TVLoss(nn.Module):
98 | def __init__(self, tv_loss_weight=1):
99 | super(TVLoss, self).__init__()
100 | self.tv_loss_weight = tv_loss_weight
101 |
102 | def forward(self, x):
103 | bz, _, h, w = x.size()
104 | count_h = self._tensor_size(x[:, :, 1:, :])
105 | count_w = self._tensor_size(x[:, :, :, 1:])
106 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h - 1, :]), 2).sum()
107 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w - 1]), 2).sum()
108 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / bz
109 |
110 | @staticmethod
111 | def _tensor_size(t):
112 | return t.size(1) * t.size(2) * t.size(3)
--------------------------------------------------------------------------------
/models/modules/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import Parameter
6 |
7 |
8 | class Self_Attn (nn.Module):
9 | """ Self attention Layer"""
10 |
11 | '''
12 | https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
13 | '''
14 |
15 | def __init__(self, in_dim, activation, with_attention=False):
16 | super (Self_Attn, self).__init__ ()
17 | self.chanel_in = in_dim
18 | self.activation = activation
19 | self.with_attention = with_attention
20 |
21 | self.query_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
22 | self.key_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
23 | self.value_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim, kernel_size=1)
24 | self.gamma = nn.Parameter (torch.zeros (1))
25 |
26 | self.softmax = nn.Softmax (dim=-1) #
27 |
28 | def forward(self, x):
29 | """
30 | inputs :
31 | x : input feature maps( B X C X W X H)
32 | returns :
33 | out : self attention value + input feature
34 | attention: B X N X N (N is Width*Height)
35 | """
36 | m_batchsize, C, width, height = x.size ()
37 | proj_query = self.query_conv (x).view (m_batchsize, -1, width * height).permute (0, 2, 1) # B X CX(N)
38 | proj_key = self.key_conv (x).view (m_batchsize, -1, width * height) # B X C x (*W*H)
39 | energy = torch.bmm (proj_query, proj_key) # transpose check
40 | attention = self.softmax (energy) # BX (N) X (N)
41 | proj_value = self.value_conv (x).view (m_batchsize, -1, width * height) # B X C X N
42 |
43 | out = torch.bmm (proj_value, attention.permute (0, 2, 1))
44 | out = out.view (m_batchsize, C, width, height)
45 |
46 | out = self.gamma * out + x
47 |
48 | if self.with_attention:
49 | return out, attention
50 | else:
51 | return out
52 |
53 | def l2normalize(v, eps=1e-12):
54 | return v / (v.norm() + eps)
55 |
56 | def spectral_norm(module, mode=True):
57 | if mode:
58 | return nn.utils.spectral_norm(module)
59 |
60 | return module
61 |
62 |
63 | class SwitchNorm2d(nn.Module):
64 | def __init__(self, num_features, eps=1e-5, momentum=0.9, using_moving_average=True, using_bn=True,
65 | last_gamma=False):
66 | super(SwitchNorm2d, self).__init__()
67 | self.eps = eps
68 | self.momentum = momentum
69 | self.using_moving_average = using_moving_average
70 | self.using_bn = using_bn
71 | self.last_gamma = last_gamma
72 | self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
73 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
74 | if self.using_bn:
75 | self.mean_weight = nn.Parameter(torch.ones(3))
76 | self.var_weight = nn.Parameter(torch.ones(3))
77 | else:
78 | self.mean_weight = nn.Parameter(torch.ones(2))
79 | self.var_weight = nn.Parameter(torch.ones(2))
80 | if self.using_bn:
81 | self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
82 | self.register_buffer('running_var', torch.zeros(1, num_features, 1))
83 |
84 | self.reset_parameters()
85 |
86 | def reset_parameters(self):
87 | if self.using_bn:
88 | self.running_mean.zero_()
89 | self.running_var.zero_()
90 | if self.last_gamma:
91 | self.weight.data.fill_(0)
92 | else:
93 | self.weight.data.fill_(1)
94 | self.bias.data.zero_()
95 |
96 | def _check_input_dim(self, input):
97 | if input.dim() != 4:
98 | raise ValueError('expected 4D input (got {}D input)'
99 | .format(input.dim()))
100 |
101 | def forward(self, x):
102 | self._check_input_dim(x)
103 | N, C, H, W = x.size()
104 | x = x.view(N, C, -1)
105 | mean_in = x.mean(-1, keepdim=True)
106 | var_in = x.var(-1, keepdim=True)
107 |
108 | mean_ln = mean_in.mean(1, keepdim=True)
109 | temp = var_in + mean_in ** 2
110 | var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2
111 |
112 | if self.using_bn:
113 | if self.training:
114 | mean_bn = mean_in.mean(0, keepdim=True)
115 | var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
116 | if self.using_moving_average:
117 | self.running_mean.mul_(self.momentum)
118 | self.running_mean.add_((1 - self.momentum) * mean_bn.data)
119 | self.running_var.mul_(self.momentum)
120 | self.running_var.add_((1 - self.momentum) * var_bn.data)
121 | else:
122 | self.running_mean.add_(mean_bn.data)
123 | self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
124 | else:
125 | mean_bn = torch.autograd.Variable(self.running_mean)
126 | var_bn = torch.autograd.Variable(self.running_var)
127 |
128 | softmax = nn.Softmax(0)
129 | mean_weight = softmax(self.mean_weight)
130 | var_weight = softmax(self.var_weight)
131 |
132 | if self.using_bn:
133 | mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
134 | var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
135 | else:
136 | mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
137 | var = var_weight[0] * var_in + var_weight[1] * var_ln
138 |
139 | x = (x-mean) / (var+self.eps).sqrt()
140 | x = x.view(N, C, H, W)
141 | return x * self.weight + self.bias
142 |
143 |
144 | class PartialConv(nn.Module):
145 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
146 | padding=0, dilation=1, groups=1, bias=True):
147 | super(PartialConv).__init__()
148 | self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
149 | stride, padding, dilation, groups, bias)
150 | self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
151 | stride, padding, dilation, groups, False)
152 |
153 | #self.input_conv.apply(weights_init('kaiming'))
154 |
155 | torch.nn.init.constant_(self.mask_conv.weight, 1.0)
156 |
157 | # mask is not updated
158 | for param in self.mask_conv.parameters():
159 | param.requires_grad = False
160 |
161 | def forward(self, input, mask):
162 |
163 | output = self.input_conv(input * mask)
164 | if self.input_conv.bias is not None:
165 | output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
166 | output)
167 | else:
168 | output_bias = torch.zeros_like(output)
169 |
170 | with torch.no_grad():
171 | output_mask = self.mask_conv(mask)
172 |
173 | no_update_holes = output_mask == 0
174 | mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
175 |
176 | output_pre = (output - output_bias) / mask_sum + output_bias
177 | output = output_pre.masked_fill_(no_update_holes, 0.0)
178 |
179 | new_mask = torch.ones_like(output)
180 | new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
181 |
182 | return output, new_mask
183 |
184 |
185 | class ResnetBlock(nn.Module):
186 | def __init__(self, dim, padding_type, norm_layer, use_bias):
187 | super(ResnetBlock, self).__init__()
188 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_bias)
189 |
190 | def build_conv_block(self, dim, padding_type, norm_layer, use_bias):
191 | conv_block = []
192 | p = 0
193 | if padding_type == 'reflect':
194 | conv_block += [nn.ReflectionPad2d(1)]
195 | elif padding_type == 'replicate':
196 | conv_block += [nn.ReplicationPad2d(1)]
197 | elif padding_type == 'zero':
198 | p = 1
199 | else:
200 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
201 |
202 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
203 | norm_layer(dim),
204 | nn.ReLU(True)]
205 |
206 | p = 0
207 | if padding_type == 'reflect':
208 | conv_block += [nn.ReflectionPad2d(1)]
209 | elif padding_type == 'replicate':
210 | conv_block += [nn.ReplicationPad2d(1)]
211 | elif padding_type == 'zero':
212 | p = 1
213 | else:
214 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
215 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
216 | norm_layer(dim)]
217 |
218 | return nn.Sequential(*conv_block)
219 |
220 | def forward(self, x):
221 | out = x + self.conv_block(x)
222 | return out
223 |
--------------------------------------------------------------------------------
/models/modules/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .modules import spectral_norm
5 |
6 | # Defines the Unet generator.
7 | # |num_downs|: number of downsamplings in UNet. For example,
8 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
9 | # at the bottleneck
10 | class UnetGenerator(nn.Module):
11 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
12 | norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
13 | super(UnetGenerator, self).__init__()
14 |
15 | # construct unet structure
16 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_spectral_norm=use_spectral_norm)
17 | for i in range(num_downs - 5):
18 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
19 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
20 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
21 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
22 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
23 |
24 | self.model = unet_block
25 |
26 | def forward(self, input):
27 | return self.model(input)
28 |
29 | # construct network from the inside to the outside.
30 | # Defines the submodule with skip connection.
31 | # X -------------------identity---------------------- X
32 | # |-- downsampling -- |submodule| -- upsampling --|
33 | class UnetSkipConnectionBlock(nn.Module):
34 | def __init__(self, outer_nc, inner_nc, input_nc,
35 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
36 | super(UnetSkipConnectionBlock, self).__init__()
37 | self.outermost = outermost
38 |
39 | if input_nc is None:
40 | input_nc = outer_nc
41 |
42 | downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
43 | stride=2, padding=1), use_spectral_norm)
44 | downrelu = nn.LeakyReLU(0.2, True)
45 | downnorm = norm_layer(inner_nc)
46 | uprelu = nn.ReLU(True)
47 | upnorm = norm_layer(outer_nc)
48 |
49 | # Different position only has differences in `upconv`
50 | # for the outermost, the special is `tanh`
51 | if outermost:
52 | upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
53 | kernel_size=4, stride=2,
54 | padding=1), use_spectral_norm)
55 | down = [downconv]
56 | up = [uprelu, upconv, nn.Tanh()]
57 | model = down + [submodule] + up
58 | # for the innermost, the special is `inner_nc` instead of `inner_nc*2`
59 | elif innermost:
60 | upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
61 | kernel_size=4, stride=2,
62 | padding=1), use_spectral_norm)
63 | down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
64 | up = [uprelu, upconv, upnorm]
65 | model = down + up
66 | # else, the normal
67 | else:
68 | upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
69 | kernel_size=4, stride=2,
70 | padding=1), use_spectral_norm)
71 | down = [downrelu, downconv, downnorm]
72 | up = [uprelu, upconv, upnorm]
73 |
74 | model = down + [submodule] + up
75 |
76 | self.model = nn.Sequential(*model)
77 |
78 | def forward(self, x):
79 | if self.outermost: # if it is the outermost, directly pass the input in.
80 | return self.model(x)
81 | else:
82 | x_latter = self.model(x)
83 | _, _, h, w = x.size()
84 | if h != x_latter.size(2) or w != x_latter.size(3):
85 | x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
86 | return torch.cat([x_latter, x], 1) # cat in the C channel
87 |
88 |
89 | # It is an easy type of UNet, intead of constructing UNet with UnetSkipConnectionBlocks.
90 | # In this way, every thing is much clear and more flexible for extension.
91 | class EasyUnetGenerator(nn.Module):
92 | def __init__(self, input_nc, output_nc, ngf=64,
93 | norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
94 | super(EasyUnetGenerator, self).__init__()
95 |
96 | # Encoder layers
97 | self.e1_c = spectral_norm(nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
98 |
99 | self.e2_c = spectral_norm(nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
100 | self.e2_norm = norm_layer(ngf*2)
101 |
102 | self.e3_c = spectral_norm(nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
103 | self.e3_norm = norm_layer(ngf*4)
104 |
105 | self.e4_c = spectral_norm(nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
106 | self.e4_norm = norm_layer(ngf*8)
107 |
108 | self.e5_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
109 | self.e5_norm = norm_layer(ngf*8)
110 |
111 | self.e6_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
112 | self.e6_norm = norm_layer(ngf*8)
113 |
114 | self.e7_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
115 | self.e7_norm = norm_layer(ngf*8)
116 |
117 | self.e8_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
118 |
119 | # Deocder layers
120 | self.d1_c = spectral_norm(nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
121 | self.d1_norm = norm_layer(ngf*8)
122 |
123 | self.d2_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2 , ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
124 | self.d2_norm = norm_layer(ngf*8)
125 |
126 | self.d3_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
127 | self.d3_norm = norm_layer(ngf*8)
128 |
129 | self.d4_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
130 | self.d4_norm = norm_layer(ngf*8)
131 |
132 | self.d5_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
133 | self.d5_norm = norm_layer(ngf*4)
134 |
135 | self.d6_c = spectral_norm(nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
136 | self.d6_norm = norm_layer(ngf*2)
137 |
138 | self.d7_c = spectral_norm(nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
139 | self.d7_norm = norm_layer(ngf)
140 |
141 | self.d8_c = spectral_norm(nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1), use_spectral_norm)
142 |
143 |
144 | # In this case, we have very flexible unet construction mode.
145 | def forward(self, input):
146 | # Encoder
147 | # No norm on the first layer
148 | e1 = self.e1_c(input)
149 | e2 = self.e2_norm(self.e2_c(F.leaky_relu_(e1, negative_slope=0.2)))
150 | e3 = self.e3_norm(self.e3_c(F.leaky_relu_(e2, negative_slope=0.2)))
151 | e4 = self.e4_norm(self.e4_c(F.leaky_relu_(e3, negative_slope=0.2)))
152 | e5 = self.e5_norm(self.e5_c(F.leaky_relu_(e4, negative_slope=0.2)))
153 | e6 = self.e6_norm(self.e6_c(F.leaky_relu_(e5, negative_slope=0.2)))
154 | e7 = self.e7_norm(self.e7_c(F.leaky_relu_(e6, negative_slope=0.2)))
155 | # No norm on the inner_most layer
156 | e8 = self.e8_c(F.leaky_relu_(e7, negative_slope=0.2))
157 |
158 | # Decoder
159 | d1 = self.d1_norm(self.d1_c(F.relu_(e8)))
160 | d2 = self.d2_norm(self.d2_c(F.relu_(torch.cat([d1, e7], dim=1))))
161 | d3 = self.d3_norm(self.d3_c(F.relu_(torch.cat([d2, e6], dim=1))))
162 | d4 = self.d4_norm(self.d4_c(F.relu_(torch.cat([d3, e5], dim=1))))
163 | d5 = self.d5_norm(self.d5_c(F.relu_(torch.cat([d4, e4], dim=1))))
164 | d6 = self.d6_norm(self.d6_c(F.relu_(torch.cat([d5, e3], dim=1))))
165 | d7 = self.d7_norm(self.d7_c(F.relu_(torch.cat([d6, e2], dim=1))))
166 | # No norm on the last layer
167 | d8 = self.d8_c(F.relu_(torch.cat([d7, e1], 1)))
168 |
169 | d8 = torch.tanh(d8)
170 |
171 | return d8
172 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | #-*-coding:utf-8-*-
2 | from torch.nn import init
3 | from torch.optim import lr_scheduler
4 | from torchvision import models
5 |
6 |
7 | from .modules import *
8 |
9 | ###############################################################################
10 | # Functions
11 | ###############################################################################
12 | def get_norm_layer(norm_type='instance'):
13 | if norm_type == 'batch':
14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
15 | elif norm_type == 'instance':
16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=False)
17 | elif norm_type == 'switchable':
18 | norm_layer = functools.partial(SwitchNorm2d)
19 | elif norm_type == 'none':
20 | norm_layer = None
21 | else:
22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
23 | return norm_layer
24 |
25 |
26 | def get_scheduler(optimizer, opt):
27 | if opt.lr_policy == 'lambda':
28 | def lambda_rule(epoch):
29 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
30 | return lr_l
31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
32 | elif opt.lr_policy == 'step':
33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
34 | elif opt.lr_policy == 'plateau':
35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
36 | elif opt.lr_policy == 'cosine':
37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
38 | else:
39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
40 | return scheduler
41 |
42 |
43 | def init_weights(net, init_type='normal', gain=0.02):
44 | def init_func(m):
45 | classname = m.__class__.__name__
46 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
47 | if init_type == 'normal':
48 | init.normal_(m.weight.data, 0.0, gain)
49 | elif init_type == 'xavier':
50 | init.xavier_normal_(m.weight.data, gain=gain)
51 | elif init_type == 'kaiming':
52 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
53 | elif init_type == 'orthogonal':
54 | init.orthogonal_(m.weight.data, gain=gain)
55 | else:
56 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
57 | if hasattr(m, 'bias') and m.bias is not None:
58 | init.constant_(m.bias.data, 0.0)
59 | elif classname.find('BatchNorm2d') != -1:
60 | init.normal_(m.weight.data, 1.0, gain)
61 | init.constant_(m.bias.data, 0.0)
62 |
63 | print('initialize network with %s' % init_type)
64 | net.apply(init_func)
65 |
66 |
67 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
68 | if len(gpu_ids) > 0:
69 | assert(torch.cuda.is_available())
70 | net.to(gpu_ids[0])
71 | net = torch.nn.DataParallel(net, gpu_ids)
72 | init_weights(net, init_type, gain=init_gain)
73 | return net
74 |
75 |
76 | # Note: Adding SN to G tends to give inferior results. Need more checking.
77 | def define_G(input_nc, output_nc, ngf, which_model_netG, opt, mask_global, norm='batch', use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02):
78 | netG = None
79 | norm_layer = get_norm_layer(norm_type=norm)
80 |
81 | innerCos_list = []
82 | shift_list = []
83 |
84 | print('input_nc {}'.format(input_nc))
85 | print('output_nc {}'.format(output_nc))
86 | print('which_model_netG {}'.format(which_model_netG))
87 |
88 | # Here we need to initlize an artificial mask_global to construct the init model.
89 | # When training, we need to set mask for special layers(mostly for Shift layers) first.
90 | # If mask is fixed during training, we only need to set mask for these layers once,
91 | # else we need to set the masks each iteration, generating new random masks and mask the input
92 | # as well as setting masks for these special layers.
93 | print('[CREATED] MODEL')
94 | if which_model_netG == 'unet_256':
95 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
96 | elif which_model_netG == 'easy_unet_256':
97 | netG = EasyUnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
98 | elif which_model_netG == 'face_unet_shift_triple':
99 | netG = FaceUnetGenerator(input_nc, output_nc, innerCos_list, shift_list, mask_global, opt, \
100 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
101 | elif which_model_netG == 'unet_shift_triple':
102 | netG = UnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
103 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
104 | elif which_model_netG == 'res_unet_shift_triple':
105 | netG = ResUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
106 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
107 | elif which_model_netG == 'patch_soft_unet_shift_triple':
108 | netG = PatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
109 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
110 | elif which_model_netG == 'res_patch_soft_unet_shift_triple':
111 | netG = ResPatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
112 | ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
113 | else:
114 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
115 | print('[CREATED] MODEL')
116 | print('Constraint in netG:')
117 | print(innerCos_list)
118 |
119 | print('Shift in netG:')
120 | print(shift_list)
121 |
122 | print('NetG:')
123 | print(netG)
124 |
125 | return init_net(netG, init_type, init_gain, gpu_ids), innerCos_list, shift_list
126 |
127 |
128 | def define_D(input_nc, ndf, which_model_netD,
129 | n_layers_D=3, norm='batch', use_sigmoid=False, use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02):
130 | netD = None
131 | norm_layer = get_norm_layer(norm_type=norm)
132 |
133 | if which_model_netD == 'basic':
134 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
135 |
136 | elif which_model_netD == 'n_layers':
137 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
138 |
139 | elif which_model_netD == 'densenet':
140 | netD = DenseNetDiscrimator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
141 |
142 | else:
143 | print('Discriminator model name [%s] is not recognized' %
144 | which_model_netD)
145 |
146 | print('NetD:')
147 | print(netD)
148 | return init_net(netD, init_type, init_gain, gpu_ids)
149 |
150 |
--------------------------------------------------------------------------------
/models/patch_soft_shift/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/patch_soft_shift/__init__.py
--------------------------------------------------------------------------------
/models/patch_soft_shift/innerPatchSoftShiftTriple.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import util.util as util
4 | from .innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule
5 |
6 |
7 | # TODO: Make it compatible for show_flow.
8 | #
9 | class InnerPatchSoftShiftTriple(nn.Module):
10 | def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3):
11 | super(InnerPatchSoftShiftTriple, self).__init__()
12 |
13 | self.shift_sz = shift_sz
14 | self.stride = stride
15 | self.mask_thred = mask_thred
16 | self.triple_weight = triple_weight
17 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
18 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
19 | self.fuse = fuse
20 | self.layer_to_last = layer_to_last
21 | self.softShift = InnerPatchSoftShiftTripleModule()
22 |
23 | def set_mask(self, mask_global):
24 | mask = util.cal_feat_mask(mask_global, self.layer_to_last)
25 | self.mask = mask
26 | return self.mask
27 |
28 | # If mask changes, then need to set cal_fix_flag true each iteration.
29 | def forward(self, input):
30 | _, self.c, self.h, self.w = input.size()
31 |
32 | # Just pass self.mask in, instead of self.flag.
33 | final_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse)
34 | if self.show_flow:
35 | self.flow_srcs = self.softShift.get_flow_src()
36 | return final_out
37 |
38 | def get_flow(self):
39 | return self.flow_srcs
40 |
41 | def set_flow_true(self):
42 | self.show_flow = True
43 |
44 | def set_flow_false(self):
45 | self.show_flow = False
46 |
47 | def __repr__(self):
48 | return self.__class__.__name__+ '(' \
49 | + ' ,triple_weight ' + str(self.triple_weight) + ')'
50 |
--------------------------------------------------------------------------------
/models/patch_soft_shift/innerPatchSoftShiftTripleModule.py:
--------------------------------------------------------------------------------
1 | from util.NonparametricShift import Modified_NonparametricShift
2 | from torch.nn import functional as F
3 | import torch.nn as nn
4 | import torch
5 | import util.util as util
6 |
7 |
8 | class InnerPatchSoftShiftTripleModule(nn.Module):
9 | def forward(self, input, stride, triple_w, mask, mask_thred, shift_sz, show_flow, fuse=True):
10 | assert input.dim() == 4, "Input Dim has to be 4"
11 | assert mask.dim() == 4, "Mask Dim has to be 4"
12 | self.triple_w = triple_w
13 | self.mask = mask
14 | self.mask_thred = mask_thred
15 | self.show_flow = show_flow
16 |
17 | self.bz, self.c, self.h, self.w = input.size()
18 |
19 | self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.FloatTensor
20 |
21 | self.ind_lst = self.Tensor(self.bz, self.h * self.w, self.h * self.w).zero_()
22 |
23 | # former and latter are all tensors
24 | former_all = input.narrow(1, 0, self.c//2) ### decoder feature
25 | latter_all = input.narrow(1, self.c//2, self.c//2) ### encoder feature
26 | shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all) # addition feature
27 |
28 | self.mask = self.mask.to(input)
29 |
30 | # extract patches from latter.
31 | latter_all_pad = F.pad(latter_all, [shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2], 'constant', 0)
32 | latter_all_windows = latter_all_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride)
33 | latter_all_windows = latter_all_windows.contiguous().view(self.bz, -1, self.c//2, shift_sz, shift_sz)
34 |
35 | # Extract patches from mask
36 | # Mention: mask here must be 1*1*H*W
37 | m_pad = F.pad(self.mask, (shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2), 'constant', 0)
38 | m = m_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride)
39 | m = m.contiguous().view(self.bz, 1, -1, shift_sz, shift_sz)
40 |
41 | # It implements the similar functionality as `cal_flag_given_mask_thred`.
42 | # However, it differs what `mm` means.
43 | # Here mm: the masked reigon is filled with 0, nonmasked region is filled with 1.
44 | # While mm in `cal_flag_given_mask_thred`, it is opposite.
45 | m = torch.mean(torch.mean(m, dim=3, keepdim=True), dim=4, keepdim=True)
46 | mm = m.le(self.mask_thred/(1.*shift_sz**2)).float() # bz*1*(32*32)*1*1
47 |
48 | fuse_weight = torch.eye(shift_sz).view(1, 1, shift_sz, shift_sz).type_as(input)
49 |
50 | self.shift_offsets = []
51 | for idx in range(self.bz):
52 | mm_cur = mm[idx]
53 | # latter_win = latter_all_windows.narrow(0, idx, 1)[0]
54 | latter_win = latter_all_windows.narrow(0, idx, 1)[0]
55 | former = former_all.narrow(0, idx, 1)
56 |
57 | # normalize latter for each patch.
58 | latter_den = torch.sqrt(torch.einsum("bcij,bcij->b", [latter_win, latter_win]))
59 | latter_den = torch.max(latter_den, self.Tensor([1e-4]))
60 |
61 | latter_win_normed = latter_win/latter_den.view(-1, 1, 1, 1)
62 | y_i = F.conv2d(former, latter_win_normed, stride=1, padding=shift_sz//2)
63 |
64 | # conv implementation for fuse scores to encourage large patches
65 | if fuse:
66 | y_i = y_i.view(1, 1, self.h*self.w, self.h*self.w) # make all of depth of spatial resolution.
67 | y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1)
68 |
69 | y_i = y_i.contiguous().view(1, self.h, self.w, self.h, self.w)
70 | y_i = y_i.permute(0, 2, 1, 4, 3)
71 | y_i = y_i.contiguous().view(1, 1, self.h*self.w, self.h*self.w)
72 |
73 | y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1)
74 | y_i = y_i.contiguous().view(1, self.w, self.h, self.w, self.h)
75 | y_i = y_i.permute(0, 2, 1, 4, 3)
76 |
77 | y_i = y_i.contiguous().view(1, self.h*self.w, self.h, self.w) # 1*(32*32)*32*32
78 |
79 | # firstly, wash away the masked reigon.
80 | # multiply `mm` means (:, index_masked, :, :) will be 0.
81 | y_i = y_i * mm_cur
82 |
83 | # Then apply softmax to the nonmasked region.
84 | cosine = F.softmax(y_i*10, dim=1)
85 |
86 | # Finally, dummy parameters of masked reigon are filtered out.
87 | cosine = cosine * mm_cur
88 |
89 | # paste
90 | shift_i = F.conv_transpose2d(cosine, latter_win, stride=1, padding=shift_sz//2)/9.
91 | shift_masked_all[idx] = shift_i
92 |
93 | # Addition: show shift map
94 | # TODO: fix me.
95 | # cosine here is a full size of 32*32, not only the masked region in `shift_net`,
96 | # which results in non-direct reusing the code.
97 | # torch.set_printoptions(threshold=2015)
98 | # if self.show_flow:
99 | # _, indexes = torch.max(cosine, dim=1)
100 | # # calculate self.flag from self.m
101 | # self.flag = (1 - mm).view(-1)
102 | # torch.set_printoptions(threshold=1025)
103 | # print(self.flag)
104 | # non_mask_indexes = (self.flag == 0.).nonzero()
105 | # non_mask_indexes = non_mask_indexes[indexes]
106 | # print('ll')
107 | # print(non_mask_indexes.size())
108 | # print(non_mask_indexes)
109 | # # Here non_mask_index is too large, should be 192.
110 | # shift_offset = torch.stack([non_mask_indexes.squeeze() // self.w, non_mask_indexes.squeeze() % self.w], dim=-1)
111 | # print(shift_offset.size())
112 | # self.shift_offsets.append(shift_offset)
113 |
114 | # print('cc')
115 | # if self.show_flow:
116 | # # Note: Here we assume that each mask is the same for the same batch image.
117 | # self.shift_offsets = torch.cat(self.shift_offsets, dim=0).float() # make it cudaFloatTensor
118 | # # Assume mask is the same for each image in a batch.
119 | # mask_nums = self.shift_offsets.size(0)//self.bz
120 | # self.flow_srcs = torch.zeros(self.bz, 3, self.h, self.w).type_as(input)
121 |
122 | # for idx in range(self.bz):
123 | # shift_offset = self.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
124 | # # reconstruct the original shift_map.
125 | # shift_offsets_map = torch.zeros(1, self.h, self.w, 2).type_as(input)
126 | # print(shift_offsets_map.size())
127 | # print(shift_offset.unsqueeze(0).size())
128 |
129 | # print(shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :].size())
130 | # shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :] = \
131 | # shift_offset.unsqueeze(0)
132 | # # It is indicating the pixels(non-masked) that will shift the the masked region.
133 | # flow_src = util.highlight_flow(shift_offsets_map, self.flag.unsqueeze(0))
134 | # self.flow_srcs[idx] = flow_src
135 |
136 | return torch.cat((former_all, latter_all, shift_masked_all), 1)
137 |
138 | def get_flow_src(self):
139 | return self.flow_srcs
140 |
--------------------------------------------------------------------------------
/models/patch_soft_shift/patch_soft_shiftnet_model.py:
--------------------------------------------------------------------------------
1 | from models.shift_net.shiftnet_model import ShiftNetModel
2 |
3 |
4 | class PatchSoftShiftNetModel(ShiftNetModel):
5 | def name(self):
6 | return 'PatchSoftShiftNetModel'
7 |
--------------------------------------------------------------------------------
/models/res_patch_soft_shift/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/res_patch_soft_shift/__init__.py
--------------------------------------------------------------------------------
/models/res_patch_soft_shift/innerResPatchSoftShiftTriple.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import util.util as util
4 |
5 | from models.patch_soft_shift.innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule
6 |
7 |
8 | # TODO: Make it compatible for show_flow.
9 | #
10 | class InnerResPatchSoftShiftTriple(nn.Module):
11 | def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3):
12 | super(InnerResPatchSoftShiftTriple, self).__init__()
13 |
14 | self.shift_sz = shift_sz
15 | self.stride = stride
16 | self.mask_thred = mask_thred
17 | self.triple_weight = triple_weight
18 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
19 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
20 | self.fuse = fuse
21 | self.layer_to_last = layer_to_last
22 | self.softShift = InnerPatchSoftShiftTripleModule()
23 |
24 | # Additional for ResShift.
25 | self.inner_nc = inner_nc
26 | self.res_net = nn.Sequential(
27 | nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1),
28 | nn.InstanceNorm2d(inner_nc),
29 | nn.ReLU(True),
30 | nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1),
31 | nn.InstanceNorm2d(inner_nc)
32 | )
33 |
34 | def set_mask(self, mask_global):
35 | mask = util.cal_feat_mask(mask_global, self.layer_to_last)
36 | self.mask = mask
37 | return self.mask
38 |
39 | # If mask changes, then need to set cal_fix_flag true each iteration.
40 | def forward(self, input):
41 | _, self.c, self.h, self.w = input.size()
42 |
43 | # Just pass self.mask in, instead of self.flag.
44 | # Try to making it faster by avoiding `cal_flag_given_mask_thread`.
45 | shift_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse)
46 |
47 | c_out = shift_out.size(1)
48 | # get F_c, F_s, F_shift
49 | F_c = shift_out.narrow(1, 0, c_out//3)
50 | F_s = shift_out.narrow(1, c_out//3, c_out//3)
51 | F_shift = shift_out.narrow(1, c_out*2//3, c_out//3)
52 | F_fuse = F_c * F_shift
53 | F_com = torch.cat([F_c, F_fuse], dim=1)
54 |
55 | res_out = self.res_net(F_com)
56 | F_c = F_c + res_out
57 |
58 | final_out = torch.cat([F_c, F_s], dim=1)
59 |
60 | if self.show_flow:
61 | self.flow_srcs = self.softShift.get_flow_src()
62 | return final_out
63 |
64 | def get_flow(self):
65 | return self.flow_srcs
66 |
67 | def set_flow_true(self):
68 | self.show_flow = True
69 |
70 | def set_flow_false(self):
71 | self.show_flow = False
72 |
73 | def __repr__(self):
74 | return self.__class__.__name__+ '(' \
75 | + ' ,triple_weight ' + str(self.triple_weight) + ')'
76 |
--------------------------------------------------------------------------------
/models/res_patch_soft_shift/res_patch_soft_shiftnet_model.py:
--------------------------------------------------------------------------------
1 | from models.shift_net.shiftnet_model import ShiftNetModel
2 |
3 |
4 | class ResPatchSoftShiftNetModel(ShiftNetModel):
5 | def name(self):
6 | return 'ResPatchSoftShiftNetModel'
7 |
--------------------------------------------------------------------------------
/models/res_shift_net/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/res_shift_net/__init__.py
--------------------------------------------------------------------------------
/models/res_shift_net/innerResShiftTriple.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import util.util as util
4 |
5 | from models.shift_net.InnerShiftTripleFunction import InnerShiftTripleFunction
6 |
7 | class InnerResShiftTriple(nn.Module):
8 | def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3):
9 | super(InnerResShiftTriple, self).__init__()
10 |
11 | self.shift_sz = shift_sz
12 | self.stride = stride
13 | self.mask_thred = mask_thred
14 | self.triple_weight = triple_weight
15 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
16 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
17 | self.layer_to_last = layer_to_last
18 |
19 | # Additional for ResShift.
20 | self.inner_nc = inner_nc
21 | self.res_net = nn.Sequential(
22 | nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1),
23 | nn.InstanceNorm2d(inner_nc),
24 | nn.ReLU(True),
25 | nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1),
26 | nn.InstanceNorm2d(inner_nc)
27 | )
28 |
29 |
30 | def set_mask(self, mask_global):
31 | mask = util.cal_feat_mask(mask_global, self.layer_to_last)
32 | self.mask = mask.squeeze()
33 | return self.mask
34 |
35 | # If mask changes, then need to set cal_fix_flag true each iteration.
36 | def forward(self, input):
37 | #print(input.shape)
38 | _, self.c, self.h, self.w = input.size()
39 | self.flag = util.cal_flag_given_mask_thred(self.mask, self.shift_sz, self.stride, self.mask_thred)
40 | shift_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
41 |
42 | c_out = shift_out.size(1)
43 | # get F_c, F_s, F_shift
44 | F_c = shift_out.narrow(1, 0, c_out//3)
45 | F_s = shift_out.narrow(1, c_out//3, c_out//3)
46 | F_shift = shift_out.narrow(1, c_out*2//3, c_out//3)
47 | F_fuse = F_c * F_shift
48 | F_com = torch.cat([F_c, F_fuse], dim=1)
49 |
50 | res_out = self.res_net(F_com)
51 | F_c = F_c + res_out
52 |
53 | final_out = torch.cat([F_c, F_s], dim=1)
54 |
55 | if self.show_flow:
56 | self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
57 | return final_out
58 |
59 | def get_flow(self):
60 | return self.flow_srcs
61 |
62 | def set_flow_true(self):
63 | self.show_flow = True
64 |
65 | def set_flow_false(self):
66 | self.show_flow = False
67 |
68 | def __repr__(self):
69 | return self.__class__.__name__+ '(' \
70 | + ' ,triple_weight ' + str(self.triple_weight) + ')'
71 |
--------------------------------------------------------------------------------
/models/res_shift_net/shiftnet_model.py:
--------------------------------------------------------------------------------
1 | from models.shift_net.shiftnet_model import ShiftNetModel
2 |
3 |
4 | class ResShiftNetModel(ShiftNetModel):
5 | def name(self):
6 | return 'ResShiftNetModel'
--------------------------------------------------------------------------------
/models/shift_net/InnerCos.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import util.util as util
5 | from .InnerCosFunction import InnerCosFunction
6 |
7 | class InnerCos(nn.Module):
8 | def __init__(self, crit='MSE', strength=1, skip=0, layer_to_last=3, device='gpu'):
9 | super(InnerCos, self).__init__()
10 | self.crit = crit
11 | self.criterion = torch.nn.MSELoss() if self.crit == 'MSE' else torch.nn.L1Loss()
12 | self.strength = strength
13 | # To define whether this layer is skipped.
14 | self.skip = skip
15 | self.layer_to_last = layer_to_last
16 | self.device = device
17 | # Init a dummy value is fine.
18 | self.target = torch.tensor(1.0)
19 |
20 | def set_mask(self, mask_global):
21 | mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
22 | self.mask_all = mask_all.float()
23 |
24 |
25 | def _split_mask(self, cur_bsize):
26 | # get the visible indexes of gpus and assign correct mask to set of images
27 | cur_device = torch.cuda.current_device()
28 | self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
29 |
30 | def forward(self, in_data):
31 | self.bz, self.c, _, _ = in_data.size()
32 | if self.device != 'cpu':
33 | self._split_mask(self.bz)
34 | else:
35 | self.cur_mask = self.mask_all
36 | self.cur_mask = self.cur_mask.to(in_data)
37 | if not self.skip:
38 | # It works like this:
39 | # Each iteration contains 2 forward passes, In the first forward pass, we input a GT image, just to get the target.
40 | # In the second forward pass, we input the corresponding corrupted image, then back-propagate the network, the guidance loss works as expected.
41 | self.output = InnerCosFunction.apply(in_data, self.criterion, self.strength, self.target, self.cur_mask)
42 | self.target = in_data.narrow(1, self.c // 2, self.c // 2).detach() # the latter part
43 | else:
44 | self.output = in_data
45 | return self.output
46 |
47 |
48 | def __repr__(self):
49 | skip_str = 'True' if not self.skip else 'False'
50 | return self.__class__.__name__+ '(' \
51 | + 'skip: ' + skip_str \
52 | + 'layer ' + str(self.layer_to_last) + ' to last' \
53 | + ' ,strength: ' + str(self.strength) + ')'
54 |
--------------------------------------------------------------------------------
/models/shift_net/InnerCosFunction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class InnerCosFunction(torch.autograd.Function):
5 |
6 | @staticmethod
7 | def forward(ctx, input, criterion, strength, target, mask):
8 | ctx.c = input.size(1)
9 | ctx.strength = strength
10 | ctx.criterion = criterion
11 | if len(target.size()) == 0: # For the first iteration.
12 | target = target.expand_as(input.narrow(1, ctx.c // 2, ctx.c // 2)).type_as(input)
13 |
14 | ctx.save_for_backward(input, target, mask)
15 | return input
16 |
17 |
18 | @staticmethod
19 | def backward(ctx, grad_output):
20 |
21 | with torch.enable_grad():
22 | input, target, mask = ctx.saved_tensors
23 | former = input.narrow(1, 0, ctx.c//2)
24 | former_in_mask = torch.mul(former, mask)
25 | if former_in_mask.size() != target.size(): # For the last iteration of one epoch
26 | target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
27 |
28 | former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
29 | ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
30 | ctx.loss.backward()
31 |
32 | grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad
33 |
34 | return grad_output, None, None, None, None
--------------------------------------------------------------------------------
/models/shift_net/InnerShiftTriple.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import util.util as util
4 | from .InnerShiftTripleFunction import InnerShiftTripleFunction
5 |
6 | class InnerShiftTriple(nn.Module):
7 | def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3, device='gpu'):
8 | super(InnerShiftTriple, self).__init__()
9 |
10 | self.shift_sz = shift_sz
11 | self.stride = stride
12 | self.mask_thred = mask_thred
13 | self.triple_weight = triple_weight
14 | self.layer_to_last = layer_to_last
15 | self.device = device
16 | self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
17 | self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
18 |
19 |
20 | def set_mask(self, mask_global):
21 | self.mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
22 |
23 | def _split_mask(self, cur_bsize):
24 | # get the visible indexes of gpus and assign correct mask to set of images
25 | cur_device = torch.cuda.current_device()
26 | self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
27 |
28 | # If mask changes, then need to set cal_fix_flag true each iteration.
29 | def forward(self, input):
30 | self.bz, self.c, self.h, self.w = input.size()
31 | if self.device != 'cpu':
32 | self._split_mask(self.bz)
33 | else:
34 | self.cur_mask = self.mask_all
35 | self.flag = util.cal_flag_given_mask_thred(self.cur_mask, self.shift_sz, self.stride, self.mask_thred)
36 | final_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
37 | if self.show_flow:
38 | self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
39 | return final_out
40 |
41 | def get_flow(self):
42 | return self.flow_srcs
43 |
44 | def set_flow_true(self):
45 | self.show_flow = True
46 |
47 | def set_flow_false(self):
48 | self.show_flow = False
49 |
50 | def __repr__(self):
51 | return self.__class__.__name__+ '(' \
52 | + ' ,triple_weight ' + str(self.triple_weight) + ')'
53 |
--------------------------------------------------------------------------------
/models/shift_net/InnerShiftTripleFunction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift
3 | import torch
4 | import util.util as util
5 | import time
6 |
7 |
8 | class InnerShiftTripleFunction(torch.autograd.Function):
9 | ctx = None
10 |
11 | @staticmethod
12 | def forward(ctx, input, shift_sz, stride, triple_w, flag, show_flow):
13 | InnerShiftTripleFunction.ctx = ctx
14 | assert input.dim() == 4, "Input Dim has to be 4"
15 | ctx.triple_w = triple_w
16 | ctx.flag = flag
17 | ctx.show_flow = show_flow
18 |
19 | ctx.bz, c_real, ctx.h, ctx.w = input.size()
20 | c = c_real
21 |
22 | ctx.ind_lst = torch.Tensor(ctx.bz, ctx.h * ctx.w, ctx.h * ctx.w).zero_().to(input)
23 |
24 | # former and latter are all tensors
25 | former_all = input.narrow(1, 0, c//2) ### decoder feature
26 | latter_all = input.narrow(1, c//2, c//2) ### encoder feature
27 | shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all).zero_() # addition feature
28 |
29 | ctx.flag = ctx.flag.to(input).long()
30 |
31 | # None batch version
32 | bNonparm = Batch_NonShift()
33 | ctx.shift_offsets = []
34 |
35 | # batch version
36 | cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former_all.clone(), latter_all.clone(), 1, stride, flag)
37 |
38 | _, indexes = torch.max(cosine, dim=2)
39 |
40 | mask_indexes = (flag==1).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1)
41 |
42 | non_mask_indexes = (flag==0).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1).gather(1, indexes)
43 |
44 | idx_b = torch.arange(ctx.bz).long().unsqueeze(1).expand(ctx.bz, mask_indexes.size(1))
45 | # set the elemnets of indexed by [mask_indexes, non_mask_indexes] to 1.
46 | # It is a batch version
47 | ctx.ind_lst[(idx_b, mask_indexes, non_mask_indexes)] = 1
48 |
49 | shift_masked_all = bNonparm._paste(latter_windows, ctx.ind_lst, i_2, i_3, i_1)
50 |
51 |
52 | # --- Non-batch version ----
53 | #for idx in range(ctx.bz):
54 | # flag_cur = ctx.flag[idx]
55 | # latter = latter_all.narrow(0, idx, 1) ### encoder feature
56 | # former = former_all.narrow(0, idx, 1) ### decoder feature
57 |
58 | # #GET COSINE, RESHAPED LATTER AND ITS INDEXES
59 | # cosine, latter_windows, i_2, i_3, i_1 = Nonparm.cosine_similarity(former.clone().squeeze(), latter.clone().squeeze(), 1, stride, flag_cur)
60 |
61 | # ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
62 | # _, indexes = torch.max(cosine, dim=1)
63 |
64 | # # SET TRANSITION MATRIX
65 | # mask_indexes = (flag_cur == 1).nonzero()
66 | # non_mask_indexes = (flag_cur == 0).nonzero()[indexes]
67 | # ctx.ind_lst[idx][mask_indexes, non_mask_indexes] = 1
68 |
69 | # # GET FINAL SHIFT FEATURE
70 | # shift_masked_all[idx] = Nonparm._paste(latter_windows, ctx.ind_lst[idx], i_2, i_3, i_1)
71 |
72 | # if ctx.show_flow:
73 | # shift_offset = torch.stack([non_mask_indexes.squeeze() // ctx.w, non_mask_indexes.squeeze() % ctx.w], dim=-1)
74 | # ctx.shift_offsets.append(shift_offset)
75 |
76 | if ctx.show_flow:
77 | assert 1==2, "I do not want maintance the functionality of `show flow`... ^_^"
78 | ctx.shift_offsets = torch.cat(ctx.shift_offsets, dim=0).float() # make it cudaFloatTensor
79 | # Assume mask is the same for each image in a batch.
80 | mask_nums = ctx.shift_offsets.size(0)//ctx.bz
81 | ctx.flow_srcs = torch.zeros(ctx.bz, 3, ctx.h, ctx.w).type_as(input)
82 |
83 | for idx in range(ctx.bz):
84 | shift_offset = ctx.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
85 | # reconstruct the original shift_map.
86 | shift_offsets_map = torch.zeros(1, ctx.h, ctx.w, 2).type_as(input)
87 | shift_offsets_map[:, (flag_cur == 1).nonzero(as_tuple=False).squeeze() // ctx.w, (flag_cur == 1).nonzero(as_tuple=False).squeeze() % ctx.w, :] = \
88 | shift_offset.unsqueeze(0)
89 | # It is indicating the pixels(non-masked) that will shift the the masked region.
90 | flow_src = util.highlight_flow(shift_offsets_map, flag_cur.unsqueeze(0))
91 | ctx.flow_srcs[idx] = flow_src
92 |
93 | return torch.cat((former_all, latter_all, shift_masked_all), 1)
94 |
95 |
96 | @staticmethod
97 | def get_flow_src():
98 | return InnerShiftTripleFunction.ctx.flow_srcs
99 |
100 | @staticmethod
101 | def backward(ctx, grad_output):
102 | ind_lst = ctx.ind_lst
103 |
104 | c = grad_output.size(1)
105 |
106 | # # the former and the latter are keep original. Only the thrid part is shifted.
107 | # C: content, pixels in masked region of the former part.
108 | # S: style, pixels in the non-masked region of the latter part.
109 | # N: the shifted feature, the new feature that will be used as the third part of features maps.
110 | # W_mat: ind_lst[idx], shift matrix.
111 | # Note: **only the masked region in N has values**.
112 |
113 | # The gradient of shift feature should be added back to the latter part(to be precise: S).
114 | # `ind_lst[idx][i,j] = 1` means that the i_th pixel will **be replaced** by j_th pixel in the forward.
115 | # When applying `S mm W_mat`, then S will be transfer to N.
116 | # (pixels in non-masked region of the latter part will be shift to the masked region in the third part.)
117 | # However, we need to transfer back the gradient of the third part to S.
118 | # This means the graident in S will **`be replaced`(to be precise, enhanced)** by N.
119 | grad_former_all = grad_output[:, 0:c//3, :, :]
120 | grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
121 | grad_shifted_all = grad_output[:, c*2//3:c, :, :].clone()
122 |
123 | W_mat_t = ind_lst.permute(0, 2, 1).contiguous()
124 | grad = grad_shifted_all.view(ctx.bz, c//3, -1).permute(0, 2, 1)
125 | grad_shifted_weighted = torch.bmm(W_mat_t, grad)
126 | grad_shifted_weighted = grad_shifted_weighted.permute(0, 2, 1).contiguous().view(ctx.bz, c//3, ctx.h, ctx.w)
127 | grad_latter_all = torch.add(grad_latter_all, grad_shifted_weighted.mul(ctx.triple_w))
128 |
129 | # ----- 'Non_batch version here' --------------------
130 | # for idx in range(ctx.bz):
131 | # # So we need to transpose `W_mat`
132 | # W_mat_t = ind_lst[idx].t()
133 |
134 | # grad = grad_shifted_all[idx].view(c//3, -1).t()
135 |
136 | # grad_shifted_weighted = torch.mm(W_mat_t, grad)
137 |
138 | # # Then transpose it back
139 | # grad_shifted_weighted = grad_shifted_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
140 | # grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_shifted_weighted.mul(ctx.triple_w))
141 |
142 | # note the input channel and the output channel are all c, as no mask input for now.
143 | grad_input = torch.cat([grad_former_all, grad_latter_all], 1)
144 |
145 | return grad_input, None, None, None, None, None, None
146 |
--------------------------------------------------------------------------------
/models/shift_net/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/models/shift_net/__init__.py
--------------------------------------------------------------------------------
/models/shift_net/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 |
5 |
6 | class BaseModel():
7 | def name(self):
8 | return 'BaseModel'
9 |
10 | def initialize(self, opt):
11 | self.opt = opt
12 | self.gpu_ids = opt.gpu_ids
13 | self.isTrain = opt.isTrain
14 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
16 | if opt.resize_or_crop != 'scale_width':
17 | torch.backends.cudnn.benchmark = True
18 | self.loss_names = []
19 | self.model_names = []
20 | self.visual_names = []
21 | self.image_paths = []
22 |
23 | def set_input(self, input):
24 | self.input = input
25 |
26 | def forward(self):
27 | pass
28 |
29 | # used in test time, wrapping `forward` in no_grad() so we don't save
30 | # intermediate steps for backprop
31 | def test(self):
32 | with torch.no_grad():
33 | self.forward()
34 |
35 | # get image paths
36 | def get_image_paths(self):
37 | return self.image_paths
38 |
39 | def optimize_parameters(self):
40 | pass
41 |
42 | # update learning rate (called once every epoch)
43 | def update_learning_rate(self):
44 | for scheduler in self.schedulers:
45 | scheduler.step()
46 | lr = self.optimizers[0].param_groups[0]['lr']
47 | print('learning rate = %.7f' % lr)
48 |
49 | # return visualization images. train.py will display these images, and save the images to a html
50 | def get_current_visuals(self):
51 | visual_ret = OrderedDict()
52 | for name in self.visual_names:
53 | if isinstance(name, str):
54 | visual_ret[name] = getattr(self, name)
55 | return visual_ret
56 |
57 | # return traning losses/errors. train.py will print out these errors as debugging information
58 | def get_current_losses(self):
59 | errors_ret = OrderedDict()
60 | for name in self.loss_names:
61 | if isinstance(name, str):
62 | # float(...) works for both scalar tensor and float number
63 | errors_ret[name] = float(getattr(self, 'loss_' + name))
64 | return errors_ret
65 |
66 | # save models to the disk
67 | def save_networks(self, which_epoch):
68 | for name in self.model_names:
69 | if isinstance(name, str):
70 | save_filename = '%s_net_%s.pth' % (which_epoch, name)
71 | save_path = os.path.join(self.save_dir, save_filename)
72 | net = getattr(self, 'net' + name)
73 |
74 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
75 | torch.save(net.module.cpu().state_dict(), save_path)
76 | net.cuda(self.gpu_ids[0])
77 | else:
78 | torch.save(net.cpu().state_dict(), save_path)
79 |
80 |
81 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
82 | key = keys[i]
83 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
84 | if module.__class__.__name__.startswith('InstanceNorm') and \
85 | (key == 'running_mean' or key == 'running_var'):
86 | if getattr(module, key) is None:
87 | state_dict.pop('.'.join(keys))
88 | else:
89 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
90 |
91 | # load models from the disk
92 | def load_networks(self, which_epoch):
93 | for name in self.model_names:
94 | if isinstance(name, str):
95 | load_filename = '%s_net_%s.pth' % (which_epoch, name)
96 | load_path = os.path.join(self.save_dir, load_filename)
97 | net = getattr(self, 'net' + name)
98 | if isinstance(net, torch.nn.DataParallel):
99 | net = net.module
100 | # if you are using PyTorch newer than 0.4 (e.g., built from
101 | # GitHub source), you can remove str() on self.device
102 | state_dict = torch.load(load_path, map_location=str(self.device))
103 | # patch InstanceNorm checkpoints prior to 0.4
104 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
105 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
106 | net.load_state_dict(state_dict)
107 |
108 | # print network information
109 | def print_networks(self, verbose):
110 | print('---------- Networks initialized -------------')
111 | for name in self.model_names:
112 | if isinstance(name, str):
113 | net = getattr(self, 'net' + name)
114 | num_params = 0
115 | for param in net.parameters():
116 | num_params += param.numel()
117 | if verbose:
118 | print(net)
119 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
120 | print('-----------------------------------------------')
121 |
122 | def set_requires_grad(self, nets, requires_grad=False):
123 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
124 | Parameters:
125 | nets (network list) -- a list of networks
126 | requires_grad (bool) -- whether the networks require gradients or not
127 | """
128 | if not isinstance(nets, list):
129 | nets = [nets]
130 | for net in nets:
131 | if net is not None:
132 | for param in net.parameters():
133 | param.requires_grad = requires_grad
--------------------------------------------------------------------------------
/models/shift_net/shiftnet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | import util.util as util
4 | from models import networks
5 | from models.shift_net.base_model import BaseModel
6 | import time
7 | import torchvision.transforms as transforms
8 | import os
9 | import numpy as np
10 | from PIL import Image
11 |
12 | class ShiftNetModel(BaseModel):
13 | def name(self):
14 | return 'ShiftNetModel'
15 |
16 |
17 | def create_random_mask(self):
18 | if self.opt.mask_type == 'random':
19 | if self.opt.mask_sub_type == 'fractal':
20 | assert 1==2, "It is broken somehow, use another mask_sub_type please"
21 | mask = util.create_walking_mask() # create an initial random mask.
22 |
23 | elif self.opt.mask_sub_type == 'rect':
24 | mask, rand_t, rand_l = util.create_rand_mask(self.opt)
25 | self.rand_t = rand_t
26 | self.rand_l = rand_l
27 | return mask
28 |
29 | elif self.opt.mask_sub_type == 'island':
30 | mask = util.wrapper_gmask(self.opt)
31 | return mask
32 |
33 | def initialize(self, opt):
34 | BaseModel.initialize(self, opt)
35 | self.opt = opt
36 | self.isTrain = opt.isTrain
37 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
38 | self.loss_names = ['G_GAN', 'G_L1', 'D', 'style', 'content', 'tv']
39 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
40 | if self.opt.show_flow:
41 | self.visual_names = ['real_A', 'fake_B', 'real_B', 'flow_srcs']
42 | else:
43 | self.visual_names = ['real_A', 'fake_B', 'real_B']
44 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
45 | if self.isTrain:
46 | self.model_names = ['G', 'D']
47 | else: # during test time, only load Gs
48 | self.model_names = ['G']
49 |
50 |
51 | # batchsize should be 1 for mask_global
52 | self.mask_global = torch.zeros((self.opt.batchSize, 1, \
53 | opt.fineSize, opt.fineSize), dtype=torch.bool)
54 |
55 | # Here we need to set an artificial mask_global(center hole is ok.)
56 | self.mask_global.zero_()
57 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
58 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
59 |
60 | if len(opt.gpu_ids) > 0:
61 | self.mask_global = self.mask_global.to(self.device)
62 |
63 | # load/define networks
64 | # self.ng_innerCos_list is the guidance loss list in netG inner layers.
65 | # self.ng_shift_list is the mask list constructing shift operation.
66 | if opt.add_mask2input:
67 | input_nc = opt.input_nc + 1
68 | else:
69 | input_nc = opt.input_nc
70 |
71 | self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(input_nc, opt.output_nc, opt.ngf,
72 | opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_spectral_norm_G, opt.init_type, self.gpu_ids, opt.init_gain)
73 |
74 | if self.isTrain:
75 | use_sigmoid = False
76 | if opt.gan_type == 'vanilla':
77 | use_sigmoid = True # only vanilla GAN using BCECriterion
78 | # don't use cGAN
79 | self.netD = networks.define_D(opt.input_nc, opt.ndf,
80 | opt.which_model_netD,
81 | opt.n_layers_D, opt.norm, use_sigmoid, opt.use_spectral_norm_D, opt.init_type, self.gpu_ids, opt.init_gain)
82 |
83 | # add style extractor
84 | self.vgg16_extractor = util.VGG16FeatureExtractor()
85 | if len(opt.gpu_ids) > 0:
86 | self.vgg16_extractor = self.vgg16_extractor.to(self.gpu_ids[0])
87 | self.vgg16_extractor = torch.nn.DataParallel(self.vgg16_extractor, self.gpu_ids)
88 |
89 | if self.isTrain:
90 | self.old_lr = opt.lr
91 | # define loss functions
92 | self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
93 | self.criterionL1 = torch.nn.L1Loss()
94 | self.criterionL1_mask = networks.Discounted_L1(opt).to(self.device) # make weights/buffers transfer to the correct device
95 | # VGG loss
96 | self.criterionL2_style_loss = torch.nn.MSELoss()
97 | self.criterionL2_content_loss = torch.nn.MSELoss()
98 | # TV loss
99 | self.tv_criterion = networks.TVLoss(self.opt.tv_weight)
100 |
101 | # initialize optimizers
102 | self.schedulers = []
103 | self.optimizers = []
104 | if self.opt.gan_type == 'wgan_gp':
105 | opt.beta1 = 0
106 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
107 | lr=opt.lr, betas=(opt.beta1, 0.9))
108 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
109 | lr=opt.lr, betas=(opt.beta1, 0.9))
110 | else:
111 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
112 | lr=opt.lr, betas=(opt.beta1, 0.999))
113 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
114 | lr=opt.lr, betas=(opt.beta1, 0.999))
115 | self.optimizers.append(self.optimizer_G)
116 | self.optimizers.append(self.optimizer_D)
117 | for optimizer in self.optimizers:
118 | self.schedulers.append(networks.get_scheduler(optimizer, opt))
119 |
120 | if not self.isTrain or opt.continue_train:
121 | self.load_networks(opt.which_epoch)
122 |
123 | self.print_networks(opt.verbose)
124 |
125 | def set_input(self, input):
126 | self.image_paths = input['A_paths']
127 | real_A = input['A'].to(self.device)
128 | real_B = input['B'].to(self.device)
129 | # directly load mask offline
130 | self.mask_global = input['M'].to(self.device).byte()
131 | self.mask_global = self.mask_global.narrow(1,0,1).bool()
132 |
133 | # create mask online
134 | if not self.opt.offline_loading_mask:
135 | if self.opt.mask_type == 'center':
136 | self.mask_global.zero_()
137 | self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
138 | int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
139 | self.rand_t, self.rand_l = int(self.opt.fineSize/4) + self.opt.overlap, int(self.opt.fineSize/4) + self.opt.overlap
140 | elif self.opt.mask_type == 'random':
141 | self.mask_global = self.create_random_mask().type_as(self.mask_global).view(1, *self.mask_global.size()[-3:])
142 | # As generating random masks online are computation-heavy
143 | # So just generate one ranodm mask for a batch images.
144 | self.mask_global = self.mask_global.expand(self.opt.batchSize, *self.mask_global.size()[-3:])
145 | else:
146 | raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type)
147 | # For loading mask offline, we also need to change 'opt.mask_type' and 'opt.mask_sub_type'
148 | # to avoid forgetting such settings.
149 | else:
150 | self.opt.mask_type = 'random'
151 | self.opt.mask_sub_type = 'island'
152 |
153 | self.set_latent_mask(self.mask_global)
154 |
155 | real_A.narrow(1,0,1).masked_fill_(self.mask_global, 0.)#2*123.0/255.0 - 1.0
156 | real_A.narrow(1,1,1).masked_fill_(self.mask_global, 0.)#2*104.0/255.0 - 1.0
157 | real_A.narrow(1,2,1).masked_fill_(self.mask_global, 0.)#2*117.0/255.0 - 1.0
158 |
159 | if self.opt.add_mask2input:
160 | # make it 4 dimensions.
161 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1.
162 | real_A = torch.cat((real_A, (~self.mask_global).expand(real_A.size(0), 1, real_A.size(2), real_A.size(3)).type_as(real_A)), dim=1)
163 |
164 | self.real_A = real_A
165 | self.real_B = real_B
166 |
167 |
168 | def set_latent_mask(self, mask_global):
169 | for ng_shift in self.ng_shift_list: # ITERATE OVER THE LIST OF ng_shift_list
170 | ng_shift.set_mask(mask_global)
171 | for ng_innerCos in self.ng_innerCos_list: # ITERATE OVER THE LIST OF ng_innerCos_list:
172 | ng_innerCos.set_mask(mask_global)
173 |
174 | def set_gt_latent(self):
175 | if not self.opt.skip:
176 | if self.opt.add_mask2input:
177 | # make it 4 dimensions.
178 | # Mention: the extra dim, the masked part is filled with 0, non-mask part is filled with 1.
179 | real_B = torch.cat([self.real_B, (~self.mask_global).expand(self.real_B.size(0), 1, self.real_B.size(2), self.real_B.size(3)).type_as(self.real_B)], dim=1)
180 | else:
181 | real_B = self.real_B
182 | self.netG(real_B) # input ground truth
183 |
184 |
185 | def forward(self):
186 | self.set_gt_latent()
187 | self.fake_B = self.netG(self.real_A)
188 |
189 | # Just assume one shift layer.
190 | def set_flow_src(self):
191 | self.flow_srcs = self.ng_shift_list[0].get_flow()
192 | self.flow_srcs = F.interpolate(self.flow_srcs, scale_factor=8, mode='nearest')
193 | # Just to avoid forgetting setting show_map_false
194 | self.set_show_map_false()
195 |
196 | # Just assume one shift layer.
197 | def set_show_map_true(self):
198 | self.ng_shift_list[0].set_flow_true()
199 |
200 | def set_show_map_false(self):
201 | self.ng_shift_list[0].set_flow_false()
202 |
203 | def get_image_paths(self):
204 | return self.image_paths
205 |
206 | def backward_D(self):
207 | fake_B = self.fake_B
208 | # Real
209 | real_B = self.real_B # GroundTruth
210 |
211 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results.
212 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
213 | # Using the cropped fake_B as the input of D.
214 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
215 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
216 |
217 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
218 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
219 |
220 | self.pred_fake = self.netD(fake_B.detach())
221 | self.pred_real = self.netD(real_B)
222 |
223 | if self.opt.gan_type == 'wgan_gp':
224 | gradient_penalty, _ = util.cal_gradient_penalty(self.netD, real_B, fake_B.detach(), self.device, constant=1, lambda_gp=self.opt.gp_lambda)
225 | self.loss_D_fake = torch.mean(self.pred_fake)
226 | self.loss_D_real = -torch.mean(self.pred_real)
227 |
228 | self.loss_D = self.loss_D_fake + self.loss_D_real + gradient_penalty
229 | else:
230 | if self.opt.gan_type in ['vanilla', 'lsgan']:
231 | self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
232 | self.loss_D_real = self.criterionGAN (self.pred_real, True)
233 |
234 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
235 |
236 | elif self.opt.gan_type == 're_s_gan':
237 | self.loss_D = self.criterionGAN(self.pred_real - self.pred_fake, True)
238 |
239 | self.loss_D.backward()
240 |
241 |
242 | def backward_G(self):
243 | # First, G(A) should fake the discriminator
244 | fake_B = self.fake_B
245 | # Has been verfied, for square mask, let D discrinate masked patch, improves the results.
246 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
247 | # Using the cropped fake_B as the input of D.
248 | fake_B = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
249 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
250 | real_B = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
251 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
252 | else:
253 | real_B = self.real_B
254 |
255 | pred_fake = self.netD(fake_B)
256 |
257 |
258 | if self.opt.gan_type == 'wgan_gp':
259 | self.loss_G_GAN = -torch.mean(pred_fake)
260 | else:
261 | if self.opt.gan_type in ['vanilla', 'lsgan']:
262 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) * self.opt.gan_weight
263 |
264 | elif self.opt.gan_type == 're_s_gan':
265 | pred_real = self.netD (real_B)
266 | self.loss_G_GAN = self.criterionGAN (pred_fake - pred_real, True) * self.opt.gan_weight
267 |
268 | elif self.opt.gan_type == 're_avg_gan':
269 | self.pred_real = self.netD(real_B)
270 | self.loss_G_GAN = (self.criterionGAN (self.pred_real - torch.mean(self.pred_fake), False) \
271 | + self.criterionGAN (self.pred_fake - torch.mean(self.pred_real), True)) / 2.
272 | self.loss_G_GAN *= self.opt.gan_weight
273 |
274 |
275 | # If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'.
276 | self.loss_G_L1, self.loss_G_L1_m = 0, 0
277 | self.loss_G_L1 += self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
278 | # calcuate mask construction loss
279 | # When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1).
280 | # Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1.
281 | if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
282 | mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
283 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
284 | mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
285 | self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
286 | # Using Discounting L1 loss
287 | self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight_G
288 |
289 | self.loss_G = self.loss_G_L1 + self.loss_G_L1_m + self.loss_G_GAN
290 |
291 | # Then, add TV loss
292 | self.loss_tv = self.tv_criterion(self.fake_B*self.mask_global.float())
293 |
294 | # Finally, add style loss
295 | vgg_ft_fakeB = self.vgg16_extractor(fake_B)
296 | vgg_ft_realB = self.vgg16_extractor(real_B)
297 | self.loss_style = 0
298 | self.loss_content = 0
299 |
300 | for i in range(3):
301 | self.loss_style += self.criterionL2_style_loss(util.gram_matrix(vgg_ft_fakeB[i]), util.gram_matrix(vgg_ft_realB[i]))
302 | self.loss_content += self.criterionL2_content_loss(vgg_ft_fakeB[i], vgg_ft_realB[i])
303 |
304 | self.loss_style *= self.opt.style_weight
305 | self.loss_content *= self.opt.content_weight
306 |
307 | self.loss_G += (self.loss_style + self.loss_content + self.loss_tv)
308 |
309 | self.loss_G.backward()
310 |
311 | def optimize_parameters(self):
312 | self.forward()
313 | # update D
314 | self.set_requires_grad(self.netD, True)
315 | self.optimizer_D.zero_grad()
316 | self.backward_D()
317 | self.optimizer_D.step()
318 |
319 | # update G
320 | self.set_requires_grad(self.netD, False)
321 | self.optimizer_G.zero_grad()
322 | self.backward_G()
323 | self.optimizer_G.step()
324 |
325 |
326 |
--------------------------------------------------------------------------------
/notebooks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/notebooks/__init__.py
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self.initialized = False
9 |
10 | def initialize(self, parser):
11 | parser.add_argument('--dataroot', default='./datasets/Paris/train', help='path to training/testing images')
12 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
13 | parser.add_argument('--loadSize', type=int, default=350, help='scale images to this size')
14 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
15 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
16 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
17 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
18 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
19 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD, [basic|densenet]')
20 | parser.add_argument('--which_model_netG', type=str, default='unet_shift_triple', help='selects model to use for netG [unet_256| unet_shift_triple| \
21 | res_unet_shift_triple|patch_soft_unet_shift_triple| \
22 | res_patch_soft_unet_shift_triple| face_unet_shift_triple]')
23 | parser.add_argument('--model', type=str, default='shiftnet', \
24 | help='chooses which model to use. [shiftnet|res_shiftnet|patch_soft_shiftnet|res_patch_soft_shiftnet|test]')
25 | parser.add_argument('--triple_weight', type=float, default=1, help='The weight on the gradient of skip connections from the gradient of shifted')
26 | parser.add_argument('--name', type=str, default='exp', help='name of the experiment. It decides where to store samples and models')
27 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
28 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, use \'-1 \' for cpu training/testing')
29 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [aligned | aligned_resized | single]')
30 | parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
31 | parser.add_argument('--checkpoints_dir', type=str, default='./log', help='models are saved here')
32 | parser.add_argument('--norm', type=str, default='instance', help='[instance|batch|switchable] normalization')
33 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
34 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
35 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
36 |
37 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
38 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width]')
39 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
40 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
41 | parser.add_argument('--show_flow', type=int, default=0, help='show the flow information. WARNING: set display_freq a large number as it is quite slow when showing flow')
42 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
43 | ## model specific
44 | parser.add_argument('--mask_type', type=str, default='center',
45 | help='the type of mask you want to apply, \'center\' or \'random\'')
46 | parser.add_argument('--mask_sub_type', type=str, default='island',
47 | help='the type of mask you want to apply, \'rect \' or \'fractal \' or \'island \'')
48 | parser.add_argument('--lambda_A', type=int, default=100, help='weight on L1 term in objective')
49 | parser.add_argument('--stride', type=int, default=1, help='should be dense, 1 is a good option.')
50 | parser.add_argument('--shift_sz', type=int, default=1, help='shift_sz>1 only for \'soft_shift_patch\'.')
51 | parser.add_argument('--mask_thred', type=int, default=1, help='number to decide whether a patch is masked')
52 | parser.add_argument('--overlap', type=int, default=4, help='the overlap for center mask')
53 | parser.add_argument('--bottleneck', type=int, default=512, help='neurals of fc')
54 | parser.add_argument('--gp_lambda', type=float, default=10.0, help='gradient penalty coefficient')
55 | parser.add_argument('--constrain', type=str, default='MSE', help='guidance loss type')
56 | parser.add_argument('--strength', type=float, default=1, help='the weight of guidance loss')
57 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
58 | parser.add_argument('--skip', type=int, default=0, help='Whether skip guidance loss, if skipped performance degrades with dozens of percents faster')
59 | parser.add_argument('--fuse', type=int, default=0, help='Fuse may encourage large patches shifting when using \'patch_soft_shift\'')
60 | parser.add_argument('--gan_type', type=str, default='vanilla', help='wgan_gp, '
61 | 'lsgan, '
62 | 'vanilla, '
63 | 're_s_gan (Relativistic Standard GAN), ')
64 | parser.add_argument('--gan_weight', type=float, default=0.2, help='the weight of gan loss')
65 | # New added
66 | parser.add_argument('--style_weight', type=float, default=10.0, help='the weight of style loss')
67 | parser.add_argument('--content_weight', type=float, default=1.0, help='the weight of content loss')
68 | parser.add_argument('--tv_weight', type=float, default=0.0, help='the weight of tv loss, you can set a small value, such as 0.1/0.01')
69 | parser.add_argument('--offline_loading_mask', type=int, default=0, help='whether to load mask offline randomly')
70 | parser.add_argument('--mask_weight_G', type=float, default=400.0, help='the weight of mask part in ouput of G, you can try different mask_weight')
71 | parser.add_argument('--discounting', type=int, default=1, help='the loss type of mask part, whether using discounting l1 loss or normal l1')
72 | parser.add_argument('--use_spectral_norm_D', type=int, default=1, help='whether to add spectral norm to D, it helps improve results')
73 | parser.add_argument('--use_spectral_norm_G', type=int, default=0, help='whether to add spectral norm in G. Seems very bad when adding SN to G')
74 | parser.add_argument('--only_lastest', type=int, default=0,
75 | help='If True, it will save only the lastest weights')
76 | parser.add_argument('--add_mask2input', type=int, default=1,
77 | help='If True, It will add the mask as a fourth dimension over input space')
78 |
79 | self.initialized = True
80 | return parser
81 |
82 | def gather_options(self, options=None):
83 | # initialize parser with basic options
84 | if not self.initialized:
85 | parser = argparse.ArgumentParser(
86 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87 | parser = self.initialize(parser)
88 |
89 |
90 | self.parser = parser
91 | if options == None:
92 | return parser.parse_args()
93 | else:
94 | return parser.parse_args(options)
95 |
96 | def print_options(self, opt):
97 | message = ''
98 | message += '----------------- Options ---------------\n'
99 | for k, v in sorted(vars(opt).items()):
100 | comment = ''
101 | default = self.parser.get_default(k)
102 | if v != default:
103 | comment = '\t[default: %s]' % str(default)
104 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
105 | message += '----------------- End -------------------'
106 | print(message)
107 |
108 | # save to the disk
109 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
110 | util.mkdirs(expr_dir)
111 | file_name = os.path.join(expr_dir, 'opt.txt')
112 | with open(file_name, 'wt') as opt_file:
113 | opt_file.write(message)
114 | opt_file.write('\n')
115 |
116 | def parse(self, options=None):
117 |
118 | opt = self.gather_options(options=options)
119 | opt.isTrain = self.isTrain # train or test
120 |
121 | # process opt.suffix
122 | if opt.suffix:
123 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
124 | opt.name = opt.name + suffix
125 |
126 | self.print_options(opt)
127 |
128 | # set gpu ids
129 | os.environ["CUDA_VISIBLE_DEVICES"]=opt.gpu_ids
130 | str_ids = opt.gpu_ids.split(',')
131 | opt.gpu_ids = []
132 | for str_id in str_ids:
133 | id = int(str_id)
134 | if id >= 0:
135 | opt.gpu_ids.append(id)
136 | # re-order gpu ids
137 | opt.gpu_ids = [i.item() for i in torch.arange(len(opt.gpu_ids))]
138 | if len(opt.gpu_ids) > 0:
139 | torch.cuda.set_device(opt.gpu_ids[0])
140 |
141 | self.opt = opt
142 | return self.opt
143 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 | parser.add_argument('--which_epoch', type=str, default='20', help='which epoch to load? set to latest to use latest cached model')
12 | parser.add_argument('--how_many', type=int, default=1000, help='how many test images to run')
13 | parser.add_argument('--testing_mask_folder', type=str, default='masks/testing_masks', help='perpared masks for testing')
14 | self.isTrain = False
15 |
16 | return parser
17 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | # Here is the options especially for training
4 |
5 | class TrainOptions(BaseOptions):
6 | def initialize(self, parser):
7 | parser = BaseOptions.initialize(self, parser)
8 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
9 | parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
10 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
11 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
12 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
13 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
14 | parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console')
15 | parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
16 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
17 | parser.add_argument('--save_epoch_freq', type=int, default=2, help='frequency of saving checkpoints at the end of epochs')
18 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
19 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
20 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
21 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
22 | parser.add_argument('--niter', type=int, default=30, help='# of iter at starting learning rate')
23 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
24 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
25 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
26 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
27 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
28 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
29 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
30 | parser.add_argument('--training_mask_folder', type=str, default='masks/training_masks', help='prepared masks for training')
31 | self.isTrain = True
32 |
33 | return parser
34 |
--------------------------------------------------------------------------------
/shift_layer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/shift_layer.png
--------------------------------------------------------------------------------
/show_map.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | import torch
3 | import util.util as util
4 | from util.NonparametricShift import Modified_NonparametricShift
5 | from torch.nn import functional as F
6 | import numpy as numpy
7 | import matplotlib.pyplot as plt
8 |
9 | bz = 1
10 | c = 2 # at least 2
11 | w = 4
12 | h = 4
13 |
14 | feature_size = [bz, c, w, h]
15 |
16 | former = torch.rand(c*h*w).mul_(50).reshape(c, h, w).int().float()
17 | latter = torch.rand(c*h*w).mul_(50).reshape(c, h, w).int().float()
18 |
19 |
20 | flag = torch.zeros(h,w).byte()
21 | flag[h//4:h//2+1, h//4:h//2+1] = 1
22 | flag = flag.view(h*w)
23 |
24 | ind_lst = torch.FloatTensor(h*w, h*w).zero_()
25 | shift_offsets = []
26 |
27 | Nonparm = Modified_NonparametricShift()
28 | cosine, latter_windows, i_2, i_3, i_1, i_4 = Nonparm.cosine_similarity(former, latter, 1, 1, flag)
29 | ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
30 |
31 | _, indexes = torch.max(cosine, dim=1)
32 |
33 |
34 | # SET TRANSITION MATRIX
35 | mask_indexes = (flag == 1).nonzero()
36 | non_mask_indexes = (flag == 0).nonzero()[indexes]
37 | ind_lst[mask_indexes, non_mask_indexes] = 1
38 |
39 |
40 | # GET FINAL SHIFT FEATURE
41 | shift_masked_all = Nonparm._paste(latter_windows, ind_lst, i_2, i_3, i_1, i_4)
42 |
43 | print('flag')
44 | print(flag.reshape(h,w))
45 | print('ind_lst')
46 | print(ind_lst)
47 | print('out')
48 | print(shift_masked_all)
49 |
50 | # get shift offset ()
51 | shift_offset = torch.stack([non_mask_indexes.squeeze() // w, torch.fmod(non_mask_indexes.squeeze(), w)], dim=-1)
52 |
53 |
54 | shift_offsets.append(shift_offset)
55 | shift_offsets = torch.cat(shift_offsets, dim=0).float()
56 | print('shift_offset')
57 | print(shift_offset)
58 | print(shift_offset.size()) # (5*5)*2 (masked points)
59 |
60 | shift_offsets_cl = shift_offsets.clone()
61 |
62 |
63 | #visualize which pixels are attended
64 | print(flag.size()) # 256, (16*16)
65 |
66 |
67 | # global and N*C*H*W
68 | # put shift_offsets_cl back to the global map.
69 | shift_offsets_map = torch.zeros(bz, h, w, 2).float()
70 | print(shift_offsets_map.size()) # 1*16*16
71 |
72 | # mask_indexes 是对应的mask区域的点的位置。
73 | # shift_offsets是对应的要shift到mask区域的外部点的位置。
74 | shift_offsets_map[:, mask_indexes.squeeze() // w, mask_indexes.squeeze() % w, :] = shift_offsets_cl.unsqueeze(0)
75 | # 至此,shift_offsets_map是完整的,而且只有mask内部有值,代表着该点将被外面的某点替换。“某点”的坐标就是该点的值(2个通道)
76 | print('global shift_offsets_map')
77 | print(shift_offsets_map)
78 | print(shift_offsets_map.size())
79 | print(shift_offsets_map.type())
80 |
81 | flow2 = til.highlight_flow(shift_offsets_map, flag.unsqueeze(0))
82 | print('flow2 size')
83 | print(flow2.size())
84 |
85 | # upflow = F.interpolate(flow, scale_factor=4, mode='nearest')
86 | upflow2 = F.interpolate(flow2, scale_factor=1, mode='nearest')
87 |
88 | print('**After upsample flow2 size**')
89 | print(upflow2.size())
90 |
91 | # upflow = upflow.squeeze().permute(1,2,0)
92 | upflow2 = upflow2.squeeze().permute(1,2,0)
93 | print(upflow2.size())
94 |
95 | # print('flow 1')
96 | # print(upflow)
97 | # print(upflow.size())
98 |
99 | # print('flow 2')
100 | # print(upflow2)
101 | # print(upflow2.size())
102 | plt.imshow(upflow2/255.)
103 | # # axs[0].imshow(upflow)
104 | # axs[1].imshow(upflow2)
105 |
106 | plt.show()
107 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | from options.test_options import TestOptions
4 | from data.data_loader import CreateDataLoader
5 | from models import create_model
6 | from util.visualizer import save_images
7 | from util import html
8 |
9 | if __name__ == "__main__":
10 | opt = TestOptions().parse()
11 | opt.nThreads = 1 # test code only supports nThreads = 1
12 | opt.batchSize = 1 # test code only supports batchSize = 1
13 | opt.serial_batches = True # no shuffle
14 | opt.no_flip = True # no flip
15 | opt.display_id = -1 # no visdom display
16 | opt.loadSize = opt.fineSize # Do not scale!
17 |
18 | data_loader = CreateDataLoader(opt)
19 | dataset = data_loader.load_data()
20 | model = create_model(opt)
21 |
22 | # create website
23 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
24 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
25 | # test
26 | for i, data in enumerate(dataset):
27 | if i >= opt.how_many:
28 | break
29 | t1 = time.time()
30 | model.set_input(data)
31 | model.test()
32 | t2 = time.time()
33 | print(t2-t1)
34 | visuals = model.get_current_visuals()
35 | img_path = model.get_image_paths()
36 | print('process image... %s' % img_path)
37 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
38 | webpage.save()
39 |
--------------------------------------------------------------------------------
/test_acc_shift.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import util.util as util
3 | from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift
4 | from torch.nn import functional as F
5 | import numpy as numpy
6 | import matplotlib.pyplot as plt
7 |
8 | bz = 2
9 | c = 3 # at least 2
10 | w = 16
11 | h = 16
12 |
13 | feature_size = [bz, c, w, h]
14 |
15 | former = torch.rand(bz*c*h*w).mul_(50).reshape(bz, c, h, w).int().float()
16 | latter = torch.rand(bz*c*h*w).mul_(50).reshape(bz, c, h, w).int().float()
17 |
18 |
19 | flag = torch.zeros(bz, h, w).byte()
20 | flag[:, h//4:h//2+1, h//4:h//2+1] = 1
21 | flag = flag.view(bz, h*w)
22 |
23 | ind_lst = torch.FloatTensor(bz, h*w, h*w).zero_()
24 | shift_offsets = []
25 |
26 | #Nonparm = Modified_NonparametricShift()
27 | bNonparm = Batch_NonShift()
28 | cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former.clone(), latter.clone(), 1, 1, flag)
29 | print(cosine.size())
30 | print(latter_windows.size())
31 | ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
32 |
33 | _, indexes = torch.max(cosine, dim=2)
34 | print('indexes dim')
35 | print(indexes.size())
36 |
37 |
38 | # SET TRANSITION MATRIX
39 | mask_indexes = (flag == 1).nonzero()
40 | mask_indexes = mask_indexes[:,1] # remove indexes that indicates the batch dim
41 | mask_indexes = mask_indexes.view(bz, -1)
42 |
43 | # Also remove indexes of batch
44 | tmp = (flag==0).nonzero()[:,1]
45 | tmp = tmp.view(bz, -1)
46 | print('tmp size')
47 | print(tmp.size())
48 |
49 | idx_tmp = indexes + torch.arange(indexes.size(0)).view(-1,1) * tmp.size(1)
50 | non_mask_indexes = tmp.view(-1)[idx_tmp]
51 |
52 | # Original method
53 | non_mask_indexes_2 = []
54 | for i in range(bz):
55 | non_mask_indexes_tmp = tmp[i][indexes[i]]
56 | non_mask_indexes_2.append(non_mask_indexes_tmp)
57 |
58 | non_mask_indexes_2 = torch.stack(non_mask_indexes_2, dim=0)
59 |
60 | print('These two methods should be the same, as the error is 0!')
61 | print(torch.sum(non_mask_indexes-non_mask_indexes_2))
62 |
63 | ind_lst2 = ind_lst.clone()
64 | for i in range(bz):
65 | ind_lst[i][mask_indexes[i], non_mask_indexes[i]] = 1
66 |
67 | print(ind_lst.sum())
68 | print(ind_lst)
69 |
70 | for i in range(bz):
71 | for mi, nmi in zip(mask_indexes[i], non_mask_indexes[i]):
72 | print('The %d\t-th pixel in the %d-th tensor will shift to %d\t-th coordinate' %(nmi, i, mi))
73 | print('~~~')
74 |
75 | # GET FINAL SHIFT FEATURE
76 | shift_masked_all = bNonparm._paste(latter_windows, ind_lst, i_2, i_3, i_1)
77 | print(shift_masked_all.size())
78 |
79 | assert 1==2
80 | # print('flag')
81 | # print(flag.reshape(h,w))
82 | # print('ind_lst')
83 | # print(ind_lst)
84 | # print('out')
85 | # print(shift_masked_all)
86 |
87 | # get shift offset ()
88 | shift_offset = torch.stack([non_mask_indexes.squeeze() // w, torch.fmod(non_mask_indexes.squeeze(), w)], dim=-1)
89 | print('shift_offset')
90 | print(shift_offset)
91 | print(shift_offset.size())
92 |
93 | shift_offsets.append(shift_offset)
94 | shift_offsets = torch.cat(shift_offsets, dim=0).float()
95 | print(shift_offsets.size())
96 | print(shift_offsets)
97 |
98 | shift_offsets_cl = shift_offsets.clone()
99 |
100 | lt = (flag==1).nonzero()[0]
101 | rb = (flag==1).nonzero()[-1]
102 |
103 | mask_h = rb//w+1 - lt//w
104 | mask_w = rb%w+1 - lt%w
105 |
106 | shift_offsets = shift_offsets.view([bz] + [2] + [mask_h, mask_w]) # So only appropriate for square mask.
107 | print(shift_offsets.size())
108 | print(shift_offsets)
109 |
110 | h_add = torch.arange(0, float(h)).view([1, 1, h, 1]).float()
111 | h_add = h_add.expand(bz, 1, h, w)
112 | w_add = torch.arange(0, float(w)).view([1, 1, 1, w]).float()
113 | w_add = w_add.expand(bz, 1, h, w)
114 |
115 | com_map = torch.cat([h_add, w_add], dim=1)
116 | print('com_map')
117 | print(com_map)
118 |
119 | com_map_crop = com_map[:, :, lt//w:rb//w+1, lt%w:rb%w+1]
120 | print('com_map crop')
121 | print(com_map_crop)
122 |
123 | shift_offsets = shift_offsets - com_map_crop
124 | print('final shift_offsets')
125 | print(shift_offsets)
126 |
127 |
128 | # to flow image
129 | flow = torch.from_numpy(util.flow_to_image(shift_offsets.permute(0,2,3,1).cpu().data.numpy()))
130 | flow = flow.permute(0,3,1,2)
131 |
132 | #visualize which pixels are attended
133 | print(flag.size())
134 | print(shift_offsets.size())
135 |
136 | # global and N*C*H*W
137 | # put shift_offsets_cl back to the global map.
138 | shift_offsets_map = flag.clone().view(-1)
139 | shift_offsets_map[indexes] = shift_offsets_cl.view(-1)
140 | print(shift_offsets_map)
141 | assert 1==2
142 | flow2 = torch.from_numpy(util.highlight_flow((shift_offsets_cl).numpy()))
143 |
144 | upflow = F.interpolate(flow, scale_factor=4, mode='nearest')
145 | upflow2 = F.interpolate(flow2, scale_factor=4, mode='nearest')
146 |
147 |
148 | upflow = upflow.squeeze().permute(1,2,0)
149 | upflow2 = upflow2.squeeze().permute(1,2,0)
150 |
151 | print('flow 1')
152 | print(upflow)
153 | print(upflow.size())
154 |
155 | print('flow 2')
156 | print(upflow2)
157 | print(upflow2.size())
158 |
159 | fig, axs = plot.subplots(ncols=2)
160 | axs[0].imshow(upflow)
161 | axs[1].imshow(upflow2)
162 |
163 | plt.show()
164 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data.data_loader import CreateDataLoader
4 | from models import create_model
5 | from util.visualizer import Visualizer
6 |
7 | if __name__ == "__main__":
8 | opt = TrainOptions().parse()
9 | data_loader = CreateDataLoader(opt)
10 | dataset = data_loader.load_data()
11 | dataset_size = len(data_loader)
12 | print('#training images = %d' % dataset_size)
13 |
14 | model = create_model(opt)
15 | visualizer = Visualizer(opt)
16 |
17 | total_steps = 0
18 |
19 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
20 | epoch_start_time = time.time()
21 | iter_data_time = time.time()
22 | epoch_iter = 0
23 |
24 | for i, data in enumerate(dataset):
25 | iter_start_time = time.time()
26 | if total_steps % opt.print_freq == 0:
27 | t_data = iter_start_time - iter_data_time
28 | visualizer.reset()
29 | total_steps += opt.batchSize
30 | epoch_iter += opt.batchSize
31 |
32 | model.set_input(data) # it not only sets the input data with mask, but also sets the latent mask.
33 |
34 | # Additonal, should set it before 'optimize_parameters()'.
35 | if total_steps % opt.display_freq == 0:
36 | if opt.show_flow:
37 | model.set_show_map_true()
38 |
39 | model.optimize_parameters()
40 |
41 | if total_steps % opt.display_freq == 0:
42 | save_result = total_steps % opt.update_html_freq == 0
43 | if opt.show_flow:
44 | model.set_flow_src()
45 | model.set_show_map_false()
46 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
47 |
48 | if total_steps % opt.print_freq == 0:
49 | losses = model.get_current_losses()
50 | t = (time.time() - iter_start_time) / opt.batchSize
51 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
52 | if opt.display_id > 0:
53 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
54 |
55 | if total_steps % opt.save_latest_freq == 0:
56 | print('saving the latest model (epoch %d, total_steps %d)' %
57 | (epoch, total_steps))
58 | model.save_networks('latest')
59 |
60 | iter_data_time = time.time()
61 | if epoch % opt.save_epoch_freq == 0:
62 | print('saving the model at the end of epoch %d, iters %d' %
63 | (epoch, total_steps))
64 | model.save_networks('latest')
65 | if not opt.only_lastest:
66 | model.save_networks(epoch)
67 |
68 | print('End of epoch %d / %d \t Time Taken: %d sec' %
69 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
70 | model.update_learning_rate()
71 |
--------------------------------------------------------------------------------
/util/NonparametricShift.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from time import time
7 |
8 | # These three functions only work when patch_size = 1x1
9 | class Modified_NonparametricShift(object):
10 |
11 | def _extract_patches_from_flag(self, img, patch_size, stride, flag, value):
12 | input_windows = self._unfold(img, patch_size, stride)
13 |
14 | input_windows = self._filter(input_windows, flag, value)
15 |
16 | return self._norm(input_windows)
17 |
18 | # former: content, to be replaced.
19 | # latter: style, source pixels.
20 | def cosine_similarity(self, former, latter, patch_size, stride, flag, with_former=False):
21 | former_windows = self._unfold(former, patch_size, stride)
22 | former = self._filter(former_windows, flag, 1)
23 |
24 | latter_windows, i_2, i_3, i_1 = self._unfold(latter, patch_size, stride, with_indexes=True)
25 | latter = self._filter(latter_windows, flag, 0)
26 |
27 | num = torch.einsum('ik,jk->ij', [former, latter])
28 | norm_latter = torch.einsum("ij,ij->i", [latter, latter])
29 | norm_former = torch.einsum("ij,ij->i", [former, former])
30 | den = torch.sqrt(torch.einsum('i,j->ij', [norm_former, norm_latter]))
31 | if not with_former:
32 | return num / den, latter_windows, i_2, i_3, i_1
33 | else:
34 | return num / den, latter_windows, former_windows, i_2, i_3, i_1
35 |
36 |
37 | def _paste(self, input_windows, transition_matrix, i_2, i_3, i_1):
38 | ## TRANSPOSE FEATURES NEW FEATURES
39 | input_windows = torch.mm(transition_matrix, input_windows)
40 |
41 | ## RESIZE TO CORRET CONV FEATURES FORMAT
42 | input_windows = input_windows.view(i_2, i_3, i_1)
43 | input_windows = input_windows.permute(2, 0, 1).unsqueeze(0)
44 | return input_windows
45 |
46 | def _unfold(self, img, patch_size, stride, with_indexes=False):
47 | n_dim = 3
48 | assert img.dim() == n_dim, 'image must be of dimension 3.'
49 |
50 | kH, kW = patch_size, patch_size
51 | dH, dW = stride, stride
52 | input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)
53 |
54 | i_1, i_2, i_3, i_4, i_5 = input_windows.size()
55 |
56 | if with_indexes:
57 | input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1)
58 | return input_windows, i_2, i_3, i_1
59 | else:
60 | input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1, i_4, i_5)
61 | return input_windows
62 |
63 | def _filter(self, input_windows, flag, value):
64 | ## EXTRACT MASK OR NOT DEPENDING ON VALUE
65 | input_window = input_windows[flag == value]
66 | return input_window.view(input_window.size(0), -1)
67 |
68 |
69 | def _norm(self, input_window):
70 | # This norm is incorrect.
71 | #return torch.norm(input_window, dim=1, keepdim=True)
72 | for i in range(input_window.size(0)):
73 | input_window[i] = input_window[i]*(1/(input_window[i].norm(2)+1e-8))
74 |
75 | return input_window
76 |
77 | class Batch_NonShift(object):
78 |
79 | def _extract_patches_from_flag(self, img, patch_size, stride, flag, value):
80 | input_windows = self._unfold(img, patch_size, stride)
81 |
82 | input_windows = self._filter(input_windows, flag, value)
83 |
84 | return self._norm(input_windows)
85 |
86 | # former: content, to be replaced.
87 | # latter: style, source pixels.
88 | def cosine_similarity(self, former, latter, patch_size, stride, flag, with_former=False):
89 | former_windows = self._unfold(former, patch_size, stride)
90 | former = self._filter(former_windows, flag, 1)
91 |
92 | latter_windows, i_2, i_3, i_1 = self._unfold(latter, patch_size, stride, with_indexes=True)
93 | latter = self._filter(latter_windows, flag, 0)
94 |
95 | num = torch.einsum('bik,bjk->bij', [former, latter])
96 | norm_latter = torch.einsum("bij,bij->bi", [latter, latter])
97 | norm_former = torch.einsum("bij,bij->bi", [former, former])
98 | den = torch.sqrt(torch.einsum('bi,bj->bij', [norm_former, norm_latter]))
99 | if not with_former:
100 | return num / den, latter_windows, i_2, i_3, i_1
101 | else:
102 | return num / den, latter_windows, former_windows, i_2, i_3, i_1
103 |
104 |
105 | # delete i_4, as i_4 is 1
106 | def _paste(self, input_windows, transition_matrix, i_2, i_3, i_1):
107 | ## TRANSPOSE FEATURES NEW FEATURES
108 | bz = input_windows.size(0)
109 | input_windows = torch.bmm(transition_matrix, input_windows)
110 |
111 | ## RESIZE TO CORRET CONV FEATURES FORMAT
112 | input_windows = input_windows.view(bz, i_2, i_3, i_1)
113 | input_windows = input_windows.permute(0, 3, 1, 2)
114 | return input_windows
115 |
116 | def _unfold(self, img, patch_size, stride, with_indexes=False):
117 | n_dim = 4
118 | assert img.dim() == n_dim, 'image must be of dimension 4.'
119 |
120 | kH, kW = patch_size, patch_size
121 | dH, dW = stride, stride
122 | input_windows = img.unfold(2, kH, dH).unfold(3, kW, dW)
123 |
124 | i_0, i_1, i_2, i_3, i_4, i_5 = input_windows.size()
125 |
126 | if with_indexes:
127 | input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1)
128 | return input_windows, i_2, i_3, i_1
129 | else:
130 | input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1, i_4, i_5)
131 | return input_windows
132 |
133 | def _filter(self, input_windows, flag, value):
134 | ## EXTRACT MASK OR NOT DEPENDING ON VALUE
135 | assert flag.dim() == 2, "flag should be batch version"
136 | input_window = input_windows[flag == value]
137 | bz = flag.size(0)
138 | return input_window.view(bz, input_window.size(0)//bz, -1)
139 |
140 |
141 | # Deprecated code
142 | class NonparametricShift(object):
143 | def buildAutoencoder(self, target_img, normalize, interpolate, nonmask_point_idx, patch_size=1, stride=1):
144 | nDim = 3
145 | assert target_img.dim() == nDim, 'target image must be of dimension 3.'
146 | C = target_img.size(0)
147 |
148 | self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor
149 |
150 | patches_all, patches_part = self._extract_patches(target_img, patch_size, stride, nonmask_point_idx)
151 | npatches_part = patches_part.size(0)
152 | npatches_all = patches_all.size(0)
153 |
154 |
155 | conv_enc_non_mask, conv_dec_non_mask = self._build(patch_size, stride, C, patches_part, npatches_part, normalize, interpolate)
156 | conv_enc_all, conv_dec_all = self._build(patch_size, stride, C, patches_all, npatches_all, normalize, interpolate)
157 |
158 | return conv_enc_all, conv_enc_non_mask, conv_dec_all, conv_dec_non_mask
159 |
160 | def _build(self, patch_size, stride, C, target_patches, npatches, normalize, interpolate):
161 | # for each patch, divide by its L2 norm.
162 | enc_patches = target_patches.clone()
163 | for i in range(npatches):
164 | enc_patches[i] = enc_patches[i]*(1/(enc_patches[i].norm(2)+1e-8))
165 | conv_enc = nn.Conv2d(C, npatches, kernel_size=patch_size, stride=stride, bias=False)
166 | conv_enc.weight.data = enc_patches
167 |
168 | # normalize is not needed, it doesn't change the result!
169 | if normalize:
170 | raise NotImplementedError
171 |
172 | if interpolate:
173 | raise NotImplementedError
174 |
175 | conv_dec = nn.ConvTranspose2d(npatches, C, kernel_size=patch_size, stride=stride, bias=False)
176 | conv_dec.weight.data = target_patches
177 |
178 | return conv_enc, conv_dec
179 |
180 | def _extract_patches(self, img, patch_size, stride, nonmask_point_idx):
181 | n_dim = 3
182 | assert img.dim() == n_dim, 'image must be of dimension 3.'
183 |
184 | kH, kW = patch_size, patch_size
185 | dH, dW = stride, stride
186 | input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)
187 |
188 | i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(2), input_windows.size(3), input_windows.size(4)
189 | input_windows = input_windows.permute(1,2,0,3,4).contiguous().view(i_2*i_3, i_1, i_4, i_5)
190 |
191 | patches_all = input_windows
192 | patches = input_windows.index_select(0, nonmask_point_idx) #It returns a new tensor, representing patches extracted from non-masked region!
193 | return patches_all, patches
194 |
195 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zhaoyi-Yan/Shift-Net_pytorch/a3534315b23c2db4f37e82666b8254bf13d2698d/util/__init__.py
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if refresh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="refresh", content=str(refresh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
65 |
--------------------------------------------------------------------------------
/util/png.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import zlib
3 |
4 | def encode(buf, width, height):
5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
6 | assert (width * height * 3 == len(buf))
7 | bpp = 3
8 |
9 | def raw_data():
10 | # reverse the vertical line order and add null bytes at the start
11 | row_bytes = width * bpp
12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
13 | yield b'\x00'
14 | yield buf[row_start:row_start + row_bytes]
15 |
16 | def chunk(tag, data):
17 | return [
18 | struct.pack("!I", len(data)),
19 | tag,
20 | data,
21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
22 | ]
23 |
24 | SIGNATURE = b'\x89PNG\r\n\x1a\n'
25 | COLOR_TYPE_RGB = 2
26 | COLOR_TYPE_RGBA = 6
27 | bit_depth = 8
28 | return b''.join(
29 | [ SIGNATURE ] +
30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
32 | chunk(b'IEND', b'')
33 | )
34 |
--------------------------------------------------------------------------------
/util/poisson_blending.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse
3 | import cv2
4 | import pyamg
5 |
6 | # pre-process the mask array so that uint64 types from opencv.imread can be adapted
7 | def prepare_mask(mask):
8 | if type(mask[0][0]) is np.ndarray:
9 | result = np.ndarray((mask.shape[0], mask.shape[1]), dtype=np.uint8)
10 | for i in range(mask.shape[0]):
11 | for j in range(mask.shape[1]):
12 | if sum(mask[i][j]) > 0:
13 | result[i][j] = 1
14 | else:
15 | result[i][j] = 0
16 | mask = result
17 | return mask
18 |
19 | def blend(img_target, img_source, img_mask, offset=(0, 0)):
20 | # compute regions to be blended
21 | region_source = (
22 | max(-offset[0], 0),
23 | max(-offset[1], 0),
24 | min(img_target.shape[0]-offset[0], img_source.shape[0]),
25 | min(img_target.shape[1]-offset[1], img_source.shape[1]))
26 | region_target = (
27 | max(offset[0], 0),
28 | max(offset[1], 0),
29 | min(img_target.shape[0], img_source.shape[0]+offset[0]),
30 | min(img_target.shape[1], img_source.shape[1]+offset[1]))
31 | region_size = (region_source[2]-region_source[0], region_source[3]-region_source[1])
32 |
33 | # clip and normalize mask image
34 | img_mask = img_mask[region_source[0]:region_source[2], region_source[1]:region_source[3]]
35 | img_mask = prepare_mask(img_mask)
36 | img_mask[img_mask==0] = False
37 | img_mask[img_mask!=False] = True
38 |
39 | # create coefficient matrix
40 | A = scipy.sparse.identity(np.prod(region_size), format='lil')
41 | for y in range(region_size[0]):
42 | for x in range(region_size[1]):
43 | if img_mask[y,x]:
44 | index = x+y*region_size[1]
45 | A[index, index] = 4
46 | if index+1 < np.prod(region_size):
47 | A[index, index+1] = -1
48 | if index-1 >= 0:
49 | A[index, index-1] = -1
50 | if index+region_size[1] < np.prod(region_size):
51 | A[index, index+region_size[1]] = -1
52 | if index-region_size[1] >= 0:
53 | A[index, index-region_size[1]] = -1
54 | A = A.tocsr()
55 |
56 | # create poisson matrix for b
57 | P = pyamg.gallery.poisson(img_mask.shape)
58 |
59 | # for each layer (ex. RGB)
60 | for num_layer in range(img_target.shape[2]):
61 | # get subimages
62 | t = img_target[region_target[0]:region_target[2], region_target[1]:region_target[3],num_layer]
63 | s = img_source[region_source[0]:region_source[2], region_source[1]:region_source[3],num_layer]
64 | t = t.flatten()
65 | s = s.flatten()
66 |
67 | # create b
68 | b = P * s
69 | for y in range(region_size[0]):
70 | for x in range(region_size[1]):
71 | if not img_mask[y,x]:
72 | index = x+y*region_size[1]
73 | b[index] = t[index]
74 |
75 | # solve Ax = b
76 | x = pyamg.solve(A,b,verb=False,tol=1e-10)
77 |
78 | # assign x to target image
79 | x = np.reshape(x, region_size)
80 | x[x>255] = 255
81 | x[x<0] = 0
82 | x = np.array(x, img_target.dtype)
83 | img_target[region_target[0]:region_target[2],region_target[1]:region_target[3],num_layer] = x
84 |
85 | return img_target
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | import sys
6 | from subprocess import Popen, PIPE
7 | from . import util, html
8 | from scipy.misc import imresize
9 |
10 | if sys.version_info[0] == 2:
11 | VisdomExceptionBase = Exception
12 | else:
13 | VisdomExceptionBase = ConnectionError
14 |
15 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
16 | """Save images to the disk.
17 | Parameters:
18 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
19 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
20 | image_path (str) -- the string is used to create image paths
21 | aspect_ratio (float) -- the aspect ratio of saved images
22 | width (int) -- the images will be resized to width x width
23 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
24 | """
25 | image_dir = webpage.get_image_dir()
26 | short_path = ntpath.basename(image_path[0])
27 | name = os.path.splitext(short_path)[0]
28 |
29 | webpage.add_header(name)
30 | ims, txts, links = [], [], []
31 |
32 | for label, im_data in visuals.items():
33 | im = util.tensor2im(im_data)
34 | image_name = '%s_%s.png' % (name, label)
35 | save_path = os.path.join(image_dir, image_name)
36 | h, w, _ = im.shape
37 | if aspect_ratio > 1.0:
38 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
39 | if aspect_ratio < 1.0:
40 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
41 | util.save_image(im, save_path)
42 |
43 | ims.append(image_name)
44 | txts.append(label)
45 | links.append(image_name)
46 | webpage.add_images(ims, txts, links, width=width)
47 |
48 | class Visualizer():
49 | def __init__(self, opt):
50 | self.display_id = opt.display_id
51 | self.use_html = opt.isTrain and not opt.no_html
52 | self.win_size = opt.display_winsize
53 | self.name = opt.name
54 | self.port = opt.display_port
55 | self.opt = opt
56 | self.saved = False
57 | if self.display_id > 0:
58 | import visdom
59 | self.ncols = opt.display_ncols
60 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
61 | if not self.vis.check_connection():
62 | self.create_visdom_connections()
63 |
64 | if self.use_html:
65 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
66 | self.img_dir = os.path.join(self.web_dir, 'images')
67 | print('create web directory %s...' % self.web_dir)
68 | util.mkdirs([self.web_dir, self.img_dir])
69 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
70 | with open(self.log_name, "a") as log_file:
71 | now = time.strftime("%c")
72 | log_file.write('================ Training Loss (%s) ================\n' % now)
73 |
74 | def reset(self):
75 | self.saved = False
76 |
77 | def create_visdom_connections(self):
78 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
79 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
80 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
81 | print('Command: %s' % cmd)
82 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
83 |
84 | # |visuals|: dictionary of images to display or save
85 | def display_current_results(self, visuals, epoch, save_result):
86 | if self.display_id > 0: # show images in the browser
87 | ncols = self.ncols
88 | if ncols > 0:
89 | ncols = min(ncols, len(visuals))
90 | h, w = next(iter(visuals.values())).shape[:2]
91 | table_css = """""" % (w, h)
95 | title = self.name
96 | label_html = ''
97 | label_html_row = ''
98 | images = []
99 | idx = 0
100 | for label, image in visuals.items():
101 | image = util.rm_extra_dim(image) # remove the dummy dim
102 | image_numpy = util.tensor2im(image)
103 | label_html_row += '%s | ' % label
104 | images.append(image_numpy.transpose([2, 0, 1]))
105 | idx += 1
106 | if idx % ncols == 0:
107 | label_html += '%s
' % label_html_row
108 | label_html_row = ''
109 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
110 | while idx % ncols != 0:
111 | images.append(white_image)
112 | label_html_row += ' | '
113 | idx += 1
114 | if label_html_row != '':
115 | label_html += '%s
' % label_html_row
116 | try:
117 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
118 | padding=2, opts=dict(title=title + ' images'))
119 | label_html = '' % label_html
120 | self.vis.text(table_css + label_html, win=self.display_id + 2,
121 | opts=dict(title=title + ' labels'))
122 | except VisdomExceptionBase:
123 | self.create_visdom_connections()
124 | else:
125 | idx = 1
126 | for label, image in visuals.items():
127 | image_numpy = util.tensor2im(image)
128 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
129 | win=self.display_id + idx)
130 | idx += 1
131 |
132 | if self.use_html and (save_result or not self.saved): # save images to a html file
133 | self.saved = True
134 | for label, image in visuals.items():
135 | image_numpy = util.tensor2im(image)
136 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
137 | util.save_image(image_numpy, img_path)
138 | # update website
139 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
140 | for n in range(epoch, 0, -1):
141 | webpage.add_header('epoch [%d]' % n)
142 | ims, txts, links = [], [], []
143 |
144 | for label, image_numpy in visuals.items():
145 | image_numpy = util.tensor2im(image)
146 | img_path = 'epoch%.3d_%s.png' % (n, label)
147 | ims.append(img_path)
148 | txts.append(label)
149 | links.append(img_path)
150 | webpage.add_images(ims, txts, links, width=self.win_size)
151 | webpage.save()
152 |
153 | # losses: dictionary of error labels and values
154 | def plot_current_losses(self, epoch, counter_ratio, opt, losses):
155 | if not hasattr(self, 'plot_data'):
156 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
157 | self.plot_data['X'].append(epoch + counter_ratio)
158 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
159 | self.vis.line(
160 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
161 | Y=np.array(self.plot_data['Y']),
162 | opts={
163 | 'title': self.name + ' loss over time',
164 | 'legend': self.plot_data['legend'],
165 | 'xlabel': 'epoch',
166 | 'ylabel': 'loss'},
167 | win=self.display_id)
168 |
169 | # losses: same format as |losses| of plot_current_losses
170 | def print_current_losses(self, epoch, i, losses, t, t_data):
171 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
172 | for k, v in losses.items():
173 | message += '%s: %.3f ' % (k, v)
174 |
175 | print(message)
176 | with open(self.log_name, "a") as log_file:
177 | log_file.write('%s\n' % message)
178 |
179 |
180 |
--------------------------------------------------------------------------------