├── LICENSE
├── README.md
├── checkpoints
├── gin
│ └── put_trained_model_here
└── warmup
│ └── put_warm_up_here
├── data
├── __pycache__
│ ├── ade20k_dataset.cpython-36.pyc
│ ├── ade20k_dataset.cpython-37.pyc
│ ├── base_data_loader.cpython-36.pyc
│ ├── base_data_loader.cpython-37.pyc
│ ├── base_dataset.cpython-36.pyc
│ ├── base_dataset.cpython-37.pyc
│ ├── custom_dataset_data_loader.cpython-36.pyc
│ ├── custom_dataset_data_loader.cpython-37.pyc
│ ├── data_loader.cpython-36.pyc
│ ├── data_loader.cpython-37.pyc
│ ├── image_folder.cpython-36.pyc
│ ├── image_folder.cpython-37.pyc
│ ├── masks.cpython-36.pyc
│ └── masks.cpython-37.pyc
├── ade20k_dataset.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── data_loader.py
├── image_folder.py
└── masks.py
├── datasets
└── ade20k
│ ├── test
│ ├── AIM_IC_t1_validation_0_mask.png
│ ├── AIM_IC_t1_validation_0_with_holes.png
│ ├── AIM_IC_t1_validation_18_mask.png
│ └── AIM_IC_t1_validation_18_with_holes.png
│ └── train
│ ├── a
│ └── abbey
│ │ ├── ADE_train_00000976.jpg
│ │ ├── ADE_train_00000976_seg.png
│ │ ├── ADE_train_00000985.jpg
│ │ ├── ADE_train_00000985_seg.png
│ │ ├── ADE_train_00000992.jpg
│ │ └── ADE_train_00000992_seg.png
│ └── z
│ └── zoo
│ ├── ADE_train_00020206.jpg
│ └── ADE_train_00020206_seg.png
├── examples
├── AIM_IC_t1_validation_0.png
├── AIM_IC_t1_validation_0_with_holes.png
├── ablation_study.png
├── arc.png
├── architecture.png
├── comparisons_ffhq_oxford.png
├── problem_of_interest.png
├── spd_resnetblk.png
└── visualization_seg.png
├── models
├── __pycache__
│ ├── BoundaryVAE_model.cpython-36.pyc
│ ├── BoundaryVAE_model.cpython-37.pyc
│ ├── base_model.cpython-36.pyc
│ ├── base_model.cpython-37.pyc
│ ├── models.cpython-36.pyc
│ ├── models.cpython-37.pyc
│ ├── models_pretrain.cpython-36.pyc
│ ├── networks.cpython-36.pyc
│ ├── networks.cpython-37.pyc
│ └── our_model.cpython-36.pyc
├── base_model.py
├── models.py
├── networks.py
└── our_model.py
├── options
├── __pycache__
│ ├── base_options.cpython-36.pyc
│ ├── base_options.cpython-37.pyc
│ ├── test_options.cpython-36.pyc
│ ├── train_options.cpython-36.pyc
│ └── train_options.cpython-37.pyc
├── base_options.py
├── test_options.py
└── train_options.py
├── results
└── test
│ └── AIM_IC_t1_validation_0.png
├── test_ensemble.py
├── train.py
└── util
├── __pycache__
├── html.cpython-36.pyc
├── html.cpython-37.pyc
├── util.cpython-36.pyc
├── util.cpython-37.pyc
├── visualizer.cpython-36.pyc
└── visualizer.cpython-37.pyc
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, ronctl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 |
28 |
29 | --------------------------- LICENSE FOR pytorch-pix2pixHD ----------------
30 | Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
31 | BSD License. All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
44 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
45 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
46 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
47 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
48 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
49 |
50 |
51 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ----------------
52 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
53 | All rights reserved.
54 |
55 | Redistribution and use in source and binary forms, with or without
56 | modification, are permitted provided that the following conditions are met:
57 |
58 | * Redistributions of source code must retain the above copyright notice, this
59 | list of conditions and the following disclaimer.
60 |
61 | * Redistributions in binary form must reproduce the above copyright notice,
62 | this list of conditions and the following disclaimer in the documentation
63 | and/or other materials provided with the distribution.
64 |
65 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
66 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
67 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
68 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
69 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
70 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
71 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
72 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
73 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
74 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
75 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Generative Inpainting Network (GIN) for Extreme Image Inpainting
2 | For AIM2020 ECCV Extreme Image Inpainting Track 1 Classic
3 | This is the Pytorch implementation of our Deep Generative Inpainting Network (GIN) for Extreme Image Inpainting. We have participated in AIM 2020 ECCV Extreme Image Inpainting Challenge. Our GIN is used for reconstructing a completed image with satisfactory visual quality from a randomly masked image.
4 |
5 | ## Overview
6 |
7 |
8 |
9 | Our Spatial Pyramid Dilation (SPD) block
10 |
11 |
12 |
13 |
14 | ## Example of Image Inpainting using our GIN
15 | - An example from the validation set of the AIM20 ECCV Extreme Image Inpainting Track 1 Classic
16 | - (left: masked image, right: our completed image)
17 |
18 |
19 |
20 |
21 |
22 | ## Preparation
23 | - Our solution is developed using Pytorch 1.5.0 platform
24 | - We train our model on two NVIDIA GeForce RTX 2080 Ti (with 11GB memory)
25 | - Apart from Pytorch and related dependencies,
26 | - Install natsort
27 | ```bash
28 | pip install natsort
29 | ```
30 | - Install dominate
31 | ```bash
32 | pip install dominate
33 | ```
34 | - Install scipy 1.1.0
35 | ```bash
36 | pip install scipy==1.1.0
37 | ```
38 | - If you would like to use tensorboard for logging, please also install tensorboard and tensorflow
39 | - Please clone this project:
40 | ```bash
41 | git clone https://github.com/rlct1/gin.git
42 | cd gin
43 | ```
44 |
45 | ## Testing
46 | - An example of the validation data of this challenge is provided in the `datasets/ade20k/test` folder
47 | - Please download our trained model for this challenge [here](https://drive.google.com/file/d/1yOtMELWwTBc-PMSY69x1FH8D1anUN7tD/view?usp=sharing) (google drive link), and put it under `checkpoints/gin/`
48 | - For reproducing the test results for this challenge, please put all the testing images under `datasets/ade20k/test/`
49 | - You can test our model by typing:
50 | ```bash
51 | python test_ensemble.py --name gin
52 | ```
53 | - The test results will be stored in `results/test` folder
54 | - If you would like to test on other datasets, please refer to the file structure in the `datasets/ade20k/test` folder
55 | - Note that the file structure is for AIM20 IC Track 1
56 | - You can download our test results for this challenge [here](https://drive.google.com/file/d/1EJgQ3neOA2WkZMmG6uG0GG14VoLYmNFg/view?usp=sharing) (google drive link)
57 |
58 | ## Training
59 | - By default, our model is trained using two GPUs
60 | - Examples of the training images from this challenge is provided in the `datasets/ade20k/train` folder
61 | - If you would like to train a model using our warm up for initialization, please download our warm up for this challenge [here](https://drive.google.com/file/d/1T3ST-ujhtDZQpWUiagICOAIvBF7CMeYz/view?usp=sharing) (google drive link), and put it under `checkpoints/warmup/`
62 | ```bash
63 | python train.py --name yourmodel --continue_train --load_pretrain './checkpoints/warmup'
64 | ```
65 | - If you would like to train a model from scratch,
66 | ```bash
67 | python train.py --name yourmodel
68 | ```
69 | - If you would like to train a model based on your own selection and resources, please refer to the `options/base_options.py` and `options/train_options.py` for details
70 |
71 | ## Experiments
72 | Ablation Study
73 |
74 |
75 |
76 | Comparisons
77 |
78 |
79 |
80 | Visualization of predicted semantic segmentation map
81 |
82 |
83 |
84 |
85 | ## Citation
86 | Thanks for visiting our project page, if it is useful, please cite our paper,
87 | ```
88 | @misc{li2020deepgin,
89 | title={DeepGIN: Deep Generative Inpainting Network for Extreme Image Inpainting},
90 | author={Chu-Tak Li and Wan-Chi Siu and Zhi-Song Liu and Li-Wen Wang and Daniel Pak-Kong Lun},
91 | year={2020},
92 | eprint={2008.07173},
93 | archivePrefix={arXiv},
94 | primaryClass={cs.CV}
95 | }
96 | ```
97 |
98 | ## Acknowledgment
99 | Our code is developed based on the skeleton of the Pytorch implementation of [pix2pixHD](https://github.com/NVIDIA/pix2pixHD)
100 |
101 |
--------------------------------------------------------------------------------
/checkpoints/gin/put_trained_model_here:
--------------------------------------------------------------------------------
1 | put trained model here
2 |
--------------------------------------------------------------------------------
/checkpoints/warmup/put_warm_up_here:
--------------------------------------------------------------------------------
1 | put warm up model here
2 |
--------------------------------------------------------------------------------
/data/__pycache__/ade20k_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/ade20k_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/ade20k_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/ade20k_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/base_data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/base_data_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_data_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/base_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/base_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/custom_dataset_data_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/custom_dataset_data_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/data_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/data_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/image_folder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/image_folder.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/image_folder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/image_folder.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/masks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/masks.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/masks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/masks.cpython-37.pyc
--------------------------------------------------------------------------------
/data/ade20k_dataset.py:
--------------------------------------------------------------------------------
1 | ####################################################################################
2 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
3 | # Licensed under the CC BY-NC-SA 4.0 license
4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
5 | ####################################################################################
6 | import os.path
7 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize
8 | from data.image_folder import make_dataset, make_dataset_with_conditions, make_dataset_with_condition_list
9 | from data.masks import Masks
10 | from PIL import Image
11 | import random
12 | import numpy as np
13 | from natsort import natsorted
14 | import torch
15 |
16 | class ADE20kDataset(BaseDataset):
17 | def initialize(self, opt):
18 | self.opt = opt
19 | self.root = opt.dataroot
20 | ################################################
21 | # input A : masked images (rgb)
22 | # input B : real images (rgb)
23 | # input mask : masks (gray)
24 | # input B_seg : real seg images (rgb)
25 | ################################################
26 |
27 | if opt.phase == 'train':
28 | # the 10,330 images in ade20k train set
29 | self.dir_A = os.path.join(opt.dataroot, opt.phase)
30 | self.B_seg_paths, _ = make_dataset_with_conditions(self.dir_A, '_seg')
31 | self.A_paths, _ = make_dataset_with_conditions(self.dir_A, '.jpg')
32 |
33 | self.B_seg_paths = natsorted(self.B_seg_paths)
34 | self.A_paths = natsorted(self.A_paths)
35 | self.B_paths = self.A_paths
36 |
37 | self.dir_mask = []
38 | self.mask_paths = []
39 |
40 | self.mask = Masks()
41 |
42 | elif opt.phase == 'test':
43 | # aim 2020 eccv validation set and test set.
44 | # no ground truth
45 | self.dir_A = os.path.join(opt.dataroot, opt.phase)
46 | self.A_paths, _ = make_dataset_with_conditions(self.dir_A, '_with_holes')
47 | self.A_paths = natsorted(self.A_paths)
48 |
49 | self.dir_B = []
50 | self.B_paths = []
51 | self.B_seg_paths = [] # depends on track 1 or 2
52 |
53 | self.dir_mask = os.path.join(opt.dataroot, opt.phase)
54 | self.mask_paths, _ = make_dataset_with_conditions(self.dir_mask, '_mask')
55 | self.mask_paths = natsorted(self.mask_paths)
56 |
57 | self.dataset_size = len(self.A_paths)
58 |
59 | def __getitem__(self, index):
60 | ################################################
61 | # input A : masked images (rgb)
62 | # input B : real images (rgb)
63 | # input mask : masks (gray)
64 | # input B_seg : real seg images (rgb)
65 | ################################################
66 |
67 | A_path = self.A_paths[index]
68 | A = Image.open(A_path).convert('RGB')
69 |
70 | params = get_params(self.opt, A.size)
71 | transform_A = get_transform(self.opt, params, normalize=False)
72 | A_tensor = transform_A(A)
73 |
74 | B_tensor = mask_tensor = 0
75 | B_seg_tensor = 0
76 |
77 | if self.opt.phase == 'train':
78 | # the 10,330 images in ade20k train set
79 | B_seg_path = self.B_seg_paths[index]
80 | B_seg = Image.open(B_seg_path).convert('RGB')
81 |
82 | new_A, new_B_seg = self.resize_or_crop(A, B_seg)
83 | ## data augmentation rotate
84 | new_A = self.rotate(new_A)
85 | new_A = self.ensemble(new_A)
86 |
87 | width, height = new_A.size
88 | f_A = np.array(new_A, np.float32)
89 | f_A1 = np.array(new_A, np.float32)
90 | f_A2 = np.array(new_A, np.float32)
91 | f_A3 = np.array(new_A, np.float32)
92 |
93 | f_mask = self.mask.get_random_mask(height, width)
94 |
95 | f_A[:, :, 0] = f_A[:, :, 0] * (1.0 - f_mask) + 255.0 * f_mask
96 | f_A[:, :, 1] = f_A[:, :, 1] * (1.0 - f_mask) + 255.0 * f_mask
97 | f_A[:, :, 2] = f_A[:, :, 2] * (1.0 - f_mask) + 255.0 * f_mask
98 | f_masked_A = f_A
99 |
100 | f_mask1 = self.mask.get_box_mask(height, width)
101 | f_mask2 = self.mask.get_ca_mask(height, width)
102 | f_mask3 = self.mask.get_ff_mask(height, width)
103 |
104 | f_A1[:, :, 0] = f_A1[:, :, 0] * (1.0 - f_mask1) + 255.0 * f_mask1
105 | f_A1[:, :, 1] = f_A1[:, :, 1] * (1.0 - f_mask1) + 255.0 * f_mask1
106 | f_A1[:, :, 2] = f_A1[:, :, 2] * (1.0 - f_mask1) + 255.0 * f_mask1
107 | f_masked_A1 = f_A1
108 |
109 | f_A2[:, :, 0] = f_A2[:, :, 0] * (1.0 - f_mask2) + 255.0 * f_mask2
110 | f_A2[:, :, 1] = f_A2[:, :, 1] * (1.0 - f_mask2) + 255.0 * f_mask2
111 | f_A2[:, :, 2] = f_A2[:, :, 2] * (1.0 - f_mask2) + 255.0 * f_mask2
112 | f_masked_A2 = f_A2
113 |
114 | f_A3[:, :, 0] = f_A3[:, :, 0] * (1.0 - f_mask3) + 255.0 * f_mask3
115 | f_A3[:, :, 1] = f_A3[:, :, 1] * (1.0 - f_mask3) + 255.0 * f_mask3
116 | f_A3[:, :, 2] = f_A3[:, :, 2] * (1.0 - f_mask3) + 255.0 * f_mask3
117 | f_masked_A3 = f_A3
118 |
119 | # masked images
120 | masked_A = Image.fromarray((f_masked_A).astype(np.uint8)).convert('RGB')
121 | mask_img = Image.fromarray((f_mask * 255.0).astype(np.uint8)).convert('L')
122 |
123 | masked_A1 = Image.fromarray((f_masked_A1).astype(np.uint8)).convert('RGB')
124 | mask_img1 = Image.fromarray((f_mask1 * 255.0).astype(np.uint8)).convert('L')
125 |
126 | masked_A2 = Image.fromarray((f_masked_A2).astype(np.uint8)).convert('RGB')
127 | mask_img2 = Image.fromarray((f_mask2 * 255.0).astype(np.uint8)).convert('L')
128 |
129 | masked_A3 = Image.fromarray((f_masked_A3).astype(np.uint8)).convert('RGB')
130 | mask_img3 = Image.fromarray((f_mask3 * 255.0).astype(np.uint8)).convert('L')
131 |
132 | A_tensor = transform_A(masked_A)
133 |
134 | ##
135 | A_tensor1 = transform_A(masked_A1)
136 | A_tensor2 = transform_A(masked_A2)
137 | A_tensor3 = transform_A(masked_A3)
138 |
139 | # real images
140 | B_path = self.B_paths[index]
141 | B = new_A
142 | transform_B = get_transform(self.opt, params, normalize=False)
143 | B_tensor = transform_B(B)
144 |
145 | B_tensor1 = transform_B(B)
146 | B_tensor2 = transform_B(B)
147 | B_tensor3 = transform_B(B)
148 |
149 | transform_B_seg = get_transform(self.opt, params, normalize=False)
150 | B_seg_tensor = transform_B_seg(new_B_seg)
151 |
152 | # masks
153 | mask_path = []
154 | transform_mask = get_transform(self.opt, params, normalize=False)
155 | mask_tensor = transform_mask(mask_img)
156 |
157 | mask_tensor1 = transform_mask(mask_img1)
158 | mask_tensor2 = transform_mask(mask_img2)
159 | mask_tensor3 = transform_mask(mask_img3)
160 |
161 | A_tensor = torch.cat((A_tensor1, A_tensor2, A_tensor3), dim=0)
162 | B_tensor = torch.cat((B_tensor1, B_tensor2, B_tensor3), dim=0)
163 | mask_tensor = torch.cat((mask_tensor1, mask_tensor2, mask_tensor3), dim=0)
164 |
165 | elif self.opt.phase == 'test':
166 | # aim 2020 eccv validation set and test set.
167 | # no ground truth
168 | B_path = []
169 | B_seg_path = []
170 |
171 | # masks
172 | mask_path = self.mask_paths[index]
173 | mask = Image.open(mask_path).convert('L')
174 | transform_mask = get_transform(self.opt, params, normalize=False)
175 | mask_tensor = transform_mask(mask)
176 |
177 | if self.opt.phase == 'test':
178 | input_dict = {
179 | 'masked_image': A_tensor,
180 | 'mask': mask_tensor,
181 | 'path_mskimg': A_path,
182 | 'path_msk': mask_path}
183 | return input_dict
184 |
185 | input_dict = {'masked_image': A_tensor,
186 | 'mask': mask_tensor,
187 | 'real_image': B_tensor,
188 | 'real_seg': B_seg_tensor}
189 | return input_dict
190 |
191 | def __len__(self):
192 | return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
193 |
194 | def name(self):
195 | return 'ADE20kDataset'
196 |
197 | def resize_or_crop(self, img, seg_img, method=Image.BICUBIC):
198 | w, h = img.size
199 | new_w = w
200 | new_h = h
201 |
202 | if w > self.opt.loadSize and h > self.opt.loadSize:
203 | return img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method), seg_img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method)
204 | else:
205 | return img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method), seg_img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method)
206 |
207 | def rotate(self, img):
208 | bFlag = random.randint(0, 3)
209 | if bFlag == 0:
210 | return img.rotate(0)
211 | elif bFlag == 1:
212 | return img.rotate(90)
213 | elif bFlag == 2:
214 | return img.rotate(180)
215 | elif bFlag == 3:
216 | return img.rotate(270)
217 |
218 | def ensemble(self, img):
219 | bFlag = random.randint(0, 3)
220 | width, height = img.size
221 | new_w = width // 2
222 | new_h = height // 2
223 | new_img = img.resize((new_w, new_h), Image.BICUBIC)
224 | np_img = np.array(img, np.float32)
225 | np_new_img = np.array(new_img, np.float32)
226 |
227 | if bFlag == 0:
228 | for i in range(new_w):
229 | for j in range(new_h):
230 | np_new_img[i, j, 0] = np_img[2*i, 2*j, 0]
231 | np_new_img[i, j, 1] = np_img[2*i, 2*j, 1]
232 | np_new_img[i, j, 2] = np_img[2*i, 2*j, 2]
233 | elif bFlag == 1:
234 | for i in range(new_w):
235 | for j in range(new_h):
236 | np_new_img[i, j, 0] = np_img[1 + 2*i, 1 + 2*j, 0]
237 | np_new_img[i, j, 1] = np_img[1 + 2*i, 1 + 2*j, 1]
238 | np_new_img[i, j, 2] = np_img[1 + 2*i, 1 + 2*j, 2]
239 | elif bFlag == 2:
240 | for i in range(new_w):
241 | for j in range(new_h):
242 | np_new_img[i, j, 0] = np_img[1 + 2*i, 2*j, 0]
243 | np_new_img[i, j, 1] = np_img[1 + 2*i, 2*j, 1]
244 | np_new_img[i, j, 2] = np_img[1 + 2*i, 2*j, 2]
245 | else:
246 | for i in range(new_w):
247 | for j in range(new_h):
248 | np_new_img[i, j, 0] = np_img[2*i, 1 + 2*j, 0]
249 | np_new_img[i, j, 1] = np_img[2*i, 1 + 2*j, 1]
250 | np_new_img[i, j, 2] = np_img[2*i, 1 + 2*j, 2]
251 |
252 | new_A = Image.fromarray((np_new_img).astype(np.uint8)).convert('RGB')
253 | return new_A
254 |
255 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self, opt):
6 | self.opt = opt
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
12 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | ####################################################################################
2 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
3 | # Licensed under the CC BY-NC-SA 4.0 license
4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
5 | ####################################################################################
6 | import torch.utils.data as data
7 | from PIL import Image
8 | import random
9 | import numpy as np
10 | import torchvision.transforms as transforms
11 |
12 |
13 | class BaseDataset(data.Dataset):
14 | def __init__(self):
15 | super(BaseDataset, self).__init__()
16 |
17 | def name(self):
18 | return 'BaseDataset'
19 |
20 | def initialize(self, opt):
21 | pass
22 |
23 |
24 | def get_params(opt, size):
25 | w, h = size
26 | new_h = h
27 | new_w = w
28 | if opt.resize_or_crop == 'resize_and_crop':
29 | new_h = new_w = opt.loadSize
30 | elif opt.resize_or_crop == 'scale_width_and_crop':
31 | new_w = opt.loadSize
32 | new_h = opt.loadSize * h // w
33 |
34 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
35 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
36 |
37 | flip = random.random() > 0.5
38 | return {'crop_pos': (x, y), 'flip': flip}
39 |
40 |
41 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, grayscale=False):
42 | transform_list = []
43 | if 'resize' in opt.resize_or_crop:
44 | osize = [opt.loadSize, opt.loadSize]
45 | transform_list.append(transforms.Scale(osize, method))
46 | elif 'scale_width' in opt.resize_or_crop:
47 | transform_list.append(transforms.Lambda(
48 | lambda img: __scale_width(img, opt.loadSize, method)))
49 |
50 | if 'crop' in opt.resize_or_crop:
51 | transform_list.append(transforms.Lambda(
52 | lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
53 |
54 | if opt.resize_or_crop == 'none':
55 | base = float(2 ** opt.n_downsample_global)
56 | if opt.netG == 'local':
57 | base *= (2 ** opt.n_local_enhancers)
58 | transform_list.append(transforms.Lambda(
59 | lambda img: __make_power_2(img, base, method)))
60 |
61 | if opt.resize_or_crop == 'standard':
62 | transform_list.append(transforms.Lambda(
63 | lambda img: __resize(img, opt.fineSize, method)))
64 |
65 | if opt.isTrain and not opt.no_flip:
66 | transform_list.append(transforms.Lambda(
67 | lambda img: __flip(img, params['flip'])))
68 |
69 | transform_list += [transforms.ToTensor()]
70 |
71 | if normalize:
72 | if grayscale:
73 | transform_list += [transforms.Normalize((0.5), (0.5))]
74 | else:
75 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
76 | (0.5, 0.5, 0.5))]
77 |
78 | return transforms.Compose(transform_list)
79 |
80 |
81 | def normalize():
82 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
83 |
84 |
85 | def __make_power_2(img, base, method=Image.BICUBIC):
86 | ow, oh = img.size
87 |
88 | h = int(round(oh / base) * base)
89 | w = int(round(ow / base) * base)
90 |
91 | if (h == oh) and (w == ow):
92 | return img
93 | return img.resize((w, h), method)
94 |
95 |
96 | def __scale_width(img, target_width, method=Image.BICUBIC):
97 | ow, oh = img.size
98 | if (ow == target_width):
99 | return img
100 | w = target_width
101 | h = int(target_width * oh / ow)
102 | return img.resize((w, h), method)
103 |
104 |
105 | def __resize(img, target_size, method=Image.BICUBIC):
106 | ow, oh = img.size
107 | if (ow == target_size) and (oh == target_size):
108 | return img
109 | return img.resize((target_size, target_size), method)
110 |
111 |
112 | def __crop(img, pos, size):
113 | ow, oh = img.size
114 | x1, y1 = pos
115 | tw = th = size
116 | if (ow > tw or oh > th):
117 | return img.crop((x1, y1, x1 + tw, y1 + th))
118 | return img
119 |
120 |
121 | def __flip(img, flip):
122 | if flip:
123 | return img.transpose(Image.FLIP_LEFT_RIGHT)
124 | return img
125 |
126 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | #from data.ade20k_dataset import ADE20kDataset
8 | from data.ade20k_dataset import ADE20kDataset
9 | dataset = ADE20kDataset()
10 |
11 | print("dataset [%s] was created" % (dataset.name()))
12 | dataset.initialize(opt)
13 | return dataset
14 |
15 |
16 | class CustomDatasetDataLoader(BaseDataLoader):
17 | def name(self):
18 | return 'CustomDatasetDataLoader'
19 |
20 | def initialize(self, opt):
21 | BaseDataLoader.initialize(self, opt)
22 | self.dataset = CreateDataset(opt)
23 | self.dataloader = torch.utils.data.DataLoader(
24 | self.dataset,
25 | batch_size=opt.batchSize,
26 | shuffle=not opt.serial_batches,
27 | num_workers=int(opt.nThreads))
28 |
29 | def load_data(self):
30 | return self.dataloader
31 |
32 | def __len__(self):
33 | return min(len(self.dataset), self.opt.max_dataset_size)
34 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
9 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ####################################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ####################################################################################
7 |
8 |
9 | import torch.utils.data as data
10 | from PIL import Image
11 | import os
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 | # make own dataset
23 |
24 |
25 | def make_dataset(dir):
26 | images = []
27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
28 |
29 | for root, _, fnames in sorted(os.walk(dir)):
30 | for fname in fnames:
31 | if is_image_file(fname):
32 | path = os.path.join(root, fname)
33 | images.append(path)
34 |
35 | return images
36 |
37 |
38 | # make own dataset, with certain wordings
39 | def make_dataset_with_conditions(dir, wording='_mask'):
40 | images = []
41 | images_no_conditions = []
42 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
43 |
44 | for root, _, fnames in sorted(os.walk(dir)):
45 | for fname in fnames:
46 | if is_image_file(fname) and (wording in fname):
47 | path = os.path.join(root, fname)
48 | images.append(path)
49 | else:
50 | path = os.path.join(root, fname)
51 | images_no_conditions.append(path)
52 |
53 | return images, images_no_conditions
54 |
55 | # make own dataset, with certain wordings
56 | def make_dataset_with_condition_list(dir):
57 | images = []
58 | images_no_conditions = []
59 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
60 |
61 | for root, _, fnames in sorted(os.walk(dir)):
62 | for fname in fnames:
63 | if is_image_file(fname) and (('_seg' in fname) or ('_seg_inst' in fname) or ('_boundary' in fname)):
64 | path = os.path.join(root, fname)
65 | images.append(path)
66 | else:
67 | path = os.path.join(root, fname)
68 | images_no_conditions.append(path)
69 | #print(images)
70 | #print(images_no_conditions)
71 |
72 | return images, images_no_conditions
73 |
74 |
75 | def default_loader(path):
76 | return Image.open(path).convert('RGB')
77 |
78 |
79 | class ImageFolder(data.Dataset):
80 | def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
81 | imgs = make_dataset(root)
82 | if len(imgs) == 0:
83 | raise(RuntimeError("Found 0 images in : " + root + "\n"
84 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
85 |
86 | self.root = root
87 | self.imgs = imgs
88 | self.transform = transform
89 | self.return_paths = return_paths
90 | self.loader = loader
91 |
92 | def __getitem__(self, index):
93 | path = self.imgs[index]
94 | img = self.loader(path)
95 | if self.transform is not None:
96 | img = self.transform(img)
97 | if self.return_paths:
98 | return img, path
99 | else:
100 | return img
101 |
102 | def __len__(self):
103 | return len(self.imgs)
104 |
105 |
--------------------------------------------------------------------------------
/data/masks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import random
4 | from scipy import ndimage, misc
5 |
6 | class Masks():
7 |
8 | @staticmethod
9 | def get_ff_mask(h, w, num_v = None):
10 | #Source: Generative Inpainting https://github.com/JiahuiYu/generative_inpainting
11 |
12 | mask = np.zeros((h,w))
13 | if num_v is None:
14 | num_v = 15+np.random.randint(9) #5
15 |
16 | for i in range(num_v):
17 | start_x = np.random.randint(w)
18 | start_y = np.random.randint(h)
19 | for j in range(1+np.random.randint(5)):
20 | angle = 0.01+np.random.randint(4.0)
21 | if i % 2 == 0:
22 | angle = 2 * 3.1415926 - angle
23 | length = 10+np.random.randint(60) # 40
24 | brush_w = 10+np.random.randint(15) # 10
25 | end_x = (start_x + length * np.sin(angle)).astype(np.int32)
26 | end_y = (start_y + length * np.cos(angle)).astype(np.int32)
27 |
28 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
29 | start_x, start_y = end_x, end_y
30 |
31 | return mask.astype(np.float32)
32 |
33 |
34 | @staticmethod
35 | def get_box_mask(h,w):
36 | height, width = h, w
37 |
38 | mask = np.zeros((height, width))
39 |
40 | mask_width = random.randint(int(0.3 * width), int(0.7 * width))
41 | mask_height = random.randint(int(0.3 * height), int(0.7 * height))
42 |
43 | mask_x = random.randint(0, width - mask_width)
44 | mask_y = random.randint(0, height - mask_height)
45 |
46 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
47 | return mask.astype(np.float32)
48 |
49 | @staticmethod
50 | def get_ca_mask(h,w, scale = None, r = None):
51 |
52 | if scale is None:
53 | scale = random.choice([1,2,4,8])
54 | if r is None:
55 | r = random.randint(2,6) # repeat median filter r times
56 |
57 | height = h
58 | width = w
59 | mask = np.random.randint(2, size = (height//scale, width//scale))
60 |
61 | for _ in range(r):
62 | mask = ndimage.median_filter(mask, size=3, mode='constant')
63 |
64 | mask = misc.imresize(mask,(h,w),interp='nearest')
65 | if scale > 1:
66 | struct = ndimage.generate_binary_structure(2, 1)
67 | mask = ndimage.morphology.binary_dilation(mask, struct)
68 | elif scale > 3:
69 | struct = np.array([[ 0., 0., 1., 0., 0.],
70 | [ 0., 1., 1., 1., 0.],
71 | [ 1., 1., 1., 1., 1.],
72 | [ 0., 1., 1., 1., 0.],
73 | [ 0., 0., 1., 0., 0.]])
74 |
75 | return (mask > 0).astype(np.float32)
76 |
77 | @staticmethod
78 | def get_random_mask(h,w):
79 | f = random.choice([Masks.get_box_mask, Masks.get_ca_mask, Masks.get_ff_mask])
80 | return f(h,w)
--------------------------------------------------------------------------------
/datasets/ade20k/test/AIM_IC_t1_validation_0_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_0_mask.png
--------------------------------------------------------------------------------
/datasets/ade20k/test/AIM_IC_t1_validation_0_with_holes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_0_with_holes.png
--------------------------------------------------------------------------------
/datasets/ade20k/test/AIM_IC_t1_validation_18_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_18_mask.png
--------------------------------------------------------------------------------
/datasets/ade20k/test/AIM_IC_t1_validation_18_with_holes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_18_with_holes.png
--------------------------------------------------------------------------------
/datasets/ade20k/train/a/abbey/ADE_train_00000976.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000976.jpg
--------------------------------------------------------------------------------
/datasets/ade20k/train/a/abbey/ADE_train_00000976_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000976_seg.png
--------------------------------------------------------------------------------
/datasets/ade20k/train/a/abbey/ADE_train_00000985.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000985.jpg
--------------------------------------------------------------------------------
/datasets/ade20k/train/a/abbey/ADE_train_00000985_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000985_seg.png
--------------------------------------------------------------------------------
/datasets/ade20k/train/a/abbey/ADE_train_00000992.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000992.jpg
--------------------------------------------------------------------------------
/datasets/ade20k/train/a/abbey/ADE_train_00000992_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000992_seg.png
--------------------------------------------------------------------------------
/datasets/ade20k/train/z/zoo/ADE_train_00020206.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/z/zoo/ADE_train_00020206.jpg
--------------------------------------------------------------------------------
/datasets/ade20k/train/z/zoo/ADE_train_00020206_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/z/zoo/ADE_train_00020206_seg.png
--------------------------------------------------------------------------------
/examples/AIM_IC_t1_validation_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/AIM_IC_t1_validation_0.png
--------------------------------------------------------------------------------
/examples/AIM_IC_t1_validation_0_with_holes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/AIM_IC_t1_validation_0_with_holes.png
--------------------------------------------------------------------------------
/examples/ablation_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/ablation_study.png
--------------------------------------------------------------------------------
/examples/arc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/arc.png
--------------------------------------------------------------------------------
/examples/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/architecture.png
--------------------------------------------------------------------------------
/examples/comparisons_ffhq_oxford.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/comparisons_ffhq_oxford.png
--------------------------------------------------------------------------------
/examples/problem_of_interest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/problem_of_interest.png
--------------------------------------------------------------------------------
/examples/spd_resnetblk.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/spd_resnetblk.png
--------------------------------------------------------------------------------
/examples/visualization_seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/visualization_seg.png
--------------------------------------------------------------------------------
/models/__pycache__/BoundaryVAE_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/BoundaryVAE_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/BoundaryVAE_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/BoundaryVAE_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/base_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/base_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models_pretrain.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/models_pretrain.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/our_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/our_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import sys
5 |
6 | class BaseModel(nn.Module):
7 | def name(self):
8 | return 'BaseModel'
9 |
10 | def initialize(self, opt):
11 | self.opt = opt
12 | self.gpu_ids = opt.gpu_ids
13 | self.isTrain = opt.isTrain
14 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
16 |
17 | def set_input(self, input):
18 | self.input = input
19 |
20 | def forward(self):
21 | pass
22 |
23 | # used in test time, no backprop
24 | def test(self):
25 | pass
26 |
27 | def get_image_paths(self):
28 | pass
29 |
30 | def optimize_parameters(self):
31 | pass
32 |
33 | def get_current_visuals(self):
34 | return self.input
35 |
36 | def get_current_errors(self):
37 | return {}
38 |
39 | def save(self, label):
40 | pass
41 |
42 | # helper saving function that can be used by subclasses
43 | def save_network(self, network, network_label, epoch_label, gpu_ids):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | torch.save(network.cpu().state_dict(), save_path)
47 | if len(gpu_ids) and torch.cuda.is_available():
48 | network.cuda()
49 |
50 | # helper loading function that can be used by subclasses
51 | def load_network(self, network, network_label, epoch_label, save_dir=''):
52 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
53 | if not save_dir:
54 | save_dir = self.save_dir
55 | save_path = os.path.join(save_dir, save_filename)
56 | if not os.path.isfile(save_path):
57 | print('%s not exists yet!' % save_path)
58 | if network_label == 'G':
59 | raise('Generator must exist!')
60 | else:
61 | try:
62 | network.load_state_dict(torch.load(save_path))
63 | except:
64 | pretrained_dict = torch.load(save_path)
65 | model_dict = network.state_dict()
66 | try:
67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
68 | network.load_state_dict(pretrained_dict)
69 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
70 | except:
71 | print('Pretrained network %s has fewer layers; The following are not initialized: ' % network_label)
72 | for k, v in pretrained_dict.items():
73 | if v.size() == model_dict[k].size():
74 | model_dict[k] = v
75 |
76 | if sys.version_info >= (3,0):
77 | not_initialized = set()
78 | else:
79 | from sets import Set
80 | not_initialized = Set()
81 |
82 | for k, v in model_dict.items():
83 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
84 | not_initialized.add(k.split('.')[0])
85 |
86 | print(sorted(not_initialized))
87 | network.load_state_dict(model_dict)
88 |
89 | def update_learning_rate():
90 | pass
91 |
92 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def create_model(opt):
4 | if opt.model == 'Ours':
5 | from .our_model import OurModel, InferenceModel
6 | if opt.isTrain:
7 | model = OurModel()
8 | else:
9 | model = InferenceModel()
10 | else:
11 | print('Please define your model [%s]!'.format(opt.model))
12 | model.initialize(opt)
13 | print("model [%s] was created" % (model.name()))
14 | num_params_G, num_params_D = model.get_num_params()
15 |
16 | if opt.isTrain and len(opt.gpu_ids):
17 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
18 |
19 | return model, num_params_G, num_params_D
20 |
21 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import functools
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 |
10 | ############################################################
11 | ### Functions
12 | ############################################################
13 | def weights_init(m):
14 | classname = m.__class__.__name__
15 | if hasattr(m, 'weight') and classname.find('Conv2d') != -1:
16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
17 | m.weight.data *= 0.1
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif classname.find('BatchNorm2d') != -1:
21 | m.weight.data.normal_(1.0, 0.02)
22 | m.bias.data.fill_(0)
23 | elif classname.find('ConvTranspose2d') != -1:
24 | m.weight.data.normal_(0.0, 0.02)
25 | if m.bias is not None:
26 | m.bias.data.zero_()
27 | elif classname.find('Linear') != -1:
28 | m.weight.data.normal_(0.0, 0.01)
29 | if m.bias is not None:
30 | m.bias.data.zero_()
31 |
32 | def get_norm_layer(norm_type='instance'):
33 | if norm_type == 'batch':
34 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
35 | elif norm_type == 'instance':
36 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
37 | else:
38 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
39 | return norm_layer
40 |
41 | def print_network(net):
42 | num_params = 0
43 | for param in net.parameters():
44 | num_params += param.numel()
45 | print(net)
46 | print('Total number of parameters: %d' % num_params)
47 | print('--------------------------------------------------------------')
48 | return num_params
49 |
50 | def define_G(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=9, norm='instance', gpu_ids=[]):
51 | netG = ImageTinker2(input_nc, output_nc, ngf=64, n_downsampling=4, n_blocks=4, norm_layer=nn.BatchNorm2d, pad_type='replicate')
52 |
53 | num_params = print_network(netG)
54 |
55 | if len(gpu_ids) > 0:
56 | assert(torch.cuda.is_available())
57 | netG.cuda(gpu_ids[0])
58 | netG.apply(weights_init)
59 |
60 | return netG, num_params
61 |
62 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
63 | norm_layer = get_norm_layer(norm_type=norm)
64 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
65 | num_params = print_network(netD)
66 |
67 | if len(gpu_ids) > 0:
68 | assert(torch.cuda.is_available())
69 | netD.cuda(gpu_ids[0])
70 | netD.apply(weights_init)
71 |
72 | return netD, num_params
73 |
74 | class ImageTinker2(nn.Module):
75 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=4, n_blocks=4, norm_layer=nn.InstanceNorm2d, pad_type='replicate', activation=nn.LeakyReLU(0.2, True)):
76 | assert(n_blocks >= 0)
77 | super(ImageTinker2, self).__init__()
78 |
79 | if pad_type == 'reflect':
80 | self.pad = nn.ReflectionPad2d
81 | elif pad_type == 'zero':
82 | self.pad = nn.ZeroPad2d
83 | elif pad_type == 'replicate':
84 | self.pad = nn.ReplicationPad2d
85 |
86 | # LR coarse tinker (encoder)
87 | lr_coarse_tinker = [self.pad(3), nn.Conv2d(input_nc, ngf // 2, kernel_size=7, stride=1, padding=0), activation]
88 | lr_coarse_tinker += [self.pad(1), nn.Conv2d(ngf // 2, ngf, kernel_size=4, stride=2, padding=0), activation]
89 | lr_coarse_tinker += [self.pad(1), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=0), activation]
90 | lr_coarse_tinker += [self.pad(1), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=0), activation]
91 | # bottle neck
92 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)]
93 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)]
94 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)]
95 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)]
96 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)]
97 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)]
98 | # decoder
99 | lr_coarse_tinker += [nn.UpsamplingBilinear2d(scale_factor=2), self.pad(1), nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=0), activation]
100 | lr_coarse_tinker += [nn.UpsamplingBilinear2d(scale_factor=2), self.pad(1), nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=0), activation]
101 | lr_coarse_tinker += [nn.UpsamplingBilinear2d(scale_factor=2), self.pad(1), nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=0), activation]
102 | lr_coarse_tinker += [self.pad(3), nn.Conv2d(ngf // 2, output_nc, kernel_size=7, stride=1, padding=0)]
103 | ### get a coarse (256x256x3)
104 | self.lr_coarse_tinker = nn.Sequential(*lr_coarse_tinker)
105 |
106 | self.r_en_padd1 = self.pad(3)
107 | self.r_en_conv1 = nn.Conv2d(input_nc, ngf // 2, kernel_size=7, stride=1, padding=0)
108 | self.r_en_acti1 = activation
109 |
110 | self.r_en_padd2 = self.pad(1)
111 | self.r_en_conv2 = nn.Conv2d(ngf // 2, ngf, kernel_size=4, stride=2, padding=0)
112 | self.r_en_acti2 = activation
113 |
114 | self.r_en_padd3 = self.pad(1)
115 | self.r_en_conv3 = nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=0)
116 | self.r_en_acti3 = activation
117 |
118 | self.r_en_skp_padd3 = self.pad(1)
119 | self.r_en_skp_conv3 = nn.Conv2d(ngf * 2, ngf * 2 // 2, kernel_size=3, stride=1, padding=0)
120 | self.r_en_skp_acti3 = activation
121 |
122 | self.r_en_padd4 = self.pad(1)
123 | self.r_en_conv4 = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=0)
124 | self.r_en_acti4 = activation
125 |
126 | self.r_en_skp_padd4 = self.pad(1)
127 | self.r_en_skp_conv4 = nn.Conv2d(ngf * 4, ngf * 4 // 2, kernel_size=3, stride=1, padding=0)
128 | self.r_en_skp_acti4 = activation
129 |
130 | self.r_en_padd5 = self.pad(1)
131 | self.r_en_conv5 = nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=0)
132 | self.r_en_acti5 = activation
133 |
134 | self.r_md_mres1 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)
135 | self.r_md_mres2 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)
136 | self.r_md_mres5 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)
137 | self.r_md_satn1 = NonLocalBlock(ngf * 8, sub_sample=False, bn_layer=False)
138 | self.r_md_mres3 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)
139 | self.r_md_mres4 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)
140 | self.r_md_mres6 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)
141 |
142 | self.r_de_upbi1 = nn.UpsamplingBilinear2d(scale_factor=2)
143 | self.r_de_padd1 = self.pad(1)
144 | self.r_de_conv1 = nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=0)
145 | self.r_de_acti1 = activation
146 |
147 | self.r_de_satn2 = NonLocalBlock(ngf * 4 // 2, sub_sample=False, bn_layer=False)
148 | self.r_de_satn3 = NonLocalBlock(ngf * 2 // 2, sub_sample=False, bn_layer=False)
149 |
150 | self.r_de_mix_padd1 = self.pad(1)
151 | self.r_de_mix_conv1 = nn.Conv2d(ngf * 4 + ngf * 4 // 2, ngf * 4, kernel_size=3, stride=1, padding=0)
152 | self.r_de_mix_acti1 = activation
153 |
154 | self.r_de_upbi2 = nn.UpsamplingBilinear2d(scale_factor=2)
155 | self.r_de_padd2 = self.pad(1)
156 | self.r_de_conv2 = nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=0)
157 | self.r_de_acti2 = activation
158 |
159 | self.r_de_mix_padd2 = self.pad(1)
160 | self.r_de_mix_conv2 = nn.Conv2d(ngf * 2 + ngf * 2 // 2, ngf * 2, kernel_size=3, stride=1, padding=0)
161 | self.r_de_mix_acti2 = activation
162 |
163 | self.r_de_padd2_lr = self.pad(1)
164 | self.r_de_conv2_lr = nn.Conv2d(ngf * 2, ngf // 2, kernel_size=3, stride=1, padding=0)
165 | self.r_de_acti2_lr = activation
166 |
167 | self.r_de_padd3_lr = self.pad(1)
168 | self.r_de_conv3_lr = nn.Conv2d(ngf // 2, output_nc, kernel_size=3, stride=1, padding=0)
169 |
170 | self.r_de_upbi3 = nn.UpsamplingBilinear2d(scale_factor=2)
171 | self.r_de_padd3 = self.pad(1)
172 | self.r_de_conv3 = nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=0)
173 | self.r_de_acti3 = activation
174 |
175 | self.r_de_upbi4 = nn.UpsamplingBilinear2d(scale_factor=2)
176 | self.r_de_padd4 = self.pad(1)
177 | self.r_de_conv4 = nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=0)
178 | self.r_de_acti4 = activation
179 |
180 | self.r_de_padd5 = self.pad(3)
181 | self.r_de_conv5 = nn.Conv2d(ngf // 2, output_nc, kernel_size=7, stride=1, padding=0)
182 |
183 | self.r_de_padd5_lr_alpha = self.pad(1)
184 | self.r_de_conv5_lr_alpha = nn.Conv2d(ngf // 2, 1, kernel_size=3, stride=1, padding=0)
185 | self.r_de_acti5_lr_alpha = nn.Sigmoid()
186 |
187 | self.up = nn.UpsamplingBilinear2d(scale_factor=4)
188 | self.down = nn.UpsamplingBilinear2d(scale_factor=0.25)
189 |
190 | def forward(self, msked_img, msk, real_img=None):
191 | if real_img is not None:
192 | rimg = real_img
193 | inp = real_img * (1 - msk) + msk
194 | else:
195 | rimg = msked_img
196 | inp = msked_img
197 |
198 | x = torch.cat((inp, msk), dim=1)
199 | lr_x = self.lr_coarse_tinker(x)
200 | hr_x = lr_x * msk + rimg * (1 - msk)
201 |
202 | y = torch.cat((hr_x, msk), dim=1)
203 | e1 = self.r_en_acti1(self.r_en_conv1(self.r_en_padd1(y)))
204 | e2 = self.r_en_acti2(self.r_en_conv2(self.r_en_padd2(e1)))
205 | e3 = self.r_en_acti3(self.r_en_conv3(self.r_en_padd3(e2)))
206 | e4 = self.r_en_acti4(self.r_en_conv4(self.r_en_padd4(e3)))
207 | e5 = self.r_en_acti5(self.r_en_conv5(self.r_en_padd5(e4)))
208 |
209 | skp_e3 = self.r_en_skp_acti3(self.r_en_skp_conv3(self.r_en_skp_padd3(e3)))
210 | skp_e4 = self.r_en_skp_acti4(self.r_en_skp_conv4(self.r_en_skp_padd4(e4)))
211 |
212 | de3 = self.r_de_satn3(skp_e3)
213 | de4 = self.r_de_satn2(skp_e4)
214 |
215 | m1 = self.r_md_mres1(e5)
216 | m2 = self.r_md_mres2(m1)
217 | m5 = self.r_md_mres5(m2)
218 | a1 = self.r_md_satn1(m5)
219 | m3 = self.r_md_mres3(a1)
220 | m4 = self.r_md_mres4(m3)
221 | m6 = self.r_md_mres6(m4)
222 |
223 | d1 = self.r_de_acti1(self.r_de_conv1((self.r_de_padd1(self.r_de_upbi1(m6))))) # 32x32x256
224 | cat1 = torch.cat((d1, de4), dim=1)
225 | md1 = self.r_de_mix_acti1(self.r_de_mix_conv1(self.r_de_mix_padd1(cat1)))
226 |
227 | d2 = self.r_de_acti2(self.r_de_conv2((self.r_de_padd2(self.r_de_upbi2(md1))))) # 64x64x128
228 | cat2 = torch.cat((d2, de3), dim=1)
229 | md2 = self.r_de_mix_acti2(self.r_de_mix_conv2(self.r_de_mix_padd2(cat2)))
230 |
231 | d2_lr = self.r_de_acti2_lr(self.r_de_conv2_lr(self.r_de_padd2_lr(md2)))
232 | d3_lr = self.r_de_conv3_lr(self.r_de_padd3_lr(d2_lr))
233 |
234 | d3 = self.r_de_acti3(self.r_de_conv3((self.r_de_padd3(self.r_de_upbi3(md2))))) # 128x128x64
235 | d4 = self.r_de_acti4(self.r_de_conv4((self.r_de_padd4(self.r_de_upbi4(d3))))) # 256x256x32
236 |
237 | d5 = self.r_de_conv5(self.r_de_padd5(d4))
238 | d5_lr_alpha = self.r_de_acti5_lr_alpha(self.r_de_conv5_lr_alpha(self.r_de_padd5_lr_alpha(d4)))
239 |
240 | ###
241 | # d5: 256x256x3
242 | # d5_lr_alpha: 256x256x1
243 | # d3_lr: 64x64x3
244 | ###
245 | lr_img = d3_lr
246 |
247 | #reconst_img = d5
248 | d5 = d5 * msk + rimg * (1 - msk)
249 | lr_d5 = self.down(d5)
250 | lr_d5_res = d3_lr - lr_d5
251 | hr_d5_res = self.up(lr_d5_res)
252 | reconst_img = d5 + hr_d5_res * d5_lr_alpha
253 | compltd_img = reconst_img * msk + rimg * (1 - msk)
254 | #out = compltd_img + hr_d5_res * d5_lr_alpha
255 |
256 | return compltd_img, reconst_img, lr_x, lr_img
257 |
258 |
259 | ############################################################
260 | ### Losses
261 | ############################################################
262 | class TVLoss(nn.Module):
263 | def forward(self, x):
264 | batch_size = x.size()[0]
265 | h_x = x.size()[2]
266 | w_x = x.size()[3]
267 | count_h = self._tensor_size(x[:, :, 1:, :])
268 | count_w = self._tensor_size(x[:, :, :, 1:])
269 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
270 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
271 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
272 |
273 | def _tensor_size(self, t):
274 | return t.size()[1] * t.size()[2] * t.size()[3]
275 |
276 | class VGGLoss(nn.Module):
277 | # vgg19 perceptual loss
278 | def __init__(self, gpu_ids):
279 | super(VGGLoss, self).__init__()
280 | self.vgg = Vgg19().cuda()
281 | self.criterion = nn.L1Loss()
282 | self.mse_loss = nn.MSELoss()
283 |
284 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
285 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
286 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
287 | self.register_buffer('mean', mean)
288 | self.register_buffer('std', std)
289 |
290 | def gram_matrix(self, x):
291 | (b, ch, h, w) = x.size()
292 | features = x.view(b, ch, w*h)
293 | features_t = features.transpose(1, 2)
294 | gram = features.bmm(features_t) / (ch * h * w)
295 | return gram
296 |
297 | def forward(self, x, y):
298 | x = (x - self.mean) / self.std
299 | y = (y - self.mean) / self.std
300 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
301 |
302 | loss = 0
303 | style_loss = 0
304 | for i in range(len(x_vgg)):
305 | loss += self.weights[i] * \
306 | self.criterion(x_vgg[i], y_vgg[i].detach())
307 | gm_x = self.gram_matrix(x_vgg[i])
308 | gm_y = self.gram_matrix(y_vgg[i])
309 | style_loss += self.weights[i] * self.mse_loss(gm_x, gm_y.detach())
310 | return loss, style_loss
311 |
312 | class GANLoss_D_v2(nn.Module):
313 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor):
314 | super(GANLoss_D_v2, self).__init__()
315 | self.real_label = target_real_label
316 | self.fake_label = target_fake_label
317 | self.real_label_var = None
318 | self.fake_label_var = None
319 | self.Tensor = tensor
320 | if use_lsgan:
321 | self.loss = nn.MSELoss()
322 | else:
323 | def wgan_loss(input, target):
324 | return torch.mean(F.relu(1.-input)) if target else torch.mean(F.relu(1.+input))
325 | self.loss = wgan_loss
326 |
327 | def get_target_tensor(self, input, target_is_real):
328 | target_tensor = None
329 | if target_is_real:
330 | create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel()))
331 | if create_label:
332 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
333 | self.real_label_var = Variable(real_tensor, requires_grad=False)
334 | target_tensor = self.real_label_var
335 | else:
336 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel()))
337 | if create_label:
338 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
339 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
340 | target_tensor = self.fake_label_var
341 | return target_tensor
342 |
343 | def __call__(self, input, target_is_real):
344 | if isinstance(input[0], list):
345 | loss = 0
346 | for input_i in input:
347 | pred = input_i[-1]
348 | target_tensor = self.get_target_tensor(pred, target_is_real)
349 | #loss += self.loss(pred, target_tensor)
350 | loss += self.loss(pred, target_is_real)
351 | return loss
352 | else:
353 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
354 | return self.loss(input[-1], target_tensor)
355 |
356 | class GANLoss_G_v2(nn.Module):
357 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor):
358 | super(GANLoss_G_v2, self).__init__()
359 | self.real_label = target_real_label
360 | self.fake_label = target_fake_label
361 | self.real_label_var = None
362 | self.fake_label_var = None
363 | self.Tensor = tensor
364 | if use_lsgan:
365 | self.loss = nn.MSELoss()
366 | else:
367 | def wgan_loss(input, target):
368 | return -1 * input.mean() if target else input.mean()
369 | self.loss = wgan_loss
370 |
371 | def get_target_tensor(self, input, target_is_real):
372 | target_tensor = None
373 | if target_is_real:
374 | create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel()))
375 | if create_label:
376 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
377 | self.real_label_var = Variable(real_tensor, requires_grad=False)
378 | target_tensor = self.real_label_var
379 | else:
380 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel()))
381 | if create_label:
382 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
383 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
384 | target_tensor = self.fake_label_var
385 | return target_tensor
386 |
387 | def __call__(self, input, target_is_real):
388 | if isinstance(input[0], list):
389 | loss = 0
390 | for input_i in input:
391 | pred = input_i[-1]
392 | target_tensor = self.get_target_tensor(pred, target_is_real)
393 | #loss += self.loss(pred, target_tensor)
394 | loss += self.loss(pred, target_is_real)
395 | return loss
396 | else:
397 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
398 | return self.loss(input[-1], target_tensor)
399 |
400 |
401 | # Define the PatchGAN discriminator with the specified arguments.
402 | class NLayerDiscriminator(nn.Module):
403 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, getIntermFeat=False):
404 | super(NLayerDiscriminator, self).__init__()
405 | self.getIntermFeat = getIntermFeat
406 | self.n_layers = n_layers
407 |
408 | kw = 4
409 | padw = int(np.ceil((kw-1.0)/2))
410 | sequence = [[SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]]
411 |
412 | nf = ndf
413 | for n in range(1, n_layers):
414 | nf_prev = nf
415 | nf = min(nf * 2, 512)
416 | sequence += [[
417 | SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
418 | nn.LeakyReLU(0.2, True)
419 | ]]
420 |
421 | nf_prev = nf
422 | nf = min(nf * 2, 512)
423 | sequence += [[
424 | SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw)),
425 | nn.LeakyReLU(0.2, True)
426 | ]]
427 |
428 | sequence += [[SpectralNorm(nn.Conv2d(nf, nf, kernel_size=kw, stride=1, padding=padw))]]
429 |
430 | # if use_sigmoid:
431 | # sequence += [[nn.Sigmoid()]]
432 |
433 | if getIntermFeat:
434 | for n in range(len(sequence)):
435 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
436 | else:
437 | sequence_stream = []
438 | for n in range(len(sequence)):
439 | sequence_stream += sequence[n]
440 | self.model = nn.Sequential(*sequence_stream)
441 |
442 | def forward(self, input):
443 | if self.getIntermFeat:
444 | res = [input]
445 | for n in range(self.n_layers + 2):
446 | model = getattr(self, 'model'+str(n))
447 | res.append(model(res[-1]))
448 | return res[1:]
449 | else:
450 | return self.model(input)
451 |
452 |
453 | # Define the Multiscale Discriminator.
454 | class MultiscaleDiscriminator(nn.Module):
455 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, getIntermFeat=False):
456 | super(MultiscaleDiscriminator, self).__init__()
457 | self.num_D = num_D
458 | self.n_layers = n_layers
459 | self.getIntermFeat = getIntermFeat
460 |
461 | for i in range(num_D):
462 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
463 | if getIntermFeat:
464 | for j in range(n_layers+2):
465 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
466 | else:
467 | setattr(self, 'layer'+str(i), netD.model)
468 |
469 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
470 |
471 | def singleD_forward(self, model, input):
472 | if self.getIntermFeat:
473 | result = [input]
474 | for i in range(len(model)):
475 | result.append(model[i](result[-1]))
476 | return result[1:]
477 | else:
478 | return [model(input)]
479 |
480 | def forward(self, input):
481 | num_D = self.num_D
482 | result = []
483 | input_downsampled = input
484 | for i in range(num_D):
485 | if self.getIntermFeat:
486 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
487 | else:
488 | model = getattr(self, 'layer'+str(num_D-1-i))
489 | result.append(self.singleD_forward(model, input_downsampled))
490 | if i != (num_D-1):
491 | input_downsampled = self.downsample(input_downsampled)
492 | return result
493 |
494 |
495 | ### Define Vgg19 for vgg_loss
496 | class Vgg19(nn.Module):
497 | def __init__(self, requires_grad=False):
498 | super(Vgg19, self).__init__()
499 | vgg_pretrained_features = models.vgg19(pretrained=True).features
500 | self.slice1 = nn.Sequential()
501 | self.slice2 = nn.Sequential()
502 | self.slice3 = nn.Sequential()
503 | self.slice4 = nn.Sequential()
504 | self.slice5 = nn.Sequential()
505 |
506 | for x in range(1):
507 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
508 | for x in range(1, 6):
509 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
510 | for x in range(6, 11):
511 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
512 | for x in range(11, 20):
513 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
514 | for x in range(20, 29):
515 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
516 |
517 | # fixed pretrained vgg19 model for feature extraction
518 | if not requires_grad:
519 | for param in self.parameters():
520 | param.requires_grad = False
521 |
522 | def forward(self, x):
523 | h_relu1 = self.slice1(x)
524 | h_relu2 = self.slice2(h_relu1)
525 | h_relu3 = self.slice3(h_relu2)
526 | h_relu4 = self.slice4(h_relu3)
527 | h_relu5 = self.slice5(h_relu4)
528 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
529 | return out
530 |
531 | ### Multi-Dilation ResnetBlock
532 | class MultiDilationResnetBlock(nn.Module):
533 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False):
534 | super(MultiDilationResnetBlock, self).__init__()
535 |
536 | self.branch1 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=2, dilation=2, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
537 | self.branch2 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=3, dilation=3, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
538 | self.branch3 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=4, dilation=4, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
539 | self.branch4 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=5, dilation=5, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
540 | self.branch5 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=6, dilation=6, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
541 | self.branch6 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=8, dilation=8, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
542 | self.branch7 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=10, dilation=10, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
543 | self.branch8 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
544 |
545 | self.fusion9 = ConvBlock(input_nc, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type=pad_type, norm=norm, acti=None)
546 |
547 | def forward(self, x):
548 | d1 = self.branch1(x)
549 | d2 = self.branch2(x)
550 | d3 = self.branch3(x)
551 | d4 = self.branch4(x)
552 | d5 = self.branch5(x)
553 | d6 = self.branch6(x)
554 | d7 = self.branch7(x)
555 | d8 = self.branch8(x)
556 | d9 = torch.cat((d1, d2, d3, d4, d5, d6, d7, d8), dim=1)
557 | out = x + self.fusion9(d9)
558 | return out
559 |
560 | ### Multi-Dilation ResnetBlock
561 | class MultiDilationResnetBlock_v3(nn.Module):
562 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False):
563 | super(MultiDilationResnetBlock_v3, self).__init__()
564 |
565 | self.branch1 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=2, dilation=2, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
566 | self.branch2 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=3, dilation=3, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
567 | self.branch3 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=4, dilation=4, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
568 | self.branch4 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=5, dilation=5, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu')
569 |
570 | self.fusion5 = ConvBlock(input_nc, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type=pad_type, norm=norm, acti=None)
571 |
572 | def forward(self, x):
573 | d1 = self.branch1(x)
574 | d2 = self.branch2(x)
575 | d3 = self.branch3(x)
576 | d4 = self.branch4(x)
577 | d5 = torch.cat((d1, d2, d3, d4), dim=1)
578 | out = x + self.fusion5(d5)
579 | return out
580 |
581 | ### ResnetBlock
582 | class ResnetBlock(nn.Module):
583 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False):
584 | super(ResnetBlock, self).__init__()
585 | self.conv_block = self.build_conv_block(input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout)
586 |
587 |
588 | def build_conv_block(self, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout):
589 | conv_block = []
590 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti='relu')]
591 | if use_dropout:
592 | conv_block += [nn.Dropout(0.5)]
593 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti=None)]
594 |
595 | return nn.Sequential(*conv_block)
596 |
597 | def forward(self, x):
598 | out = x + self.conv_block(x)
599 | return out
600 |
601 | ### ResnetBlock
602 | class ResnetBlock_v2(nn.Module):
603 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False):
604 | super(ResnetBlock_v2, self).__init__()
605 | self.conv_block = self.build_conv_block(input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout)
606 |
607 |
608 | def build_conv_block(self, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout):
609 | conv_block = []
610 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size=3, stride=1, padding=padding, dilation=dilation, groups=groups, bias=bias, pad_type=pad_type, norm=norm, acti='elu')]
611 | if use_dropout:
612 | conv_block += [nn.Dropout(0.5)]
613 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti=None)]
614 |
615 | return nn.Sequential(*conv_block)
616 |
617 | def forward(self, x):
618 | out = x + self.conv_block(x)
619 | return out
620 |
621 | ### NonLocalBlock2D
622 | class NonLocalBlock(nn.Module):
623 | def __init__(self, input_nc, inter_nc=None, sub_sample=True, bn_layer=True):
624 | super(NonLocalBlock, self).__init__()
625 | self.input_nc = input_nc
626 | self.inter_nc = inter_nc
627 |
628 | if inter_nc is None:
629 | self.inter_nc = input_nc // 2
630 |
631 | self.g = nn.Conv2d(in_channels=self.input_nc, out_channels=self.inter_nc, kernel_size=1, stride=1, padding=0)
632 |
633 | if bn_layer:
634 | self.W = nn.Sequential(
635 | nn.Conv2d(in_channels=self.inter_nc, out_channels=self.input_nc, kernel_size=1, stride=1, padding=0),
636 | nn.BatchNorm2d(self.input_nc)
637 | )
638 | self.W[0].weight.data.zero_()
639 | self.W[0].bias.data.zero_()
640 | else:
641 | self.W = nn.Conv2d(in_channels=self.inter_nc, out_channels=self.input_nc, kernel_size=1, stride=1, padding=0)
642 | self.W.weight.data.zero_()
643 | self.W.bias.data.zero_()
644 |
645 | self.theta = nn.Conv2d(in_channels=self.input_nc, out_channels=self.inter_nc, kernel_size=1, stride=1, padding=0)
646 | self.phi = nn.Conv2d(in_channels=self.input_nc, out_channels=self.inter_nc, kernel_size=1, stride=1, padding=0)
647 |
648 | if sub_sample:
649 | self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size(2, 2)))
650 | self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size(2, 2)))
651 |
652 | def forward(self, x):
653 | batch_size = x.size(0)
654 |
655 | g_x = self.g(x).view(batch_size, self.inter_nc, -1)
656 | g_x = g_x.permute(0, 2, 1)
657 |
658 | theta_x = self.theta(x).view(batch_size, self.inter_nc, -1)
659 | theta_x = theta_x.permute(0, 2, 1)
660 |
661 | phi_x = self.phi(x).view(batch_size, self.inter_nc, -1)
662 |
663 | f = torch.matmul(theta_x, phi_x)
664 | f_div_C = F.softmax(f, dim=-1)
665 |
666 | y = torch.matmul(f_div_C, g_x)
667 | y = y.permute(0, 2, 1).contiguous()
668 | y = y.view(batch_size, self.inter_nc, *x.size()[2:])
669 | W_y = self.W(y)
670 |
671 | z = W_y + x
672 | return z
673 |
674 | ### ConvBlock
675 | class ConvBlock(nn.Module):
676 | def __init__(self, input_nc, output_nc, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, pad_type='zero', norm=None, acti='lrelu'):
677 | super(ConvBlock, self).__init__()
678 | self.use_bias = bias
679 |
680 | # initialize padding
681 | if pad_type == 'reflect':
682 | self.pad = nn.ReflectionPad2d(padding)
683 | elif pad_type == 'zero':
684 | self.pad = nn.ZeroPad2d(padding)
685 | elif pad_type == 'replicate':
686 | self.pad = nn.ReplicationPad2d(padding)
687 | else:
688 | assert 0, "Unsupported padding type: {}".format(pad_type)
689 |
690 | # initialize normalization
691 | if norm == 'batch':
692 | self.norm = nn.BatchNorm2d(output_nc)
693 | elif norm == 'instance':
694 | self.norm = nn.InstanceNorm2d(output_nc)
695 | elif norm is None or norm == 'spectral':
696 | self.norm = None
697 | else:
698 | assert 0, "Unsupported normalization: {}".format(norm)
699 |
700 | # initialize activation
701 | if acti == 'relu':
702 | self.acti = nn.ReLU(inplace=True)
703 | elif acti == 'lrelu':
704 | self.acti = nn.LeakyReLU(0.2, inplace=True)
705 | elif acti == 'prelu':
706 | self.acti = nn.PReLU()
707 | elif acti == 'elu':
708 | self.acti = nn.ELU()
709 | elif acti == 'tanh':
710 | self.acti = nn.Tanh()
711 | elif acti == 'sigmoid':
712 | self.acti = nn.Sigmoid()
713 | elif acti is None:
714 | self.acti = None
715 | else:
716 | assert 0, "Unsupported activation: {}".format(acti)
717 |
718 | # initialize convolution
719 | if norm == 'spectral':
720 | self.conv = SpectralNorm(nn.Conv2d(input_nc, output_nc, kernel_size, stride, dilation=dilation, groups=groups, bias=self.use_bias))
721 | else:
722 | self.conv = nn.Conv2d(input_nc, output_nc, kernel_size, stride, dilation=dilation, groups=groups, bias=self.use_bias)
723 |
724 | def forward(self, x):
725 | x = self.conv(self.pad(x))
726 | if self.norm:
727 | x = self.norm(x)
728 | if self.acti:
729 | x = self.acti(x)
730 | return x
731 |
732 | def l2normalize(v, eps=1e-12):
733 | return v / (v.norm() + eps)
734 |
735 | ### SpectralNorm
736 | class SpectralNorm(nn.Module):
737 | """
738 | Spectral Normalization for Generative Adversarial Networks
739 | Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
740 | """
741 | def __init__(self, module, name='weight', power_iterations=1):
742 | super(SpectralNorm, self).__init__()
743 | self.module = module
744 | self.name = name
745 | self.power_iterations = power_iterations
746 | if not self._made_params():
747 | self._make_params()
748 |
749 | def _update_u_v(self):
750 | u = getattr(self.module, self.name + "_u")
751 | v = getattr(self.module, self.name + "_v")
752 | w = getattr(self.module, self.name + "_bar")
753 |
754 | height = w.data.shape[0]
755 | for _ in range(self.power_iterations):
756 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
757 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
758 |
759 | sigma = u.dot(w.view(height, -1).mv(v))
760 | setattr(self.module, self.name, w / sigma.expand_as(w))
761 |
762 | def _made_params(self):
763 | try:
764 | u = getattr(self.module, self.name + "_u")
765 | v = getattr(self.module, self.name + "_v")
766 | w = getattr(self.module, self.name + "_bar")
767 | return True
768 | except AttributeError:
769 | return False
770 |
771 | def _make_params(self):
772 | w = getattr(self.module, self.name)
773 |
774 | height = w.data.shape[0]
775 | width = w.view(height, -1).data.shape[1]
776 |
777 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
778 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
779 | u.data = l2normalize(u.data)
780 | v.data = l2normalize(v.data)
781 | w_bar = nn.Parameter(w.data)
782 |
783 | del self.module._parameters[self.name]
784 |
785 | self.module.register_parameter(self.name + "_u", u)
786 | self.module.register_parameter(self.name + "_v", v)
787 | self.module.register_parameter(self.name + "_bar", w_bar)
788 |
789 | def forward(self, *args):
790 | self._update_u_v()
791 | return self.module.forward(*args)
792 |
793 |
--------------------------------------------------------------------------------
/models/our_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import os
5 | from torch.autograd import Variable
6 | #from util.image_pool import ImagePool
7 | # BaseModel, save_network & load_network
8 | from .base_model import BaseModel
9 | from . import networks
10 |
11 | class OurModel(BaseModel):
12 | def name(self):
13 | return 'OurModel'
14 |
15 | def get_num_params(self):
16 | return self.num_params_G, self.num_params_D
17 |
18 | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
19 | flags = (True, True, True, True, use_gan_feat_loss, use_vgg_loss, True, True, True)
20 | def loss_filter(g_gan, g_coarse_l1, g_out_l1, g_style, g_gan_feat, g_vgg, g_tv, d_real, d_fake):
21 | return [l for (l, f) in zip((g_gan, g_coarse_l1, g_out_l1, g_style, g_gan_feat, g_vgg, g_tv, d_real, d_fake), flags) if f]
22 | return loss_filter
23 |
24 | def initialize(self, opt):
25 | BaseModel.initialize(self, opt)
26 | self.isTrain = opt.isTrain
27 | input_nc = opt.input_nc ### masked_img + mask (RGB + gray) 4 channels
28 |
29 | ### define networks
30 | # Generator
31 | netG_input_nc = input_nc
32 | self.netG, self.num_params_G = networks.define_G(netG_input_nc, opt.output_nc, ngf=64, gpu_ids=self.gpu_ids)
33 |
34 | # Discriminator
35 | if self.isTrain:
36 | use_sigmoid = opt.no_lsgan
37 | netD_input_nc = opt.output_nc * 2
38 | self.netD, self.num_params_D = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
39 | # for param in self.netD.parameters():
40 | # param.requires_grad = False
41 | else:
42 | self.num_params_D = 0
43 |
44 | print('-------------------- Networks initialized --------------------')
45 |
46 | ### load networks
47 | if not self.isTrain or opt.continue_train or opt.load_pretrain:
48 | pretrained_path = '' if not self.isTrain else opt.load_pretrain
49 | print(pretrained_path)
50 |
51 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
52 | if self.isTrain:
53 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
54 |
55 |
56 | ### set loss functions and optimizers
57 | if self.isTrain:
58 | if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
59 | raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
60 | # self.fake_pool = ImagePool(opt.pool_size)
61 | self.old_lr = opt.lr
62 |
63 | # define loss functions
64 | self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
65 |
66 | self.criterionGAN_D = networks.GANLoss_D_v2(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
67 | self.criterionGAN_G = networks.GANLoss_G_v2(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
68 |
69 | self.criterionFeat = torch.nn.L1Loss()
70 | if not opt.no_vgg_loss:
71 | self.criterionVGG = networks.VGGLoss(self.gpu_ids)
72 | self.criterionL1 = torch.nn.L1Loss()
73 | self.down = nn.UpsamplingBilinear2d(scale_factor=0.25)
74 | self.resize = nn.UpsamplingBilinear2d(size=(224, 224))
75 | self.criterionTV = networks.TVLoss()
76 |
77 | # Names so we can breakout loss
78 | self.loss_names = self.loss_filter('G_GAN', 'G_COARSE_L1', 'G_OUT_L1', 'G_STYLE', 'G_GAN_Feat', 'G_VGG', 'G_TV', 'D_real', 'D_fake')
79 |
80 | # initialize optimizers
81 | # optimizer G
82 | if opt.niter_fix_global > 0:
83 | import sys
84 | if sys.version_info >= (3,0):
85 | finetune_list = set()
86 | else:
87 | from sets import Set
88 | finetune_list = Set()
89 |
90 | params_dict = dict(self.netG.named_parameters())
91 | params = []
92 | for key, value in params_dict.items():
93 | if key.startswith('model' + str(opt.n_local_enhancers)):
94 | params += [value]
95 | finetune_list.add(key.split('.')[0])
96 | print('--------------- Only training the local enhancer network (for %d epochs) ---------------' % opt.niter_fix_global)
97 | print('The layers that are finetuned are ', sorted(finetune_list))
98 | else:
99 | params = list(self.netG.parameters())
100 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
101 |
102 | # optimizer D
103 | params = list(self.netD.parameters())
104 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr * 4.0, betas=(opt.beta1, 0.999))
105 |
106 | def encode_input(self, masked_image, mask, real_image, infer=False):
107 | input_msked_img = masked_image.data.cuda()
108 | input_msk = mask.data.cuda()
109 |
110 | input_msked_img = Variable(input_msked_img, volatile=infer)
111 | input_msk = Variable(input_msk, volatile=infer)
112 |
113 | real_image = Variable(real_image.data.cuda(), volatile=infer)
114 |
115 | return input_msked_img, input_msk, real_image
116 |
117 | def encode_input_test(self, masked_image, mask, infer=False):
118 | input_msked_img = masked_image.data.cuda()
119 | input_msk = mask.data.cuda()
120 |
121 | input_msked_img = Variable(input_msked_img)
122 | input_msk = Variable(input_msk)
123 |
124 | return input_msked_img, input_msk
125 |
126 | def discriminate(self, input_image, test_image, use_pool=False):
127 | input_concat = torch.cat((input_image, test_image.detach()), dim=1)
128 | return self.netD.forward(input_concat)
129 |
130 | def forward(self, masked_img, mask, real_img=None, infer=False, mode=None):
131 | # inference
132 | if mode == 'inference':
133 | compltd_img, reconst_img, lr_x = self.inference(masked_img, mask)
134 | return compltd_img, reconst_img, lr_x
135 |
136 | # Encode inputs
137 | input_msked_img, input_msk, real_image = self.encode_input(masked_img, mask, real_img)
138 |
139 | # Fake Generation
140 | compltd_img, reconst_img, lr_x, lr_img = self.netG.forward(input_msked_img, input_msk, real_image)
141 | lr_msk = input_msk
142 |
143 | msk025 = self.down(input_msk)
144 | img025 = self.down(real_image)
145 |
146 | # Fake Detection and Loss
147 | pred_fake = self.discriminate(input_msked_img, compltd_img)
148 | loss_D_fake = self.criterionGAN_D(pred_fake, False)
149 |
150 | # Real Detection and Loss
151 | pred_real = self.discriminate(input_msked_img, real_image)
152 | loss_D_real = self.criterionGAN_D(pred_real, True)
153 |
154 | # GAN Loss (Fake pass-ability loss)
155 | pred_fake = self.netD.forward(torch.cat((input_msked_img, compltd_img), dim=1))
156 | loss_G_GAN = self.criterionGAN_G(pred_fake, True)
157 | loss_G_GAN *= self.opt.lambda_gan
158 |
159 | # GAN L1 Loss
160 | loss_G_GAN_L1 = self.criterionL1(reconst_img * input_msk, real_image * input_msk) * self.opt.lambda_l1
161 | loss_G_GAN_L1 += self.criterionL1(reconst_img * (1 - input_msk), real_image * (1 - input_msk))
162 | loss_G_GAN_L1 += self.criterionL1(lr_img * msk025, img025 * msk025) * self.opt.lambda_l1
163 | loss_G_GAN_L1 += self.criterionL1(lr_img * (1 - msk025), img025 * (1 - msk025))
164 |
165 | loss_G_COARSE_L1 = self.criterionL1(lr_x * lr_msk, real_image * lr_msk) * self.opt.lambda_l1
166 | loss_G_COARSE_L1 += self.criterionL1(lr_x * (1 - lr_msk), real_image * (1 - lr_msk))
167 |
168 | loss_G_TV = self.criterionTV(compltd_img) * self.opt.lambda_tv
169 |
170 | # GAN feature matching loss
171 | loss_G_GAN_Feat = 0
172 | if not self.opt.no_ganFeat_loss:
173 | feat_weights = 4.0 / (self.opt.n_layers_D + 1)
174 | D_weights = 1.0 / self.opt.num_D
175 | for i in range(self.opt.num_D):
176 | for j in range(len(pred_fake[i]) - 1):
177 | loss_G_GAN_Feat += D_weights * feat_weights * self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
178 | loss_G_GAN_Feat *= self.opt.lambda_feat
179 |
180 | # VGG feature matching loss
181 | loss_G_VGG = 0
182 | if not self.opt.no_vgg_loss:
183 | resized_reconst_img = self.resize(reconst_img)
184 | resized_compltd_img = self.resize(compltd_img)
185 | resized_real_img = self.resize(real_image)
186 | loss_G_VGG, loss_G_VGGStyle = self.criterionVGG(resized_reconst_img, resized_real_img)
187 | loss_G_VGG2, loss_G_VGGStyle2 = self.criterionVGG(resized_compltd_img, resized_real_img)
188 | loss_G_VGG += loss_G_VGG2
189 | loss_G_VGGStyle += loss_G_VGGStyle2
190 | loss_G_VGG *= self.opt.lambda_vgg
191 | loss_G_VGGStyle *= self.opt.lambda_style
192 |
193 | # only return the fake image if necessary
194 | return [ self.loss_filter( loss_G_GAN, loss_G_COARSE_L1, loss_G_GAN_L1, loss_G_VGGStyle, loss_G_GAN_Feat, loss_G_VGG, loss_G_TV, loss_D_real, loss_D_fake ), None if not infer else compltd_img, reconst_img, lr_x ]
195 |
196 | def inference(self, masked_img, mask):
197 | # Encode inputs
198 | input_msked_img, input_msk = self.encode_input_test(Variable(masked_img), Variable(mask), infer=True)
199 |
200 | # Fake Generation
201 | if torch.__version__.startswith('0.4'):
202 | with torch.no_grad():
203 | compltd_img, reconst_img, lr_x, lr_img = self.netG.forward(input_msked_img, input_msk)
204 | else:
205 | compltd_img, reconst_img, lr_x, lr_img = self.netG.forward(input_msked_img, input_msk)
206 | return compltd_img, reconst_img, lr_x
207 |
208 | def save(self, which_epoch):
209 | self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
210 | self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
211 |
212 | def update_fixed_params(self):
213 | # after fixing the global generator for a # of iterations, also start finetuning it
214 | params = list(self.netG.parameters())
215 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
216 | print('----------------- Now also finetuning global generator -----------------')
217 |
218 | def update_learning_rate(self):
219 | lrd = self.opt.lr / self.opt.niter_decay
220 | lr = self.old_lr - lrd
221 | for param_group in self.optimizer_D.param_groups:
222 | param_group['lr'] = lr * 4.0
223 | for param_group in self.optimizer_G.param_groups:
224 | param_group['lr'] = lr
225 | print('update learning rate: %f -> %f' % (self.old_lr, lr))
226 | self.old_lr = lr
227 |
228 | class InferenceModel(OurModel):
229 | def forward(self, inp1, inp2):
230 | masked_img = inp1
231 | mask = inp2
232 | return self.inference(masked_img, mask)
233 |
234 |
--------------------------------------------------------------------------------
/options/__pycache__/base_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/base_options.cpython-36.pyc
--------------------------------------------------------------------------------
/options/__pycache__/base_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/base_options.cpython-37.pyc
--------------------------------------------------------------------------------
/options/__pycache__/test_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/test_options.cpython-36.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/train_options.cpython-36.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/train_options.cpython-37.pyc
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | ####################################################################################
2 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
3 | # Licensed under the CC BY-NC-SA 4.0 license
4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
5 | ####################################################################################
6 | import argparse
7 | import os
8 | import torch
9 | from util import util
10 |
11 |
12 | class BaseOptions():
13 | def __init__(self):
14 | self.parser = argparse.ArgumentParser(description='Inpainting')
15 | self.initialized = False
16 |
17 | def initialize(self):
18 | # experiment specifics
19 | self.parser.add_argument('--name', type=str, default='experiment', help='name of the experiment. It decides where to store samples and models')
20 | self.parser.add_argument('--gpu_ids', type=str, default='0,1', help='gpu ids: e.g. 0 0,1,2 0,2. use -1 for cpu')
21 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
22 | self.parser.add_argument('--model', type=str, default='Ours', help='which model to use')
23 | self.parser.add_argument('--norm', type=str, default='instance', help='instance or batch normalization')
24 |
25 |
26 | # input/output sizes
27 | self.parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
28 | self.parser.add_argument('--loadSize', type=int, default=768, help='scale images to this size')
29 | self.parser.add_argument('--fineSize', type=int, default=256, help='image size to the model')
30 | self.parser.add_argument('--input_nc', type=int, default=4, help='# of input image channels')
31 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
32 |
33 | # for setting inputs
34 | self.parser.add_argument('--dataroot', type=str, default='./datasets/ade20k/')
35 | self.parser.add_argument('--resize_or_crop', type=str, default='standard', help='scaling and/or cropping of images at load time')
36 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, no shuffle')
37 | self.parser.add_argument('--no_flip', action='store_true', help='if true, no flip')
38 | self.parser.add_argument('--nThreads', type=int, default=6, help='# of threads for loading data')
39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='max # of images allowed per dataset')
40 |
41 | # for displays
42 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
43 |
44 | # for generator
45 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
46 | self.parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the first conv layer')
47 | self.parser.add_argument('--n_downsample_global', type=int, default=5, help='# of downsampling layers in netG')
48 | self.parser.add_argument('--n_blocks_global', type=int, default=6, help='# of resnet blocks in the global generator network')
49 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='# of resnet blocks in the local enhancer network')
50 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='# of local enhancers to use')
51 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='# of epochs that we only train the outmost local enhancer')
52 |
53 | self.initialized = True
54 |
55 | def parse(self, save=True):
56 | if not self.initialized:
57 | self.initialize()
58 | self.opt = self.parser.parse_args()
59 | self.opt.isTrain = self.isTrain
60 |
61 | str_ids = self.opt.gpu_ids.split(',')
62 | self.opt.gpu_ids = []
63 | for str_id in str_ids:
64 | id = int(str_id)
65 | if id >= 0:
66 | self.opt.gpu_ids.append(id)
67 |
68 | # set gpu ids
69 | if len(self.opt.gpu_ids) > 0:
70 | torch.cuda.set_device(self.opt.gpu_ids[0])
71 |
72 | args = vars(self.opt)
73 |
74 | # print options
75 | print('-------------------- Options --------------------')
76 | for k, v in sorted(args.items()):
77 | print('%s: %s' % (str(k), str(v)))
78 | print('---------------------- End ----------------------')
79 |
80 | # save the options to disk
81 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
82 | util.mkdirs(expr_dir)
83 |
84 | if save and not self.opt.continue_train:
85 | file_name = os.path.join(expr_dir, 'opt.txt')
86 | with open(file_name, 'wt') as opt_file:
87 | opt_file.write('-------------------- Options --------------------\n')
88 | for k, v in sorted(args.items()):
89 | opt_file.write('%s: %s\n' % (str(k), str(v)))
90 | opt_file.write('---------------------- End ----------------------\n')
91 |
92 | return self.opt
93 |
94 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TestOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 |
7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here')
8 | self.parser.add_argument('--phase', type=str, default='test', help='train or test')
9 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load?')
10 |
11 | self.isTrain = False
12 |
13 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TrainOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 |
7 | # for displays
8 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
9 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
10 | self.parser.add_argument('--save_latest_freq', type=int, default=3000, help='frequency of saving the latest results')
11 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
12 | self.parser.add_argument('--no_html', action='store_true', help='if true, do not save intermediate training results to web')
13 | self.parser.add_argument('--tf_log', action='store_true', help='if true, use tensorboard logging')
14 |
15 | # for training
16 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
17 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
18 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load?')
19 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, or test')
20 | self.parser.add_argument('--niter', type=int, default=10, help='# of iter at starting learning rate')
21 | self.parser.add_argument('--niter_decay', type=int, default=90, help='# of iter to linearly decay learning rate to zero')
22 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
23 | self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
24 |
25 | # for discriminators
26 | self.parser.add_argument('--num_D', type=int, default=2, help='# of discriminators to use')
27 | self.parser.add_argument('--n_layers_D', type=int, default=2, help='only used if which_model_netD==n_layers')
28 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discriminator filters in first conv layer')
29 | self.parser.add_argument('--lambda_feat', type=float, default=0.01, help='weight for feature matching loss')
30 | self.parser.add_argument('--lambda_vgg', type=float, default=0.05, help='weight for vgg feature matching loss')
31 | self.parser.add_argument('--lambda_l1', type=float, default=5.0, help='weight for l1 loss')
32 | self.parser.add_argument('--lambda_tv', type=float, default=0.1, help='weight for tv loss')
33 | self.parser.add_argument('--lambda_style', type=float, default=80.0, help='weight for style loss')
34 | self.parser.add_argument('--lambda_gan', type=float, default=0.001, help='weight for g gan loss')
35 | self.parser.add_argument('--no_ganFeat_loss', action='store_false', help='if specified, do *not* use discriminator feature matching loss')
36 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
37 | # use least square GAN
38 | self.parser.add_argument('--no_lsgan', action='store_true', default=True, help='do *not* use least square GAN, if false, use vanilla GAN')
39 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
40 |
41 | self.isTrain = True
42 |
43 |
--------------------------------------------------------------------------------
/results/test/AIM_IC_t1_validation_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/results/test/AIM_IC_t1_validation_0.png
--------------------------------------------------------------------------------
/test_ensemble.py:
--------------------------------------------------------------------------------
1 | from util.visualizer import Visualizer
2 | import util.util as util
3 | from data.data_loader import CreateDataLoader
4 | from options.test_options import TestOptions
5 | import torchvision.transforms as transforms
6 | import time
7 | import os
8 | import numpy as np
9 | from PIL import Image
10 | import torch
11 | from torch.autograd import Variable
12 | from collections import OrderedDict
13 |
14 | from models.models import create_model
15 | from util import html
16 |
17 | import copy
18 |
19 | opt = TestOptions().parse(save=False)
20 | opt.nThreads = 1
21 | opt.batchSize = 1
22 | opt.serial_batches = True
23 | opt.no_flip = True
24 |
25 | data_loader = CreateDataLoader(opt)
26 | dataset = data_loader.load_data()
27 | dataset_size = len(data_loader)
28 |
29 | def __make_power_2(img, base=256, method=Image.BICUBIC):
30 | ow, oh = img.size
31 |
32 | h = int(round(oh / base) * base)
33 | w = int(round(ow / base) * base)
34 |
35 | if h == 0:
36 | h = base
37 | if w == 0:
38 | w = base
39 |
40 | if (h == oh) and (w == ow):
41 | return img
42 | return img.resize((w, h), method)
43 |
44 | model, num_params_G, num_params_D = create_model(opt)
45 | model.eval()
46 |
47 | rlt_dir = os.path.join(opt.results_dir, 'test')
48 | util.mkdirs([rlt_dir])
49 |
50 | transform_list = []
51 | transform_list += [transforms.ToTensor()]
52 | mskimg_transform = transforms.Compose(transform_list)
53 | transform_list = []
54 | transform_list += [transforms.ToTensor()]
55 | msk_transform = transforms.Compose(transform_list)
56 |
57 | start_time = time.time()
58 | for i, data in enumerate(dataset):
59 | with torch.no_grad():
60 | msk_img_path = data['path_mskimg'][0]
61 | filename = os.path.basename(msk_img_path)
62 | msk_path = data['path_msk'][0]
63 |
64 | oimg = Image.open(msk_img_path).convert('RGB')
65 | omsk = Image.open(msk_path).convert('L')
66 | ow, oh = oimg.size
67 |
68 | ###
69 | resized_img = __make_power_2(oimg)
70 | resized_msk = __make_power_2(omsk, method=Image.BILINEAR)
71 | rw, rh = resized_img.size
72 |
73 | hori_ver = rw // 256
74 | vert_ver = rh // 256
75 |
76 | tmp_img = oimg.resize((256, 256), Image.BICUBIC)
77 | tmp_msk = omsk.resize((256, 256), Image.BICUBIC)
78 |
79 | np_tmp_img = np.array(tmp_img, np.uint8)
80 | np_tmp_msk = np.array(tmp_msk, np.uint8)
81 |
82 | np_resized_img = np.array(resized_img, np.uint8)
83 | np_resized_msk = np.array(resized_msk, np.uint8)
84 | np_resized_msk = np_resized_msk > 0
85 | np_resized_img[:,:,0] = np_resized_img[:,:,0] * (1 - np_resized_msk) + 255 * np_resized_msk
86 | np_resized_img[:,:,1] = np_resized_img[:,:,1] * (1 - np_resized_msk) + 255 * np_resized_msk
87 | np_resized_img[:,:,2] = np_resized_img[:,:,2] * (1 - np_resized_msk) + 255 * np_resized_msk
88 | np_resized_msk = np_resized_msk * 255
89 |
90 | img_arr = []
91 | msk_arr = []
92 |
93 | ###
94 | for hv in range(hori_ver):
95 | for vv in range(vert_ver):
96 | for i in range(256):
97 | for j in range(256):
98 | np_tmp_img[i, j, 0] = np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 0]
99 | np_tmp_img[i, j, 1] = np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 1]
100 | np_tmp_img[i, j, 2] = np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 2]
101 | np_tmp_msk[i, j] = np_resized_msk[vv + vert_ver*j, hv + hori_ver*i]
102 | img_arr.append(np.copy(np_tmp_img))
103 | msk_arr.append(np.copy(np_tmp_msk))
104 |
105 | ###
106 | compltd_arr = []
107 | for i in range(len(img_arr)):
108 | img = Image.fromarray(img_arr[i])
109 | msk = Image.fromarray(msk_arr[i])
110 |
111 | img_90 = img.rotate(90)
112 | msk_90 = msk.rotate(90)
113 | img_180 = img.rotate(180)
114 | msk_180 = msk.rotate(180)
115 | img_270 = img.rotate(270)
116 | msk_270 = msk.rotate(270)
117 | img_flp = img.transpose(method=Image.FLIP_LEFT_RIGHT)
118 | msk_flp = msk.transpose(method=Image.FLIP_LEFT_RIGHT)
119 |
120 | compltd_img, reconst_img, lr_x = model(mskimg_transform(img).unsqueeze(0), msk_transform(msk).unsqueeze(0))
121 | compltd_img_90, reconst_img_90, lr_x_90 = model(mskimg_transform(img_90).unsqueeze(0), msk_transform(msk_90).unsqueeze(0))
122 | compltd_img_180, reconst_img_180, lr_x_180 = model(mskimg_transform(img_180).unsqueeze(0), msk_transform(msk_180).unsqueeze(0))
123 | compltd_img_270, reconst_img_270, lr_x_270 = model(mskimg_transform(img_270).unsqueeze(0), msk_transform(msk_270).unsqueeze(0))
124 | compltd_img_flp, reconst_img_flp, lr_x_flp = model(mskimg_transform(img_flp).unsqueeze(0), msk_transform(msk_flp).unsqueeze(0))
125 | np_compltd_img = util.tensor2im(reconst_img.data[0], normalize=False)
126 | np_compltd_img_90 = util.tensor2im(reconst_img_90.data[0], normalize=False)
127 | np_compltd_img_180 = util.tensor2im(reconst_img_180.data[0], normalize=False)
128 | np_compltd_img_270 = util.tensor2im(reconst_img_270.data[0], normalize=False)
129 | np_compltd_img_flp = util.tensor2im(reconst_img_flp.data[0], normalize=False)
130 |
131 | new_img_90 = Image.fromarray(np_compltd_img_90)
132 | new_img_90 = new_img_90.rotate(270)
133 | np_new_img_90 = np.array(new_img_90, np.float)
134 |
135 | new_img_180 = Image.fromarray(np_compltd_img_180)
136 | new_img_180 = new_img_180.rotate(180)
137 | np_new_img_180 = np.array(new_img_180, np.float)
138 |
139 | new_img_270 = Image.fromarray(np_compltd_img_270)
140 | new_img_270 = new_img_270.rotate(90)
141 | np_new_img_270 = np.array(new_img_270, np.float)
142 |
143 | new_img_flp = Image.fromarray(np_compltd_img_flp)
144 | new_img_flp = new_img_flp.transpose(method=Image.FLIP_LEFT_RIGHT)
145 | np_new_img_flp = np.array(new_img_flp, np.float)
146 |
147 | np_compltd_img = (np_compltd_img + np_new_img_90 + np_new_img_180 + np_new_img_270 + np_new_img_flp) / 5.0
148 | np_compltd_img = np.array(np.round(np_compltd_img), np.uint8)
149 | final_img = Image.fromarray(np_compltd_img, mode="RGB")
150 | np_compltd_img = np.array(final_img, np.uint8)
151 |
152 | compltd_arr.append(np.copy(np_compltd_img))
153 |
154 | ###
155 | ver_idx = 0
156 | for hv in range(hori_ver):
157 | for vv in range(vert_ver):
158 | #np_tmp_img = compltd_arr[ver_idx]
159 | for i in range(256):
160 | for j in range(256):
161 | np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 0] = compltd_arr[ver_idx][i, j, 0]
162 | np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 1] = compltd_arr[ver_idx][i, j, 1]
163 | np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 2] = compltd_arr[ver_idx][i, j, 2]
164 | ver_idx += 1
165 |
166 | ###
167 | new_compltd_img = Image.fromarray(np_resized_img)
168 | new_compltd_img = new_compltd_img.resize((ow, oh), Image.BICUBIC)
169 | new_compltd_img = new_compltd_img.resize((int(ow*0.5), int(oh*0.5)), Image.BICUBIC)
170 | new_compltd_img = new_compltd_img.resize((ow, oh), Image.BICUBIC)
171 | np_new_compltd_img = np.array(new_compltd_img)
172 | np_oimg = np.array(oimg)
173 | np_omsk = np.array(omsk)
174 |
175 | np_new_compltd_img[:, :, 0] = np_new_compltd_img[:, :, 0] * (np_omsk / 255.0) + ((255.0 - np_omsk) / 255.0) * np_oimg[:, :, 0]
176 | np_new_compltd_img[:, :, 1] = np_new_compltd_img[:, :, 1] * (np_omsk / 255.0) + ((255.0 - np_omsk) / 255.0) * np_oimg[:, :, 1]
177 | np_new_compltd_img[:, :, 2] = np_new_compltd_img[:, :, 2] * (np_omsk / 255.0) + ((255.0 - np_omsk) / 255.0) * np_oimg[:, :, 2]
178 |
179 | newfilename = filename.replace("_with_holes", "")
180 | compltd_path = os.path.join(rlt_dir, newfilename)
181 | util.save_image(np_new_compltd_img, compltd_path)
182 |
183 | print(compltd_path)
184 |
185 | end_time = time.time() - start_time
186 | print('Avg Time Taken: %.3f sec' % (end_time / dataset_size))
187 |
188 | print('done')
189 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from util.visualizer import Visualizer
2 | import util.util as util
3 | from data.data_loader import CreateDataLoader
4 | from options.train_options import TrainOptions
5 | import time
6 | import os
7 | import numpy as np
8 | import torch
9 | from torch.autograd import Variable
10 | from collections import OrderedDict
11 |
12 | import fractions
13 | def lcm(a, b):
14 | return abs(a*b) / fractions.gcd(a, b) if a and b else 0
15 |
16 | from models.models import create_model
17 |
18 | opt = TrainOptions().parse()
19 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
20 | param_path = os.path.join(opt.checkpoints_dir, opt.name, 'param.txt')
21 |
22 | # continue training or start from scratch
23 | if opt.continue_train:
24 | try:
25 | start_epoch, epoch_iter = np.loadtxt(
26 | iter_path, delimiter=',', dtype=int)
27 | except:
28 | start_epoch, epoch_iter = 1, 0
29 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
30 | else:
31 | start_epoch, epoch_iter = 1, 0
32 |
33 | opt.print_freq = lcm(opt.print_freq, opt.batchSize)
34 |
35 |
36 | ########################################
37 | # load dataset
38 | ########################################
39 | data_loader = CreateDataLoader(opt)
40 | dataset = data_loader.load_data()
41 | dataset_size = len(data_loader)
42 | print('# of training images = %d' % dataset_size)
43 |
44 | ########################################
45 | # define model and optimizer
46 | ########################################
47 | # define own model
48 | model, num_params_G, num_params_D = create_model(opt)
49 | optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D
50 | # output the model size (i.e. num_params_G and num_params_D to txt file)
51 | np.savetxt(param_path, (num_params_G, num_params_D), delimiter=',', fmt='%d')
52 |
53 | ########################################
54 | # define visualizer
55 | ########################################
56 | visualizer = Visualizer(opt)
57 |
58 | ########################################
59 | # define train and/val loop
60 | ########################################
61 | total_steps = (start_epoch - 1) * dataset_size + epoch_iter
62 |
63 | display_delta = total_steps % opt.display_freq
64 | print_delta = total_steps % opt.print_freq
65 | save_delta = total_steps % opt.save_latest_freq
66 |
67 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
68 | model.train()
69 | # by default, start_epoch, epoch_iter = 1, 0
70 | epoch_start_time = time.time()
71 | if epoch != start_epoch:
72 | epoch_iter = epoch_iter % dataset_size
73 | for i, data in enumerate(dataset, start=epoch_iter):
74 | if total_steps % opt.print_freq == print_delta:
75 | iter_start_time = time.time()
76 | total_steps += opt.batchSize
77 | epoch_iter += opt.batchSize
78 |
79 | # whether to collect output images
80 | save_fake = total_steps % opt.display_freq == display_delta
81 |
82 | ############### Forward pass ###############
83 | # get pred and calculate loss
84 | B, C, H, W = data['masked_image'].shape
85 | msk_img = data['masked_image'].view(-1, 3, H, W)
86 | msk = data['mask'].view(-1, 1, H, W)
87 | real_img = data['real_image'].view(-1, 3, H, W)
88 |
89 | losses, compltd_img, reconst_img, lr_x = model(Variable(msk_img), Variable(msk),
90 | Variable(real_img), infer=save_fake)
91 |
92 | # sum per device losses
93 | losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
94 | loss_dict = dict(zip(model.module.loss_names, losses))
95 |
96 | # calculate final loss scalar
97 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
98 | loss_G = loss_dict['G_GAN'] + loss_dict['G_COARSE_L1'] + loss_dict['G_OUT_L1'] + loss_dict['G_TV'] + loss_dict['G_STYLE'] + loss_dict.get('G_VGG', 0)
99 |
100 | ############### Backward pass ###############
101 | # update Generator parameters
102 | optimizer_G.zero_grad()
103 | loss_G.backward()
104 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3, norm_type=2)
105 | optimizer_G.step()
106 |
107 | # update Discriminator parameters
108 | optimizer_D.zero_grad()
109 | loss_D.backward()
110 | optimizer_D.step()
111 |
112 | ############### Display results and losses ###############
113 | # print out losses
114 | if total_steps % opt.print_freq == print_delta:
115 | errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
116 | t = (time.time() - iter_start_time) / opt.print_freq
117 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
118 | visualizer.plot_current_errors(errors, total_steps)
119 |
120 | # display completed images
121 | if save_fake:
122 | visuals = OrderedDict([('real_image', util.tensor2im(real_img[0], normalize=False)),
123 | ('masked_image_1', util.tensor2im(msk_img[0], normalize=False)),
124 | ('coarse_reconst_1', util.tensor2im(lr_x.data[0], normalize=False)),
125 | ('output_image_1', util.tensor2im(reconst_img.data[0], normalize=False)),
126 | ('completed_image_1', util.tensor2im(compltd_img.data[0], normalize=False)),
127 | ('masked_image_2', util.tensor2im(msk_img[1], normalize=False)),
128 | ('coarse_reconst_2', util.tensor2im(lr_x.data[1], normalize=False)),
129 | ('output_image_2', util.tensor2im(reconst_img.data[1], normalize=False)),
130 | ('completed_image_2', util.tensor2im(compltd_img.data[1], normalize=False)),
131 | ('masked_image_3', util.tensor2im(msk_img[2], normalize=False)),
132 | ('coarse_reconst_3', util.tensor2im(lr_x.data[2], normalize=False)),
133 | ('output_image_3', util.tensor2im(reconst_img.data[2], normalize=False)),
134 | ('completed_image_3', util.tensor2im(compltd_img.data[2], normalize=False))])
135 | visualizer.display_current_results(visuals, epoch, total_steps)
136 |
137 | # save the latest model
138 | if total_steps % opt.save_latest_freq == save_delta:
139 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
140 | model.module.save('latest')
141 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
142 |
143 | if epoch_iter >= dataset_size:
144 | break
145 |
146 | # end of epoch
147 | iter_end_time=time.time()
148 | print('End of epoch %d / %d \t Time Taken: %d sec' %
149 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
150 |
151 | # save model for this epoch
152 | if epoch % opt.save_epoch_freq == 0:
153 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
154 | model.module.save('latest')
155 | model.module.save(epoch)
156 | np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')
157 |
158 | # instead of only training the local enhancer, train the entire network after certain iterations
159 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
160 | model.module.update_fixed_params()
161 |
162 | # linearly decay learning rate after certain iters
163 | if epoch > opt.niter:
164 | print('update learning rate')
165 | model.module.update_learning_rate()
166 |
--------------------------------------------------------------------------------
/util/__pycache__/html.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/html.cpython-36.pyc
--------------------------------------------------------------------------------
/util/__pycache__/html.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/html.cpython-37.pyc
--------------------------------------------------------------------------------
/util/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/util/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/util/__pycache__/visualizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/visualizer.cpython-36.pyc
--------------------------------------------------------------------------------
/util/__pycache__/visualizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/visualizer.cpython-37.pyc
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 |
16 | self.doc = dominate.document(title=title)
17 | if refresh > 0:
18 | with self.doc.head:
19 | meta(http_equiv="refresh", content=str(refresh))
20 |
21 | def get_image_dir(self):
22 | return self.img_dir
23 |
24 | def add_header(self, str):
25 | with self.doc:
26 | h3(str)
27 |
28 | def add_table(self, border=1):
29 | self.t = table(border=border, style="table-layout: fixed;")
30 | self.doc.add(self.t)
31 |
32 | def add_images(self, ims, txts, links, width=512):
33 | self.add_table()
34 | with self.t:
35 | with tr():
36 | for im, txt, link in zip(ims, txts, links):
37 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
38 | with p():
39 | with a(href=os.path.join('images', link)):
40 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
41 | br()
42 | p(txt)
43 |
44 | def save(self):
45 | html_file = '%s/index.html' % self.web_dir
46 | f = open(html_file, 'wt')
47 | f.write(self.doc.render())
48 | f.close()
49 |
50 |
51 | if __name__ == '__main__':
52 | html = HTML('web/', 'test_html')
53 | html.add_header('hello world')
54 |
55 | ims = []
56 | txts = []
57 | links = []
58 | for n in range(4):
59 | ims.append('image_%d.jpg' % n)
60 | txts.append('text_%d.jpg' % n)
61 | links.append('image_%d.jpg' % n)
62 | html.add_images(ims, txts, links)
63 | html.save()
64 |
65 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 |
5 | # no in use at the stage
6 | class ImagePool():
7 | def __init__(self, pool_size):
8 | self.pool_size = pool_size
9 | if self.pool_size > 0:
10 | self.num_imgs = 0
11 | self.images = []
12 |
13 | def query(self, images):
14 | if self.pool_size == 0:
15 | return images
16 | return_images = []
17 | for image in images.data:
18 | image = torch.unsqueeze(image, 0)
19 | if self.num_imgs < self.pool_size:
20 | self.num_imgs = self.num_imgs + 1
21 | self.images.append(image)
22 | return_images.append(image)
23 | else:
24 | p = random.uniform(0, 1)
25 | if p > 0.5:
26 | random_id = random.randint(0, self.pool_size - 1)
27 | tmp = self.images[random_id].clone()
28 | self.images[random_id] = image
29 | return_images.append(tmp)
30 | else:
31 | return_images.append(image)
32 | return_images = Variable(torch.cat(return_images, 0))
33 | return return_images
34 |
35 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | import numpy as np
5 | import os
6 |
7 | def mkdir(path):
8 | if not os.path.exists(path):
9 | os.makedirs(path)
10 |
11 | def mkdirs(paths):
12 | if isinstance(paths, list) and not isinstance(paths, str):
13 | for path in paths:
14 | mkdir(path)
15 | else:
16 | mkdir(paths)
17 |
18 | def save_image(image_numpy, image_path):
19 | image_pil = Image.fromarray(image_numpy)
20 | image_pil.save(image_path)
21 |
22 |
23 | # converts a tensor into a numpy array
24 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
25 | if isinstance(image_tensor, list):
26 | image_numpy = []
27 | for i in range(len(image_tensor)):
28 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
29 | return image_numpy
30 | image_numpy = image_tensor.cpu().float().numpy()
31 | if normalize:
32 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
33 | else:
34 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
35 | image_numpy = np.clip(image_numpy, 0, 255)
36 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
37 | image_numpy = image_numpy[:, :, 0]
38 | return image_numpy.astype(imtype)
39 |
40 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | import scipy.misc # scipy==1.1.0
5 | from . import html
6 | from . import util
7 | import ntpath
8 |
9 | try:
10 | from StringIO import StringIO # Python 2.7
11 | except ImportError:
12 | from io import BytesIO # Python 3.x
13 |
14 | class Visualizer():
15 | def __init__(self, opt):
16 | self.opt = opt
17 | # tf_log use tensorboard logging
18 | self.tf_log = opt.tf_log
19 | # intermediate training results to web
20 | self.use_html = opt.isTrain and not opt.no_html
21 | self.win_size = opt.display_winsize
22 | self.name = opt.name
23 |
24 | if self.tf_log:
25 | import tensorflow as tf
26 | self.tf = tf
27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
28 | self.writer = tf.summary.FileWriter(self.log_dir)
29 |
30 | if self.use_html:
31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
32 | self.img_dir = os.path.join(self.web_dir, 'images')
33 | print('Create web directory %s ...' % self.web_dir)
34 | util.mkdirs([self.web_dir, self.img_dir])
35 |
36 | # a txt file to record the losses
37 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
38 | with open(self.log_name, "a") as log_file:
39 | now = time.strftime("%c")
40 | log_file.write('==================== Training loss (%s) ====================\n' % now)
41 |
42 |
43 | # dictionary of images to display or save
44 | def display_current_results(self, visuals, epoch, step):
45 | if self.tf_log: # show images in tensorboard output
46 | img_summaries = []
47 | for label, image_numpy in visuals.items():
48 | # Write the image to a string
49 | try:
50 | s = StringIO()
51 | except:
52 | s = BytesIO()
53 | scipy.misc.toimage(image_numpy).save(s, format="png")
54 | # Create an Image object
55 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
56 | # Create a Summary value
57 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
58 |
59 | # Create and write Summary
60 | summary = self.tf.Summary(value=img_summaries)
61 | self.writer.add_summary(summary, step)
62 |
63 | if self.use_html: # save images to a html file
64 | for label, image_numpy in visuals.items():
65 | if isinstance(image_numpy, list):
66 | for i in range(len(image_numpy)):
67 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.png' % (epoch, label, i))
68 | util.save_image(image_numpy[i], img_path)
69 | else:
70 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
71 | util.save_image(image_numpy, img_path)
72 |
73 | # update the website
74 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)
75 | for n in range(epoch, 0, -1):
76 | webpage.add_header('epoch [%d]' % n)
77 | ims = []
78 | txts = []
79 | links = []
80 |
81 | for label, image_numpy in visuals.items():
82 | if isinstance(image_numpy, list):
83 | for i in range(len(image_numpy)):
84 | img_path = 'epoch%.3d_%s_%d.png' % (n, label, i)
85 | ims.append(img_path)
86 | txts.append(label+str(i))
87 | links.append(img_path)
88 | else:
89 | img_path = 'epoch%.3d_%s.png' % (n, label)
90 | ims.append(img_path)
91 | txts.append(label)
92 | links.append(img_path)
93 |
94 | if len(ims) < 10:
95 | webpage.add_images(ims, txts, links, width=self.win_size)
96 | else:
97 | num = int(round(len(ims)/2.0))
98 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
99 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
100 |
101 | webpage.save()
102 |
103 | # dictionary of loss labels and values
104 | def plot_current_errors(self, errors, step):
105 | if self.tf_log:
106 | for tag, value in errors.items():
107 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
108 | self.writer.add_summary(summary, step)
109 |
110 | # print loss labels and values
111 | def print_current_errors(self, epoch, i, errors, t):
112 | message = '(epoch: %d, iters: %d, time: %.4f) ' % (epoch, i, t)
113 | for k, v in errors.items():
114 | if v != 0:
115 | message += '%s: %.4f ' % (k, v)
116 |
117 | print(message)
118 | with open(self.log_name, "a") as log_file:
119 | log_file.write('%s\n' % message)
120 |
121 | # save images to disk
122 | def save_images(self, webpage, visuals, image_path):
123 | image_dir = webpage.get_image_dir()
124 | short_path = ntpath.basename(image_path[0])
125 | name = os.path.splitext(short_path)[0]
126 |
127 | webpage.add_header(name)
128 | ims = []
129 | txts = []
130 | links = []
131 |
132 | for label, image_numpy in visuals.items():
133 | image_name = '%s_%s.png' % (name, label)
134 | save_path = os.path.join(image_dir, image_name)
135 | util.save_image(image_numpy, save_path)
136 |
137 | ims.append(image_name)
138 | txts.append(label)
139 | links.append(image_name)
140 |
141 | webpage.add_images(ims, txts, links, width=self.win_size)
142 |
143 |
--------------------------------------------------------------------------------