├── LICENSE ├── README.md ├── checkpoints ├── gin │ └── put_trained_model_here └── warmup │ └── put_warm_up_here ├── data ├── __pycache__ │ ├── ade20k_dataset.cpython-36.pyc │ ├── ade20k_dataset.cpython-37.pyc │ ├── base_data_loader.cpython-36.pyc │ ├── base_data_loader.cpython-37.pyc │ ├── base_dataset.cpython-36.pyc │ ├── base_dataset.cpython-37.pyc │ ├── custom_dataset_data_loader.cpython-36.pyc │ ├── custom_dataset_data_loader.cpython-37.pyc │ ├── data_loader.cpython-36.pyc │ ├── data_loader.cpython-37.pyc │ ├── image_folder.cpython-36.pyc │ ├── image_folder.cpython-37.pyc │ ├── masks.cpython-36.pyc │ └── masks.cpython-37.pyc ├── ade20k_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py ├── image_folder.py └── masks.py ├── datasets └── ade20k │ ├── test │ ├── AIM_IC_t1_validation_0_mask.png │ ├── AIM_IC_t1_validation_0_with_holes.png │ ├── AIM_IC_t1_validation_18_mask.png │ └── AIM_IC_t1_validation_18_with_holes.png │ └── train │ ├── a │ └── abbey │ │ ├── ADE_train_00000976.jpg │ │ ├── ADE_train_00000976_seg.png │ │ ├── ADE_train_00000985.jpg │ │ ├── ADE_train_00000985_seg.png │ │ ├── ADE_train_00000992.jpg │ │ └── ADE_train_00000992_seg.png │ └── z │ └── zoo │ ├── ADE_train_00020206.jpg │ └── ADE_train_00020206_seg.png ├── examples ├── AIM_IC_t1_validation_0.png ├── AIM_IC_t1_validation_0_with_holes.png ├── ablation_study.png ├── arc.png ├── architecture.png ├── comparisons_ffhq_oxford.png ├── problem_of_interest.png ├── spd_resnetblk.png └── visualization_seg.png ├── models ├── __pycache__ │ ├── BoundaryVAE_model.cpython-36.pyc │ ├── BoundaryVAE_model.cpython-37.pyc │ ├── base_model.cpython-36.pyc │ ├── base_model.cpython-37.pyc │ ├── models.cpython-36.pyc │ ├── models.cpython-37.pyc │ ├── models_pretrain.cpython-36.pyc │ ├── networks.cpython-36.pyc │ ├── networks.cpython-37.pyc │ └── our_model.cpython-36.pyc ├── base_model.py ├── models.py ├── networks.py └── our_model.py ├── options ├── __pycache__ │ ├── base_options.cpython-36.pyc │ ├── base_options.cpython-37.pyc │ ├── test_options.cpython-36.pyc │ ├── train_options.cpython-36.pyc │ └── train_options.cpython-37.pyc ├── base_options.py ├── test_options.py └── train_options.py ├── results └── test │ └── AIM_IC_t1_validation_0.png ├── test_ensemble.py ├── train.py └── util ├── __pycache__ ├── html.cpython-36.pyc ├── html.cpython-37.pyc ├── util.cpython-36.pyc ├── util.cpython-37.pyc ├── visualizer.cpython-36.pyc └── visualizer.cpython-37.pyc ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, ronctl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | 28 | 29 | --------------------------- LICENSE FOR pytorch-pix2pixHD ---------------- 30 | Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. 31 | BSD License. All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 44 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 45 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 46 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 47 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 48 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 49 | 50 | 51 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------------- 52 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 53 | All rights reserved. 54 | 55 | Redistribution and use in source and binary forms, with or without 56 | modification, are permitted provided that the following conditions are met: 57 | 58 | * Redistributions of source code must retain the above copyright notice, this 59 | list of conditions and the following disclaimer. 60 | 61 | * Redistributions in binary form must reproduce the above copyright notice, 62 | this list of conditions and the following disclaimer in the documentation 63 | and/or other materials provided with the distribution. 64 | 65 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 66 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 67 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 68 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 69 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 70 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 71 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 72 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 73 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 74 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Generative Inpainting Network (GIN) for Extreme Image Inpainting 2 | For AIM2020 ECCV Extreme Image Inpainting Track 1 Classic
3 | This is the Pytorch implementation of our Deep Generative Inpainting Network (GIN) for Extreme Image Inpainting. We have participated in AIM 2020 ECCV Extreme Image Inpainting Challenge. Our GIN is used for reconstructing a completed image with satisfactory visual quality from a randomly masked image.

4 | 5 | ## Overview 6 |

7 | 8 |

9 | Our Spatial Pyramid Dilation (SPD) block 10 |

11 | 12 |

13 | 14 | ## Example of Image Inpainting using our GIN 15 | - An example from the validation set of the AIM20 ECCV Extreme Image Inpainting Track 1 Classic 16 | - (left: masked image, right: our completed image) 17 |

18 | 19 | 20 |

21 | 22 | ## Preparation 23 | - Our solution is developed using Pytorch 1.5.0 platform 24 | - We train our model on two NVIDIA GeForce RTX 2080 Ti (with 11GB memory) 25 | - Apart from Pytorch and related dependencies, 26 | - Install natsort 27 | ```bash 28 | pip install natsort 29 | ``` 30 | - Install dominate 31 | ```bash 32 | pip install dominate 33 | ``` 34 | - Install scipy 1.1.0 35 | ```bash 36 | pip install scipy==1.1.0 37 | ``` 38 | - If you would like to use tensorboard for logging, please also install tensorboard and tensorflow 39 | - Please clone this project: 40 | ```bash 41 | git clone https://github.com/rlct1/gin.git 42 | cd gin 43 | ``` 44 | 45 | ## Testing 46 | - An example of the validation data of this challenge is provided in the `datasets/ade20k/test` folder 47 | - Please download our trained model for this challenge [here](https://drive.google.com/file/d/1yOtMELWwTBc-PMSY69x1FH8D1anUN7tD/view?usp=sharing) (google drive link), and put it under `checkpoints/gin/` 48 | - For reproducing the test results for this challenge, please put all the testing images under `datasets/ade20k/test/` 49 | - You can test our model by typing: 50 | ```bash 51 | python test_ensemble.py --name gin 52 | ``` 53 | - The test results will be stored in `results/test` folder 54 | - If you would like to test on other datasets, please refer to the file structure in the `datasets/ade20k/test` folder 55 | - Note that the file structure is for AIM20 IC Track 1 56 | - You can download our test results for this challenge [here](https://drive.google.com/file/d/1EJgQ3neOA2WkZMmG6uG0GG14VoLYmNFg/view?usp=sharing) (google drive link) 57 | 58 | ## Training 59 | - By default, our model is trained using two GPUs 60 | - Examples of the training images from this challenge is provided in the `datasets/ade20k/train` folder 61 | - If you would like to train a model using our warm up for initialization, please download our warm up for this challenge [here](https://drive.google.com/file/d/1T3ST-ujhtDZQpWUiagICOAIvBF7CMeYz/view?usp=sharing) (google drive link), and put it under `checkpoints/warmup/` 62 | ```bash 63 | python train.py --name yourmodel --continue_train --load_pretrain './checkpoints/warmup' 64 | ``` 65 | - If you would like to train a model from scratch, 66 | ```bash 67 | python train.py --name yourmodel 68 | ``` 69 | - If you would like to train a model based on your own selection and resources, please refer to the `options/base_options.py` and `options/train_options.py` for details 70 | 71 | ## Experiments 72 | Ablation Study 73 |

74 | 75 |

76 | Comparisons 77 |

78 | 79 |

80 | Visualization of predicted semantic segmentation map 81 |

82 | 83 |

84 | 85 | ## Citation 86 | Thanks for visiting our project page, if it is useful, please cite our paper, 87 | ``` 88 | @misc{li2020deepgin, 89 | title={DeepGIN: Deep Generative Inpainting Network for Extreme Image Inpainting}, 90 | author={Chu-Tak Li and Wan-Chi Siu and Zhi-Song Liu and Li-Wen Wang and Daniel Pak-Kong Lun}, 91 | year={2020}, 92 | eprint={2008.07173}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.CV} 95 | } 96 | ``` 97 | 98 | ## Acknowledgment 99 | Our code is developed based on the skeleton of the Pytorch implementation of [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) 100 | 101 | -------------------------------------------------------------------------------- /checkpoints/gin/put_trained_model_here: -------------------------------------------------------------------------------- 1 | put trained model here 2 | -------------------------------------------------------------------------------- /checkpoints/warmup/put_warm_up_here: -------------------------------------------------------------------------------- 1 | put warm up model here 2 | -------------------------------------------------------------------------------- /data/__pycache__/ade20k_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/ade20k_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/ade20k_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/ade20k_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset_data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/custom_dataset_data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/image_folder.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/masks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/masks.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/masks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/data/__pycache__/masks.cpython-37.pyc -------------------------------------------------------------------------------- /data/ade20k_dataset.py: -------------------------------------------------------------------------------- 1 | #################################################################################### 2 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 3 | # Licensed under the CC BY-NC-SA 4.0 license 4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 5 | #################################################################################### 6 | import os.path 7 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize 8 | from data.image_folder import make_dataset, make_dataset_with_conditions, make_dataset_with_condition_list 9 | from data.masks import Masks 10 | from PIL import Image 11 | import random 12 | import numpy as np 13 | from natsort import natsorted 14 | import torch 15 | 16 | class ADE20kDataset(BaseDataset): 17 | def initialize(self, opt): 18 | self.opt = opt 19 | self.root = opt.dataroot 20 | ################################################ 21 | # input A : masked images (rgb) 22 | # input B : real images (rgb) 23 | # input mask : masks (gray) 24 | # input B_seg : real seg images (rgb) 25 | ################################################ 26 | 27 | if opt.phase == 'train': 28 | # the 10,330 images in ade20k train set 29 | self.dir_A = os.path.join(opt.dataroot, opt.phase) 30 | self.B_seg_paths, _ = make_dataset_with_conditions(self.dir_A, '_seg') 31 | self.A_paths, _ = make_dataset_with_conditions(self.dir_A, '.jpg') 32 | 33 | self.B_seg_paths = natsorted(self.B_seg_paths) 34 | self.A_paths = natsorted(self.A_paths) 35 | self.B_paths = self.A_paths 36 | 37 | self.dir_mask = [] 38 | self.mask_paths = [] 39 | 40 | self.mask = Masks() 41 | 42 | elif opt.phase == 'test': 43 | # aim 2020 eccv validation set and test set. 44 | # no ground truth 45 | self.dir_A = os.path.join(opt.dataroot, opt.phase) 46 | self.A_paths, _ = make_dataset_with_conditions(self.dir_A, '_with_holes') 47 | self.A_paths = natsorted(self.A_paths) 48 | 49 | self.dir_B = [] 50 | self.B_paths = [] 51 | self.B_seg_paths = [] # depends on track 1 or 2 52 | 53 | self.dir_mask = os.path.join(opt.dataroot, opt.phase) 54 | self.mask_paths, _ = make_dataset_with_conditions(self.dir_mask, '_mask') 55 | self.mask_paths = natsorted(self.mask_paths) 56 | 57 | self.dataset_size = len(self.A_paths) 58 | 59 | def __getitem__(self, index): 60 | ################################################ 61 | # input A : masked images (rgb) 62 | # input B : real images (rgb) 63 | # input mask : masks (gray) 64 | # input B_seg : real seg images (rgb) 65 | ################################################ 66 | 67 | A_path = self.A_paths[index] 68 | A = Image.open(A_path).convert('RGB') 69 | 70 | params = get_params(self.opt, A.size) 71 | transform_A = get_transform(self.opt, params, normalize=False) 72 | A_tensor = transform_A(A) 73 | 74 | B_tensor = mask_tensor = 0 75 | B_seg_tensor = 0 76 | 77 | if self.opt.phase == 'train': 78 | # the 10,330 images in ade20k train set 79 | B_seg_path = self.B_seg_paths[index] 80 | B_seg = Image.open(B_seg_path).convert('RGB') 81 | 82 | new_A, new_B_seg = self.resize_or_crop(A, B_seg) 83 | ## data augmentation rotate 84 | new_A = self.rotate(new_A) 85 | new_A = self.ensemble(new_A) 86 | 87 | width, height = new_A.size 88 | f_A = np.array(new_A, np.float32) 89 | f_A1 = np.array(new_A, np.float32) 90 | f_A2 = np.array(new_A, np.float32) 91 | f_A3 = np.array(new_A, np.float32) 92 | 93 | f_mask = self.mask.get_random_mask(height, width) 94 | 95 | f_A[:, :, 0] = f_A[:, :, 0] * (1.0 - f_mask) + 255.0 * f_mask 96 | f_A[:, :, 1] = f_A[:, :, 1] * (1.0 - f_mask) + 255.0 * f_mask 97 | f_A[:, :, 2] = f_A[:, :, 2] * (1.0 - f_mask) + 255.0 * f_mask 98 | f_masked_A = f_A 99 | 100 | f_mask1 = self.mask.get_box_mask(height, width) 101 | f_mask2 = self.mask.get_ca_mask(height, width) 102 | f_mask3 = self.mask.get_ff_mask(height, width) 103 | 104 | f_A1[:, :, 0] = f_A1[:, :, 0] * (1.0 - f_mask1) + 255.0 * f_mask1 105 | f_A1[:, :, 1] = f_A1[:, :, 1] * (1.0 - f_mask1) + 255.0 * f_mask1 106 | f_A1[:, :, 2] = f_A1[:, :, 2] * (1.0 - f_mask1) + 255.0 * f_mask1 107 | f_masked_A1 = f_A1 108 | 109 | f_A2[:, :, 0] = f_A2[:, :, 0] * (1.0 - f_mask2) + 255.0 * f_mask2 110 | f_A2[:, :, 1] = f_A2[:, :, 1] * (1.0 - f_mask2) + 255.0 * f_mask2 111 | f_A2[:, :, 2] = f_A2[:, :, 2] * (1.0 - f_mask2) + 255.0 * f_mask2 112 | f_masked_A2 = f_A2 113 | 114 | f_A3[:, :, 0] = f_A3[:, :, 0] * (1.0 - f_mask3) + 255.0 * f_mask3 115 | f_A3[:, :, 1] = f_A3[:, :, 1] * (1.0 - f_mask3) + 255.0 * f_mask3 116 | f_A3[:, :, 2] = f_A3[:, :, 2] * (1.0 - f_mask3) + 255.0 * f_mask3 117 | f_masked_A3 = f_A3 118 | 119 | # masked images 120 | masked_A = Image.fromarray((f_masked_A).astype(np.uint8)).convert('RGB') 121 | mask_img = Image.fromarray((f_mask * 255.0).astype(np.uint8)).convert('L') 122 | 123 | masked_A1 = Image.fromarray((f_masked_A1).astype(np.uint8)).convert('RGB') 124 | mask_img1 = Image.fromarray((f_mask1 * 255.0).astype(np.uint8)).convert('L') 125 | 126 | masked_A2 = Image.fromarray((f_masked_A2).astype(np.uint8)).convert('RGB') 127 | mask_img2 = Image.fromarray((f_mask2 * 255.0).astype(np.uint8)).convert('L') 128 | 129 | masked_A3 = Image.fromarray((f_masked_A3).astype(np.uint8)).convert('RGB') 130 | mask_img3 = Image.fromarray((f_mask3 * 255.0).astype(np.uint8)).convert('L') 131 | 132 | A_tensor = transform_A(masked_A) 133 | 134 | ## 135 | A_tensor1 = transform_A(masked_A1) 136 | A_tensor2 = transform_A(masked_A2) 137 | A_tensor3 = transform_A(masked_A3) 138 | 139 | # real images 140 | B_path = self.B_paths[index] 141 | B = new_A 142 | transform_B = get_transform(self.opt, params, normalize=False) 143 | B_tensor = transform_B(B) 144 | 145 | B_tensor1 = transform_B(B) 146 | B_tensor2 = transform_B(B) 147 | B_tensor3 = transform_B(B) 148 | 149 | transform_B_seg = get_transform(self.opt, params, normalize=False) 150 | B_seg_tensor = transform_B_seg(new_B_seg) 151 | 152 | # masks 153 | mask_path = [] 154 | transform_mask = get_transform(self.opt, params, normalize=False) 155 | mask_tensor = transform_mask(mask_img) 156 | 157 | mask_tensor1 = transform_mask(mask_img1) 158 | mask_tensor2 = transform_mask(mask_img2) 159 | mask_tensor3 = transform_mask(mask_img3) 160 | 161 | A_tensor = torch.cat((A_tensor1, A_tensor2, A_tensor3), dim=0) 162 | B_tensor = torch.cat((B_tensor1, B_tensor2, B_tensor3), dim=0) 163 | mask_tensor = torch.cat((mask_tensor1, mask_tensor2, mask_tensor3), dim=0) 164 | 165 | elif self.opt.phase == 'test': 166 | # aim 2020 eccv validation set and test set. 167 | # no ground truth 168 | B_path = [] 169 | B_seg_path = [] 170 | 171 | # masks 172 | mask_path = self.mask_paths[index] 173 | mask = Image.open(mask_path).convert('L') 174 | transform_mask = get_transform(self.opt, params, normalize=False) 175 | mask_tensor = transform_mask(mask) 176 | 177 | if self.opt.phase == 'test': 178 | input_dict = { 179 | 'masked_image': A_tensor, 180 | 'mask': mask_tensor, 181 | 'path_mskimg': A_path, 182 | 'path_msk': mask_path} 183 | return input_dict 184 | 185 | input_dict = {'masked_image': A_tensor, 186 | 'mask': mask_tensor, 187 | 'real_image': B_tensor, 188 | 'real_seg': B_seg_tensor} 189 | return input_dict 190 | 191 | def __len__(self): 192 | return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize 193 | 194 | def name(self): 195 | return 'ADE20kDataset' 196 | 197 | def resize_or_crop(self, img, seg_img, method=Image.BICUBIC): 198 | w, h = img.size 199 | new_w = w 200 | new_h = h 201 | 202 | if w > self.opt.loadSize and h > self.opt.loadSize: 203 | return img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method), seg_img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method) 204 | else: 205 | return img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method), seg_img.resize((self.opt.fineSize * 2, self.opt.fineSize * 2), method) 206 | 207 | def rotate(self, img): 208 | bFlag = random.randint(0, 3) 209 | if bFlag == 0: 210 | return img.rotate(0) 211 | elif bFlag == 1: 212 | return img.rotate(90) 213 | elif bFlag == 2: 214 | return img.rotate(180) 215 | elif bFlag == 3: 216 | return img.rotate(270) 217 | 218 | def ensemble(self, img): 219 | bFlag = random.randint(0, 3) 220 | width, height = img.size 221 | new_w = width // 2 222 | new_h = height // 2 223 | new_img = img.resize((new_w, new_h), Image.BICUBIC) 224 | np_img = np.array(img, np.float32) 225 | np_new_img = np.array(new_img, np.float32) 226 | 227 | if bFlag == 0: 228 | for i in range(new_w): 229 | for j in range(new_h): 230 | np_new_img[i, j, 0] = np_img[2*i, 2*j, 0] 231 | np_new_img[i, j, 1] = np_img[2*i, 2*j, 1] 232 | np_new_img[i, j, 2] = np_img[2*i, 2*j, 2] 233 | elif bFlag == 1: 234 | for i in range(new_w): 235 | for j in range(new_h): 236 | np_new_img[i, j, 0] = np_img[1 + 2*i, 1 + 2*j, 0] 237 | np_new_img[i, j, 1] = np_img[1 + 2*i, 1 + 2*j, 1] 238 | np_new_img[i, j, 2] = np_img[1 + 2*i, 1 + 2*j, 2] 239 | elif bFlag == 2: 240 | for i in range(new_w): 241 | for j in range(new_h): 242 | np_new_img[i, j, 0] = np_img[1 + 2*i, 2*j, 0] 243 | np_new_img[i, j, 1] = np_img[1 + 2*i, 2*j, 1] 244 | np_new_img[i, j, 2] = np_img[1 + 2*i, 2*j, 2] 245 | else: 246 | for i in range(new_w): 247 | for j in range(new_h): 248 | np_new_img[i, j, 0] = np_img[2*i, 1 + 2*j, 0] 249 | np_new_img[i, j, 1] = np_img[2*i, 1 + 2*j, 1] 250 | np_new_img[i, j, 2] = np_img[2*i, 1 + 2*j, 2] 251 | 252 | new_A = Image.fromarray((np_new_img).astype(np.uint8)).convert('RGB') 253 | return new_A 254 | 255 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | 12 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | #################################################################################### 2 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 3 | # Licensed under the CC BY-NC-SA 4.0 license 4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 5 | #################################################################################### 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import random 9 | import numpy as np 10 | import torchvision.transforms as transforms 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self): 15 | super(BaseDataset, self).__init__() 16 | 17 | def name(self): 18 | return 'BaseDataset' 19 | 20 | def initialize(self, opt): 21 | pass 22 | 23 | 24 | def get_params(opt, size): 25 | w, h = size 26 | new_h = h 27 | new_w = w 28 | if opt.resize_or_crop == 'resize_and_crop': 29 | new_h = new_w = opt.loadSize 30 | elif opt.resize_or_crop == 'scale_width_and_crop': 31 | new_w = opt.loadSize 32 | new_h = opt.loadSize * h // w 33 | 34 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 35 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 36 | 37 | flip = random.random() > 0.5 38 | return {'crop_pos': (x, y), 'flip': flip} 39 | 40 | 41 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, grayscale=False): 42 | transform_list = [] 43 | if 'resize' in opt.resize_or_crop: 44 | osize = [opt.loadSize, opt.loadSize] 45 | transform_list.append(transforms.Scale(osize, method)) 46 | elif 'scale_width' in opt.resize_or_crop: 47 | transform_list.append(transforms.Lambda( 48 | lambda img: __scale_width(img, opt.loadSize, method))) 49 | 50 | if 'crop' in opt.resize_or_crop: 51 | transform_list.append(transforms.Lambda( 52 | lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 53 | 54 | if opt.resize_or_crop == 'none': 55 | base = float(2 ** opt.n_downsample_global) 56 | if opt.netG == 'local': 57 | base *= (2 ** opt.n_local_enhancers) 58 | transform_list.append(transforms.Lambda( 59 | lambda img: __make_power_2(img, base, method))) 60 | 61 | if opt.resize_or_crop == 'standard': 62 | transform_list.append(transforms.Lambda( 63 | lambda img: __resize(img, opt.fineSize, method))) 64 | 65 | if opt.isTrain and not opt.no_flip: 66 | transform_list.append(transforms.Lambda( 67 | lambda img: __flip(img, params['flip']))) 68 | 69 | transform_list += [transforms.ToTensor()] 70 | 71 | if normalize: 72 | if grayscale: 73 | transform_list += [transforms.Normalize((0.5), (0.5))] 74 | else: 75 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 76 | (0.5, 0.5, 0.5))] 77 | 78 | return transforms.Compose(transform_list) 79 | 80 | 81 | def normalize(): 82 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 83 | 84 | 85 | def __make_power_2(img, base, method=Image.BICUBIC): 86 | ow, oh = img.size 87 | 88 | h = int(round(oh / base) * base) 89 | w = int(round(ow / base) * base) 90 | 91 | if (h == oh) and (w == ow): 92 | return img 93 | return img.resize((w, h), method) 94 | 95 | 96 | def __scale_width(img, target_width, method=Image.BICUBIC): 97 | ow, oh = img.size 98 | if (ow == target_width): 99 | return img 100 | w = target_width 101 | h = int(target_width * oh / ow) 102 | return img.resize((w, h), method) 103 | 104 | 105 | def __resize(img, target_size, method=Image.BICUBIC): 106 | ow, oh = img.size 107 | if (ow == target_size) and (oh == target_size): 108 | return img 109 | return img.resize((target_size, target_size), method) 110 | 111 | 112 | def __crop(img, pos, size): 113 | ow, oh = img.size 114 | x1, y1 = pos 115 | tw = th = size 116 | if (ow > tw or oh > th): 117 | return img.crop((x1, y1, x1 + tw, y1 + th)) 118 | return img 119 | 120 | 121 | def __flip(img, flip): 122 | if flip: 123 | return img.transpose(Image.FLIP_LEFT_RIGHT) 124 | return img 125 | 126 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | #from data.ade20k_dataset import ADE20kDataset 8 | from data.ade20k_dataset import ADE20kDataset 9 | dataset = ADE20kDataset() 10 | 11 | print("dataset [%s] was created" % (dataset.name())) 12 | dataset.initialize(opt) 13 | return dataset 14 | 15 | 16 | class CustomDatasetDataLoader(BaseDataLoader): 17 | def name(self): 18 | return 'CustomDatasetDataLoader' 19 | 20 | def initialize(self, opt): 21 | BaseDataLoader.initialize(self, opt) 22 | self.dataset = CreateDataset(opt) 23 | self.dataloader = torch.utils.data.DataLoader( 24 | self.dataset, 25 | batch_size=opt.batchSize, 26 | shuffle=not opt.serial_batches, 27 | num_workers=int(opt.nThreads)) 28 | 29 | def load_data(self): 30 | return self.dataloader 31 | 32 | def __len__(self): 33 | return min(len(self.dataset), self.opt.max_dataset_size) 34 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | 9 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | #################################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | #################################################################################### 7 | 8 | 9 | import torch.utils.data as data 10 | from PIL import Image 11 | import os 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | # make own dataset 23 | 24 | 25 | def make_dataset(dir): 26 | images = [] 27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 28 | 29 | for root, _, fnames in sorted(os.walk(dir)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | 35 | return images 36 | 37 | 38 | # make own dataset, with certain wordings 39 | def make_dataset_with_conditions(dir, wording='_mask'): 40 | images = [] 41 | images_no_conditions = [] 42 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 43 | 44 | for root, _, fnames in sorted(os.walk(dir)): 45 | for fname in fnames: 46 | if is_image_file(fname) and (wording in fname): 47 | path = os.path.join(root, fname) 48 | images.append(path) 49 | else: 50 | path = os.path.join(root, fname) 51 | images_no_conditions.append(path) 52 | 53 | return images, images_no_conditions 54 | 55 | # make own dataset, with certain wordings 56 | def make_dataset_with_condition_list(dir): 57 | images = [] 58 | images_no_conditions = [] 59 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 60 | 61 | for root, _, fnames in sorted(os.walk(dir)): 62 | for fname in fnames: 63 | if is_image_file(fname) and (('_seg' in fname) or ('_seg_inst' in fname) or ('_boundary' in fname)): 64 | path = os.path.join(root, fname) 65 | images.append(path) 66 | else: 67 | path = os.path.join(root, fname) 68 | images_no_conditions.append(path) 69 | #print(images) 70 | #print(images_no_conditions) 71 | 72 | return images, images_no_conditions 73 | 74 | 75 | def default_loader(path): 76 | return Image.open(path).convert('RGB') 77 | 78 | 79 | class ImageFolder(data.Dataset): 80 | def __init__(self, root, transform=None, return_paths=False, loader=default_loader): 81 | imgs = make_dataset(root) 82 | if len(imgs) == 0: 83 | raise(RuntimeError("Found 0 images in : " + root + "\n" 84 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 85 | 86 | self.root = root 87 | self.imgs = imgs 88 | self.transform = transform 89 | self.return_paths = return_paths 90 | self.loader = loader 91 | 92 | def __getitem__(self, index): 93 | path = self.imgs[index] 94 | img = self.loader(path) 95 | if self.transform is not None: 96 | img = self.transform(img) 97 | if self.return_paths: 98 | return img, path 99 | else: 100 | return img 101 | 102 | def __len__(self): 103 | return len(self.imgs) 104 | 105 | -------------------------------------------------------------------------------- /data/masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import random 4 | from scipy import ndimage, misc 5 | 6 | class Masks(): 7 | 8 | @staticmethod 9 | def get_ff_mask(h, w, num_v = None): 10 | #Source: Generative Inpainting https://github.com/JiahuiYu/generative_inpainting 11 | 12 | mask = np.zeros((h,w)) 13 | if num_v is None: 14 | num_v = 15+np.random.randint(9) #5 15 | 16 | for i in range(num_v): 17 | start_x = np.random.randint(w) 18 | start_y = np.random.randint(h) 19 | for j in range(1+np.random.randint(5)): 20 | angle = 0.01+np.random.randint(4.0) 21 | if i % 2 == 0: 22 | angle = 2 * 3.1415926 - angle 23 | length = 10+np.random.randint(60) # 40 24 | brush_w = 10+np.random.randint(15) # 10 25 | end_x = (start_x + length * np.sin(angle)).astype(np.int32) 26 | end_y = (start_y + length * np.cos(angle)).astype(np.int32) 27 | 28 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) 29 | start_x, start_y = end_x, end_y 30 | 31 | return mask.astype(np.float32) 32 | 33 | 34 | @staticmethod 35 | def get_box_mask(h,w): 36 | height, width = h, w 37 | 38 | mask = np.zeros((height, width)) 39 | 40 | mask_width = random.randint(int(0.3 * width), int(0.7 * width)) 41 | mask_height = random.randint(int(0.3 * height), int(0.7 * height)) 42 | 43 | mask_x = random.randint(0, width - mask_width) 44 | mask_y = random.randint(0, height - mask_height) 45 | 46 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 47 | return mask.astype(np.float32) 48 | 49 | @staticmethod 50 | def get_ca_mask(h,w, scale = None, r = None): 51 | 52 | if scale is None: 53 | scale = random.choice([1,2,4,8]) 54 | if r is None: 55 | r = random.randint(2,6) # repeat median filter r times 56 | 57 | height = h 58 | width = w 59 | mask = np.random.randint(2, size = (height//scale, width//scale)) 60 | 61 | for _ in range(r): 62 | mask = ndimage.median_filter(mask, size=3, mode='constant') 63 | 64 | mask = misc.imresize(mask,(h,w),interp='nearest') 65 | if scale > 1: 66 | struct = ndimage.generate_binary_structure(2, 1) 67 | mask = ndimage.morphology.binary_dilation(mask, struct) 68 | elif scale > 3: 69 | struct = np.array([[ 0., 0., 1., 0., 0.], 70 | [ 0., 1., 1., 1., 0.], 71 | [ 1., 1., 1., 1., 1.], 72 | [ 0., 1., 1., 1., 0.], 73 | [ 0., 0., 1., 0., 0.]]) 74 | 75 | return (mask > 0).astype(np.float32) 76 | 77 | @staticmethod 78 | def get_random_mask(h,w): 79 | f = random.choice([Masks.get_box_mask, Masks.get_ca_mask, Masks.get_ff_mask]) 80 | return f(h,w) -------------------------------------------------------------------------------- /datasets/ade20k/test/AIM_IC_t1_validation_0_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_0_mask.png -------------------------------------------------------------------------------- /datasets/ade20k/test/AIM_IC_t1_validation_0_with_holes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_0_with_holes.png -------------------------------------------------------------------------------- /datasets/ade20k/test/AIM_IC_t1_validation_18_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_18_mask.png -------------------------------------------------------------------------------- /datasets/ade20k/test/AIM_IC_t1_validation_18_with_holes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/test/AIM_IC_t1_validation_18_with_holes.png -------------------------------------------------------------------------------- /datasets/ade20k/train/a/abbey/ADE_train_00000976.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000976.jpg -------------------------------------------------------------------------------- /datasets/ade20k/train/a/abbey/ADE_train_00000976_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000976_seg.png -------------------------------------------------------------------------------- /datasets/ade20k/train/a/abbey/ADE_train_00000985.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000985.jpg -------------------------------------------------------------------------------- /datasets/ade20k/train/a/abbey/ADE_train_00000985_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000985_seg.png -------------------------------------------------------------------------------- /datasets/ade20k/train/a/abbey/ADE_train_00000992.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000992.jpg -------------------------------------------------------------------------------- /datasets/ade20k/train/a/abbey/ADE_train_00000992_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/a/abbey/ADE_train_00000992_seg.png -------------------------------------------------------------------------------- /datasets/ade20k/train/z/zoo/ADE_train_00020206.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/z/zoo/ADE_train_00020206.jpg -------------------------------------------------------------------------------- /datasets/ade20k/train/z/zoo/ADE_train_00020206_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/datasets/ade20k/train/z/zoo/ADE_train_00020206_seg.png -------------------------------------------------------------------------------- /examples/AIM_IC_t1_validation_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/AIM_IC_t1_validation_0.png -------------------------------------------------------------------------------- /examples/AIM_IC_t1_validation_0_with_holes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/AIM_IC_t1_validation_0_with_holes.png -------------------------------------------------------------------------------- /examples/ablation_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/ablation_study.png -------------------------------------------------------------------------------- /examples/arc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/arc.png -------------------------------------------------------------------------------- /examples/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/architecture.png -------------------------------------------------------------------------------- /examples/comparisons_ffhq_oxford.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/comparisons_ffhq_oxford.png -------------------------------------------------------------------------------- /examples/problem_of_interest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/problem_of_interest.png -------------------------------------------------------------------------------- /examples/spd_resnetblk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/spd_resnetblk.png -------------------------------------------------------------------------------- /examples/visualization_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/examples/visualization_seg.png -------------------------------------------------------------------------------- /models/__pycache__/BoundaryVAE_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/BoundaryVAE_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/BoundaryVAE_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/BoundaryVAE_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/models_pretrain.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/models_pretrain.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/our_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/models/__pycache__/our_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import sys 5 | 6 | class BaseModel(nn.Module): 7 | def name(self): 8 | return 'BaseModel' 9 | 10 | def initialize(self, opt): 11 | self.opt = opt 12 | self.gpu_ids = opt.gpu_ids 13 | self.isTrain = opt.isTrain 14 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 16 | 17 | def set_input(self, input): 18 | self.input = input 19 | 20 | def forward(self): 21 | pass 22 | 23 | # used in test time, no backprop 24 | def test(self): 25 | pass 26 | 27 | def get_image_paths(self): 28 | pass 29 | 30 | def optimize_parameters(self): 31 | pass 32 | 33 | def get_current_visuals(self): 34 | return self.input 35 | 36 | def get_current_errors(self): 37 | return {} 38 | 39 | def save(self, label): 40 | pass 41 | 42 | # helper saving function that can be used by subclasses 43 | def save_network(self, network, network_label, epoch_label, gpu_ids): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | torch.save(network.cpu().state_dict(), save_path) 47 | if len(gpu_ids) and torch.cuda.is_available(): 48 | network.cuda() 49 | 50 | # helper loading function that can be used by subclasses 51 | def load_network(self, network, network_label, epoch_label, save_dir=''): 52 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 53 | if not save_dir: 54 | save_dir = self.save_dir 55 | save_path = os.path.join(save_dir, save_filename) 56 | if not os.path.isfile(save_path): 57 | print('%s not exists yet!' % save_path) 58 | if network_label == 'G': 59 | raise('Generator must exist!') 60 | else: 61 | try: 62 | network.load_state_dict(torch.load(save_path)) 63 | except: 64 | pretrained_dict = torch.load(save_path) 65 | model_dict = network.state_dict() 66 | try: 67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 68 | network.load_state_dict(pretrained_dict) 69 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 70 | except: 71 | print('Pretrained network %s has fewer layers; The following are not initialized: ' % network_label) 72 | for k, v in pretrained_dict.items(): 73 | if v.size() == model_dict[k].size(): 74 | model_dict[k] = v 75 | 76 | if sys.version_info >= (3,0): 77 | not_initialized = set() 78 | else: 79 | from sets import Set 80 | not_initialized = Set() 81 | 82 | for k, v in model_dict.items(): 83 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 84 | not_initialized.add(k.split('.')[0]) 85 | 86 | print(sorted(not_initialized)) 87 | network.load_state_dict(model_dict) 88 | 89 | def update_learning_rate(): 90 | pass 91 | 92 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def create_model(opt): 4 | if opt.model == 'Ours': 5 | from .our_model import OurModel, InferenceModel 6 | if opt.isTrain: 7 | model = OurModel() 8 | else: 9 | model = InferenceModel() 10 | else: 11 | print('Please define your model [%s]!'.format(opt.model)) 12 | model.initialize(opt) 13 | print("model [%s] was created" % (model.name())) 14 | num_params_G, num_params_D = model.get_num_params() 15 | 16 | if opt.isTrain and len(opt.gpu_ids): 17 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 18 | 19 | return model, num_params_G, num_params_D 20 | 21 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import functools 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | 10 | ############################################################ 11 | ### Functions 12 | ############################################################ 13 | def weights_init(m): 14 | classname = m.__class__.__name__ 15 | if hasattr(m, 'weight') and classname.find('Conv2d') != -1: 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 17 | m.weight.data *= 0.1 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif classname.find('BatchNorm2d') != -1: 21 | m.weight.data.normal_(1.0, 0.02) 22 | m.bias.data.fill_(0) 23 | elif classname.find('ConvTranspose2d') != -1: 24 | m.weight.data.normal_(0.0, 0.02) 25 | if m.bias is not None: 26 | m.bias.data.zero_() 27 | elif classname.find('Linear') != -1: 28 | m.weight.data.normal_(0.0, 0.01) 29 | if m.bias is not None: 30 | m.bias.data.zero_() 31 | 32 | def get_norm_layer(norm_type='instance'): 33 | if norm_type == 'batch': 34 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 35 | elif norm_type == 'instance': 36 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 37 | else: 38 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 39 | return norm_layer 40 | 41 | def print_network(net): 42 | num_params = 0 43 | for param in net.parameters(): 44 | num_params += param.numel() 45 | print(net) 46 | print('Total number of parameters: %d' % num_params) 47 | print('--------------------------------------------------------------') 48 | return num_params 49 | 50 | def define_G(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=9, norm='instance', gpu_ids=[]): 51 | netG = ImageTinker2(input_nc, output_nc, ngf=64, n_downsampling=4, n_blocks=4, norm_layer=nn.BatchNorm2d, pad_type='replicate') 52 | 53 | num_params = print_network(netG) 54 | 55 | if len(gpu_ids) > 0: 56 | assert(torch.cuda.is_available()) 57 | netG.cuda(gpu_ids[0]) 58 | netG.apply(weights_init) 59 | 60 | return netG, num_params 61 | 62 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): 63 | norm_layer = get_norm_layer(norm_type=norm) 64 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) 65 | num_params = print_network(netD) 66 | 67 | if len(gpu_ids) > 0: 68 | assert(torch.cuda.is_available()) 69 | netD.cuda(gpu_ids[0]) 70 | netD.apply(weights_init) 71 | 72 | return netD, num_params 73 | 74 | class ImageTinker2(nn.Module): 75 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=4, n_blocks=4, norm_layer=nn.InstanceNorm2d, pad_type='replicate', activation=nn.LeakyReLU(0.2, True)): 76 | assert(n_blocks >= 0) 77 | super(ImageTinker2, self).__init__() 78 | 79 | if pad_type == 'reflect': 80 | self.pad = nn.ReflectionPad2d 81 | elif pad_type == 'zero': 82 | self.pad = nn.ZeroPad2d 83 | elif pad_type == 'replicate': 84 | self.pad = nn.ReplicationPad2d 85 | 86 | # LR coarse tinker (encoder) 87 | lr_coarse_tinker = [self.pad(3), nn.Conv2d(input_nc, ngf // 2, kernel_size=7, stride=1, padding=0), activation] 88 | lr_coarse_tinker += [self.pad(1), nn.Conv2d(ngf // 2, ngf, kernel_size=4, stride=2, padding=0), activation] 89 | lr_coarse_tinker += [self.pad(1), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=0), activation] 90 | lr_coarse_tinker += [self.pad(1), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=0), activation] 91 | # bottle neck 92 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)] 93 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)] 94 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)] 95 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)] 96 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)] 97 | lr_coarse_tinker += [MultiDilationResnetBlock(ngf * 4, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None)] 98 | # decoder 99 | lr_coarse_tinker += [nn.UpsamplingBilinear2d(scale_factor=2), self.pad(1), nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=0), activation] 100 | lr_coarse_tinker += [nn.UpsamplingBilinear2d(scale_factor=2), self.pad(1), nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=0), activation] 101 | lr_coarse_tinker += [nn.UpsamplingBilinear2d(scale_factor=2), self.pad(1), nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=0), activation] 102 | lr_coarse_tinker += [self.pad(3), nn.Conv2d(ngf // 2, output_nc, kernel_size=7, stride=1, padding=0)] 103 | ### get a coarse (256x256x3) 104 | self.lr_coarse_tinker = nn.Sequential(*lr_coarse_tinker) 105 | 106 | self.r_en_padd1 = self.pad(3) 107 | self.r_en_conv1 = nn.Conv2d(input_nc, ngf // 2, kernel_size=7, stride=1, padding=0) 108 | self.r_en_acti1 = activation 109 | 110 | self.r_en_padd2 = self.pad(1) 111 | self.r_en_conv2 = nn.Conv2d(ngf // 2, ngf, kernel_size=4, stride=2, padding=0) 112 | self.r_en_acti2 = activation 113 | 114 | self.r_en_padd3 = self.pad(1) 115 | self.r_en_conv3 = nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=0) 116 | self.r_en_acti3 = activation 117 | 118 | self.r_en_skp_padd3 = self.pad(1) 119 | self.r_en_skp_conv3 = nn.Conv2d(ngf * 2, ngf * 2 // 2, kernel_size=3, stride=1, padding=0) 120 | self.r_en_skp_acti3 = activation 121 | 122 | self.r_en_padd4 = self.pad(1) 123 | self.r_en_conv4 = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=0) 124 | self.r_en_acti4 = activation 125 | 126 | self.r_en_skp_padd4 = self.pad(1) 127 | self.r_en_skp_conv4 = nn.Conv2d(ngf * 4, ngf * 4 // 2, kernel_size=3, stride=1, padding=0) 128 | self.r_en_skp_acti4 = activation 129 | 130 | self.r_en_padd5 = self.pad(1) 131 | self.r_en_conv5 = nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=0) 132 | self.r_en_acti5 = activation 133 | 134 | self.r_md_mres1 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None) 135 | self.r_md_mres2 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None) 136 | self.r_md_mres5 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None) 137 | self.r_md_satn1 = NonLocalBlock(ngf * 8, sub_sample=False, bn_layer=False) 138 | self.r_md_mres3 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None) 139 | self.r_md_mres4 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None) 140 | self.r_md_mres6 = MultiDilationResnetBlock_v3(ngf * 8, kernel_size=3, stride=1, padding=1, pad_type='replicate', norm=None) 141 | 142 | self.r_de_upbi1 = nn.UpsamplingBilinear2d(scale_factor=2) 143 | self.r_de_padd1 = self.pad(1) 144 | self.r_de_conv1 = nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=0) 145 | self.r_de_acti1 = activation 146 | 147 | self.r_de_satn2 = NonLocalBlock(ngf * 4 // 2, sub_sample=False, bn_layer=False) 148 | self.r_de_satn3 = NonLocalBlock(ngf * 2 // 2, sub_sample=False, bn_layer=False) 149 | 150 | self.r_de_mix_padd1 = self.pad(1) 151 | self.r_de_mix_conv1 = nn.Conv2d(ngf * 4 + ngf * 4 // 2, ngf * 4, kernel_size=3, stride=1, padding=0) 152 | self.r_de_mix_acti1 = activation 153 | 154 | self.r_de_upbi2 = nn.UpsamplingBilinear2d(scale_factor=2) 155 | self.r_de_padd2 = self.pad(1) 156 | self.r_de_conv2 = nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=0) 157 | self.r_de_acti2 = activation 158 | 159 | self.r_de_mix_padd2 = self.pad(1) 160 | self.r_de_mix_conv2 = nn.Conv2d(ngf * 2 + ngf * 2 // 2, ngf * 2, kernel_size=3, stride=1, padding=0) 161 | self.r_de_mix_acti2 = activation 162 | 163 | self.r_de_padd2_lr = self.pad(1) 164 | self.r_de_conv2_lr = nn.Conv2d(ngf * 2, ngf // 2, kernel_size=3, stride=1, padding=0) 165 | self.r_de_acti2_lr = activation 166 | 167 | self.r_de_padd3_lr = self.pad(1) 168 | self.r_de_conv3_lr = nn.Conv2d(ngf // 2, output_nc, kernel_size=3, stride=1, padding=0) 169 | 170 | self.r_de_upbi3 = nn.UpsamplingBilinear2d(scale_factor=2) 171 | self.r_de_padd3 = self.pad(1) 172 | self.r_de_conv3 = nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=0) 173 | self.r_de_acti3 = activation 174 | 175 | self.r_de_upbi4 = nn.UpsamplingBilinear2d(scale_factor=2) 176 | self.r_de_padd4 = self.pad(1) 177 | self.r_de_conv4 = nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=0) 178 | self.r_de_acti4 = activation 179 | 180 | self.r_de_padd5 = self.pad(3) 181 | self.r_de_conv5 = nn.Conv2d(ngf // 2, output_nc, kernel_size=7, stride=1, padding=0) 182 | 183 | self.r_de_padd5_lr_alpha = self.pad(1) 184 | self.r_de_conv5_lr_alpha = nn.Conv2d(ngf // 2, 1, kernel_size=3, stride=1, padding=0) 185 | self.r_de_acti5_lr_alpha = nn.Sigmoid() 186 | 187 | self.up = nn.UpsamplingBilinear2d(scale_factor=4) 188 | self.down = nn.UpsamplingBilinear2d(scale_factor=0.25) 189 | 190 | def forward(self, msked_img, msk, real_img=None): 191 | if real_img is not None: 192 | rimg = real_img 193 | inp = real_img * (1 - msk) + msk 194 | else: 195 | rimg = msked_img 196 | inp = msked_img 197 | 198 | x = torch.cat((inp, msk), dim=1) 199 | lr_x = self.lr_coarse_tinker(x) 200 | hr_x = lr_x * msk + rimg * (1 - msk) 201 | 202 | y = torch.cat((hr_x, msk), dim=1) 203 | e1 = self.r_en_acti1(self.r_en_conv1(self.r_en_padd1(y))) 204 | e2 = self.r_en_acti2(self.r_en_conv2(self.r_en_padd2(e1))) 205 | e3 = self.r_en_acti3(self.r_en_conv3(self.r_en_padd3(e2))) 206 | e4 = self.r_en_acti4(self.r_en_conv4(self.r_en_padd4(e3))) 207 | e5 = self.r_en_acti5(self.r_en_conv5(self.r_en_padd5(e4))) 208 | 209 | skp_e3 = self.r_en_skp_acti3(self.r_en_skp_conv3(self.r_en_skp_padd3(e3))) 210 | skp_e4 = self.r_en_skp_acti4(self.r_en_skp_conv4(self.r_en_skp_padd4(e4))) 211 | 212 | de3 = self.r_de_satn3(skp_e3) 213 | de4 = self.r_de_satn2(skp_e4) 214 | 215 | m1 = self.r_md_mres1(e5) 216 | m2 = self.r_md_mres2(m1) 217 | m5 = self.r_md_mres5(m2) 218 | a1 = self.r_md_satn1(m5) 219 | m3 = self.r_md_mres3(a1) 220 | m4 = self.r_md_mres4(m3) 221 | m6 = self.r_md_mres6(m4) 222 | 223 | d1 = self.r_de_acti1(self.r_de_conv1((self.r_de_padd1(self.r_de_upbi1(m6))))) # 32x32x256 224 | cat1 = torch.cat((d1, de4), dim=1) 225 | md1 = self.r_de_mix_acti1(self.r_de_mix_conv1(self.r_de_mix_padd1(cat1))) 226 | 227 | d2 = self.r_de_acti2(self.r_de_conv2((self.r_de_padd2(self.r_de_upbi2(md1))))) # 64x64x128 228 | cat2 = torch.cat((d2, de3), dim=1) 229 | md2 = self.r_de_mix_acti2(self.r_de_mix_conv2(self.r_de_mix_padd2(cat2))) 230 | 231 | d2_lr = self.r_de_acti2_lr(self.r_de_conv2_lr(self.r_de_padd2_lr(md2))) 232 | d3_lr = self.r_de_conv3_lr(self.r_de_padd3_lr(d2_lr)) 233 | 234 | d3 = self.r_de_acti3(self.r_de_conv3((self.r_de_padd3(self.r_de_upbi3(md2))))) # 128x128x64 235 | d4 = self.r_de_acti4(self.r_de_conv4((self.r_de_padd4(self.r_de_upbi4(d3))))) # 256x256x32 236 | 237 | d5 = self.r_de_conv5(self.r_de_padd5(d4)) 238 | d5_lr_alpha = self.r_de_acti5_lr_alpha(self.r_de_conv5_lr_alpha(self.r_de_padd5_lr_alpha(d4))) 239 | 240 | ### 241 | # d5: 256x256x3 242 | # d5_lr_alpha: 256x256x1 243 | # d3_lr: 64x64x3 244 | ### 245 | lr_img = d3_lr 246 | 247 | #reconst_img = d5 248 | d5 = d5 * msk + rimg * (1 - msk) 249 | lr_d5 = self.down(d5) 250 | lr_d5_res = d3_lr - lr_d5 251 | hr_d5_res = self.up(lr_d5_res) 252 | reconst_img = d5 + hr_d5_res * d5_lr_alpha 253 | compltd_img = reconst_img * msk + rimg * (1 - msk) 254 | #out = compltd_img + hr_d5_res * d5_lr_alpha 255 | 256 | return compltd_img, reconst_img, lr_x, lr_img 257 | 258 | 259 | ############################################################ 260 | ### Losses 261 | ############################################################ 262 | class TVLoss(nn.Module): 263 | def forward(self, x): 264 | batch_size = x.size()[0] 265 | h_x = x.size()[2] 266 | w_x = x.size()[3] 267 | count_h = self._tensor_size(x[:, :, 1:, :]) 268 | count_w = self._tensor_size(x[:, :, :, 1:]) 269 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 270 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 271 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_size 272 | 273 | def _tensor_size(self, t): 274 | return t.size()[1] * t.size()[2] * t.size()[3] 275 | 276 | class VGGLoss(nn.Module): 277 | # vgg19 perceptual loss 278 | def __init__(self, gpu_ids): 279 | super(VGGLoss, self).__init__() 280 | self.vgg = Vgg19().cuda() 281 | self.criterion = nn.L1Loss() 282 | self.mse_loss = nn.MSELoss() 283 | 284 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 285 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 286 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 287 | self.register_buffer('mean', mean) 288 | self.register_buffer('std', std) 289 | 290 | def gram_matrix(self, x): 291 | (b, ch, h, w) = x.size() 292 | features = x.view(b, ch, w*h) 293 | features_t = features.transpose(1, 2) 294 | gram = features.bmm(features_t) / (ch * h * w) 295 | return gram 296 | 297 | def forward(self, x, y): 298 | x = (x - self.mean) / self.std 299 | y = (y - self.mean) / self.std 300 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 301 | 302 | loss = 0 303 | style_loss = 0 304 | for i in range(len(x_vgg)): 305 | loss += self.weights[i] * \ 306 | self.criterion(x_vgg[i], y_vgg[i].detach()) 307 | gm_x = self.gram_matrix(x_vgg[i]) 308 | gm_y = self.gram_matrix(y_vgg[i]) 309 | style_loss += self.weights[i] * self.mse_loss(gm_x, gm_y.detach()) 310 | return loss, style_loss 311 | 312 | class GANLoss_D_v2(nn.Module): 313 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): 314 | super(GANLoss_D_v2, self).__init__() 315 | self.real_label = target_real_label 316 | self.fake_label = target_fake_label 317 | self.real_label_var = None 318 | self.fake_label_var = None 319 | self.Tensor = tensor 320 | if use_lsgan: 321 | self.loss = nn.MSELoss() 322 | else: 323 | def wgan_loss(input, target): 324 | return torch.mean(F.relu(1.-input)) if target else torch.mean(F.relu(1.+input)) 325 | self.loss = wgan_loss 326 | 327 | def get_target_tensor(self, input, target_is_real): 328 | target_tensor = None 329 | if target_is_real: 330 | create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) 331 | if create_label: 332 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 333 | self.real_label_var = Variable(real_tensor, requires_grad=False) 334 | target_tensor = self.real_label_var 335 | else: 336 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) 337 | if create_label: 338 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 339 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 340 | target_tensor = self.fake_label_var 341 | return target_tensor 342 | 343 | def __call__(self, input, target_is_real): 344 | if isinstance(input[0], list): 345 | loss = 0 346 | for input_i in input: 347 | pred = input_i[-1] 348 | target_tensor = self.get_target_tensor(pred, target_is_real) 349 | #loss += self.loss(pred, target_tensor) 350 | loss += self.loss(pred, target_is_real) 351 | return loss 352 | else: 353 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 354 | return self.loss(input[-1], target_tensor) 355 | 356 | class GANLoss_G_v2(nn.Module): 357 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): 358 | super(GANLoss_G_v2, self).__init__() 359 | self.real_label = target_real_label 360 | self.fake_label = target_fake_label 361 | self.real_label_var = None 362 | self.fake_label_var = None 363 | self.Tensor = tensor 364 | if use_lsgan: 365 | self.loss = nn.MSELoss() 366 | else: 367 | def wgan_loss(input, target): 368 | return -1 * input.mean() if target else input.mean() 369 | self.loss = wgan_loss 370 | 371 | def get_target_tensor(self, input, target_is_real): 372 | target_tensor = None 373 | if target_is_real: 374 | create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) 375 | if create_label: 376 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 377 | self.real_label_var = Variable(real_tensor, requires_grad=False) 378 | target_tensor = self.real_label_var 379 | else: 380 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) 381 | if create_label: 382 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 383 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 384 | target_tensor = self.fake_label_var 385 | return target_tensor 386 | 387 | def __call__(self, input, target_is_real): 388 | if isinstance(input[0], list): 389 | loss = 0 390 | for input_i in input: 391 | pred = input_i[-1] 392 | target_tensor = self.get_target_tensor(pred, target_is_real) 393 | #loss += self.loss(pred, target_tensor) 394 | loss += self.loss(pred, target_is_real) 395 | return loss 396 | else: 397 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 398 | return self.loss(input[-1], target_tensor) 399 | 400 | 401 | # Define the PatchGAN discriminator with the specified arguments. 402 | class NLayerDiscriminator(nn.Module): 403 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_sigmoid=False, getIntermFeat=False): 404 | super(NLayerDiscriminator, self).__init__() 405 | self.getIntermFeat = getIntermFeat 406 | self.n_layers = n_layers 407 | 408 | kw = 4 409 | padw = int(np.ceil((kw-1.0)/2)) 410 | sequence = [[SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]] 411 | 412 | nf = ndf 413 | for n in range(1, n_layers): 414 | nf_prev = nf 415 | nf = min(nf * 2, 512) 416 | sequence += [[ 417 | SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)), 418 | nn.LeakyReLU(0.2, True) 419 | ]] 420 | 421 | nf_prev = nf 422 | nf = min(nf * 2, 512) 423 | sequence += [[ 424 | SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw)), 425 | nn.LeakyReLU(0.2, True) 426 | ]] 427 | 428 | sequence += [[SpectralNorm(nn.Conv2d(nf, nf, kernel_size=kw, stride=1, padding=padw))]] 429 | 430 | # if use_sigmoid: 431 | # sequence += [[nn.Sigmoid()]] 432 | 433 | if getIntermFeat: 434 | for n in range(len(sequence)): 435 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 436 | else: 437 | sequence_stream = [] 438 | for n in range(len(sequence)): 439 | sequence_stream += sequence[n] 440 | self.model = nn.Sequential(*sequence_stream) 441 | 442 | def forward(self, input): 443 | if self.getIntermFeat: 444 | res = [input] 445 | for n in range(self.n_layers + 2): 446 | model = getattr(self, 'model'+str(n)) 447 | res.append(model(res[-1])) 448 | return res[1:] 449 | else: 450 | return self.model(input) 451 | 452 | 453 | # Define the Multiscale Discriminator. 454 | class MultiscaleDiscriminator(nn.Module): 455 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, getIntermFeat=False): 456 | super(MultiscaleDiscriminator, self).__init__() 457 | self.num_D = num_D 458 | self.n_layers = n_layers 459 | self.getIntermFeat = getIntermFeat 460 | 461 | for i in range(num_D): 462 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 463 | if getIntermFeat: 464 | for j in range(n_layers+2): 465 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) 466 | else: 467 | setattr(self, 'layer'+str(i), netD.model) 468 | 469 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 470 | 471 | def singleD_forward(self, model, input): 472 | if self.getIntermFeat: 473 | result = [input] 474 | for i in range(len(model)): 475 | result.append(model[i](result[-1])) 476 | return result[1:] 477 | else: 478 | return [model(input)] 479 | 480 | def forward(self, input): 481 | num_D = self.num_D 482 | result = [] 483 | input_downsampled = input 484 | for i in range(num_D): 485 | if self.getIntermFeat: 486 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] 487 | else: 488 | model = getattr(self, 'layer'+str(num_D-1-i)) 489 | result.append(self.singleD_forward(model, input_downsampled)) 490 | if i != (num_D-1): 491 | input_downsampled = self.downsample(input_downsampled) 492 | return result 493 | 494 | 495 | ### Define Vgg19 for vgg_loss 496 | class Vgg19(nn.Module): 497 | def __init__(self, requires_grad=False): 498 | super(Vgg19, self).__init__() 499 | vgg_pretrained_features = models.vgg19(pretrained=True).features 500 | self.slice1 = nn.Sequential() 501 | self.slice2 = nn.Sequential() 502 | self.slice3 = nn.Sequential() 503 | self.slice4 = nn.Sequential() 504 | self.slice5 = nn.Sequential() 505 | 506 | for x in range(1): 507 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 508 | for x in range(1, 6): 509 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 510 | for x in range(6, 11): 511 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 512 | for x in range(11, 20): 513 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 514 | for x in range(20, 29): 515 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 516 | 517 | # fixed pretrained vgg19 model for feature extraction 518 | if not requires_grad: 519 | for param in self.parameters(): 520 | param.requires_grad = False 521 | 522 | def forward(self, x): 523 | h_relu1 = self.slice1(x) 524 | h_relu2 = self.slice2(h_relu1) 525 | h_relu3 = self.slice3(h_relu2) 526 | h_relu4 = self.slice4(h_relu3) 527 | h_relu5 = self.slice5(h_relu4) 528 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 529 | return out 530 | 531 | ### Multi-Dilation ResnetBlock 532 | class MultiDilationResnetBlock(nn.Module): 533 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False): 534 | super(MultiDilationResnetBlock, self).__init__() 535 | 536 | self.branch1 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=2, dilation=2, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 537 | self.branch2 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=3, dilation=3, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 538 | self.branch3 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=4, dilation=4, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 539 | self.branch4 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=5, dilation=5, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 540 | self.branch5 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=6, dilation=6, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 541 | self.branch6 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=8, dilation=8, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 542 | self.branch7 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=10, dilation=10, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 543 | self.branch8 = ConvBlock(input_nc, input_nc // 8, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 544 | 545 | self.fusion9 = ConvBlock(input_nc, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type=pad_type, norm=norm, acti=None) 546 | 547 | def forward(self, x): 548 | d1 = self.branch1(x) 549 | d2 = self.branch2(x) 550 | d3 = self.branch3(x) 551 | d4 = self.branch4(x) 552 | d5 = self.branch5(x) 553 | d6 = self.branch6(x) 554 | d7 = self.branch7(x) 555 | d8 = self.branch8(x) 556 | d9 = torch.cat((d1, d2, d3, d4, d5, d6, d7, d8), dim=1) 557 | out = x + self.fusion9(d9) 558 | return out 559 | 560 | ### Multi-Dilation ResnetBlock 561 | class MultiDilationResnetBlock_v3(nn.Module): 562 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False): 563 | super(MultiDilationResnetBlock_v3, self).__init__() 564 | 565 | self.branch1 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=2, dilation=2, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 566 | self.branch2 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=3, dilation=3, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 567 | self.branch3 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=4, dilation=4, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 568 | self.branch4 = ConvBlock(input_nc, input_nc // 4, kernel_size=3, stride=1, padding=5, dilation=5, groups=1, bias=True, pad_type=pad_type, norm=norm, acti='relu') 569 | 570 | self.fusion5 = ConvBlock(input_nc, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type=pad_type, norm=norm, acti=None) 571 | 572 | def forward(self, x): 573 | d1 = self.branch1(x) 574 | d2 = self.branch2(x) 575 | d3 = self.branch3(x) 576 | d4 = self.branch4(x) 577 | d5 = torch.cat((d1, d2, d3, d4), dim=1) 578 | out = x + self.fusion5(d5) 579 | return out 580 | 581 | ### ResnetBlock 582 | class ResnetBlock(nn.Module): 583 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False): 584 | super(ResnetBlock, self).__init__() 585 | self.conv_block = self.build_conv_block(input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout) 586 | 587 | 588 | def build_conv_block(self, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout): 589 | conv_block = [] 590 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti='relu')] 591 | if use_dropout: 592 | conv_block += [nn.Dropout(0.5)] 593 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti=None)] 594 | 595 | return nn.Sequential(*conv_block) 596 | 597 | def forward(self, x): 598 | out = x + self.conv_block(x) 599 | return out 600 | 601 | ### ResnetBlock 602 | class ResnetBlock_v2(nn.Module): 603 | def __init__(self, input_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti='relu', use_dropout=False): 604 | super(ResnetBlock_v2, self).__init__() 605 | self.conv_block = self.build_conv_block(input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout) 606 | 607 | 608 | def build_conv_block(self, input_nc, kernel_size, stride, padding, dilation, groups, bias, pad_type, norm, acti, use_dropout): 609 | conv_block = [] 610 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size=3, stride=1, padding=padding, dilation=dilation, groups=groups, bias=bias, pad_type=pad_type, norm=norm, acti='elu')] 611 | if use_dropout: 612 | conv_block += [nn.Dropout(0.5)] 613 | conv_block += [ConvBlock(input_nc, input_nc, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, pad_type='reflect', norm='instance', acti=None)] 614 | 615 | return nn.Sequential(*conv_block) 616 | 617 | def forward(self, x): 618 | out = x + self.conv_block(x) 619 | return out 620 | 621 | ### NonLocalBlock2D 622 | class NonLocalBlock(nn.Module): 623 | def __init__(self, input_nc, inter_nc=None, sub_sample=True, bn_layer=True): 624 | super(NonLocalBlock, self).__init__() 625 | self.input_nc = input_nc 626 | self.inter_nc = inter_nc 627 | 628 | if inter_nc is None: 629 | self.inter_nc = input_nc // 2 630 | 631 | self.g = nn.Conv2d(in_channels=self.input_nc, out_channels=self.inter_nc, kernel_size=1, stride=1, padding=0) 632 | 633 | if bn_layer: 634 | self.W = nn.Sequential( 635 | nn.Conv2d(in_channels=self.inter_nc, out_channels=self.input_nc, kernel_size=1, stride=1, padding=0), 636 | nn.BatchNorm2d(self.input_nc) 637 | ) 638 | self.W[0].weight.data.zero_() 639 | self.W[0].bias.data.zero_() 640 | else: 641 | self.W = nn.Conv2d(in_channels=self.inter_nc, out_channels=self.input_nc, kernel_size=1, stride=1, padding=0) 642 | self.W.weight.data.zero_() 643 | self.W.bias.data.zero_() 644 | 645 | self.theta = nn.Conv2d(in_channels=self.input_nc, out_channels=self.inter_nc, kernel_size=1, stride=1, padding=0) 646 | self.phi = nn.Conv2d(in_channels=self.input_nc, out_channels=self.inter_nc, kernel_size=1, stride=1, padding=0) 647 | 648 | if sub_sample: 649 | self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size(2, 2))) 650 | self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size(2, 2))) 651 | 652 | def forward(self, x): 653 | batch_size = x.size(0) 654 | 655 | g_x = self.g(x).view(batch_size, self.inter_nc, -1) 656 | g_x = g_x.permute(0, 2, 1) 657 | 658 | theta_x = self.theta(x).view(batch_size, self.inter_nc, -1) 659 | theta_x = theta_x.permute(0, 2, 1) 660 | 661 | phi_x = self.phi(x).view(batch_size, self.inter_nc, -1) 662 | 663 | f = torch.matmul(theta_x, phi_x) 664 | f_div_C = F.softmax(f, dim=-1) 665 | 666 | y = torch.matmul(f_div_C, g_x) 667 | y = y.permute(0, 2, 1).contiguous() 668 | y = y.view(batch_size, self.inter_nc, *x.size()[2:]) 669 | W_y = self.W(y) 670 | 671 | z = W_y + x 672 | return z 673 | 674 | ### ConvBlock 675 | class ConvBlock(nn.Module): 676 | def __init__(self, input_nc, output_nc, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, pad_type='zero', norm=None, acti='lrelu'): 677 | super(ConvBlock, self).__init__() 678 | self.use_bias = bias 679 | 680 | # initialize padding 681 | if pad_type == 'reflect': 682 | self.pad = nn.ReflectionPad2d(padding) 683 | elif pad_type == 'zero': 684 | self.pad = nn.ZeroPad2d(padding) 685 | elif pad_type == 'replicate': 686 | self.pad = nn.ReplicationPad2d(padding) 687 | else: 688 | assert 0, "Unsupported padding type: {}".format(pad_type) 689 | 690 | # initialize normalization 691 | if norm == 'batch': 692 | self.norm = nn.BatchNorm2d(output_nc) 693 | elif norm == 'instance': 694 | self.norm = nn.InstanceNorm2d(output_nc) 695 | elif norm is None or norm == 'spectral': 696 | self.norm = None 697 | else: 698 | assert 0, "Unsupported normalization: {}".format(norm) 699 | 700 | # initialize activation 701 | if acti == 'relu': 702 | self.acti = nn.ReLU(inplace=True) 703 | elif acti == 'lrelu': 704 | self.acti = nn.LeakyReLU(0.2, inplace=True) 705 | elif acti == 'prelu': 706 | self.acti = nn.PReLU() 707 | elif acti == 'elu': 708 | self.acti = nn.ELU() 709 | elif acti == 'tanh': 710 | self.acti = nn.Tanh() 711 | elif acti == 'sigmoid': 712 | self.acti = nn.Sigmoid() 713 | elif acti is None: 714 | self.acti = None 715 | else: 716 | assert 0, "Unsupported activation: {}".format(acti) 717 | 718 | # initialize convolution 719 | if norm == 'spectral': 720 | self.conv = SpectralNorm(nn.Conv2d(input_nc, output_nc, kernel_size, stride, dilation=dilation, groups=groups, bias=self.use_bias)) 721 | else: 722 | self.conv = nn.Conv2d(input_nc, output_nc, kernel_size, stride, dilation=dilation, groups=groups, bias=self.use_bias) 723 | 724 | def forward(self, x): 725 | x = self.conv(self.pad(x)) 726 | if self.norm: 727 | x = self.norm(x) 728 | if self.acti: 729 | x = self.acti(x) 730 | return x 731 | 732 | def l2normalize(v, eps=1e-12): 733 | return v / (v.norm() + eps) 734 | 735 | ### SpectralNorm 736 | class SpectralNorm(nn.Module): 737 | """ 738 | Spectral Normalization for Generative Adversarial Networks 739 | Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 740 | """ 741 | def __init__(self, module, name='weight', power_iterations=1): 742 | super(SpectralNorm, self).__init__() 743 | self.module = module 744 | self.name = name 745 | self.power_iterations = power_iterations 746 | if not self._made_params(): 747 | self._make_params() 748 | 749 | def _update_u_v(self): 750 | u = getattr(self.module, self.name + "_u") 751 | v = getattr(self.module, self.name + "_v") 752 | w = getattr(self.module, self.name + "_bar") 753 | 754 | height = w.data.shape[0] 755 | for _ in range(self.power_iterations): 756 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 757 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 758 | 759 | sigma = u.dot(w.view(height, -1).mv(v)) 760 | setattr(self.module, self.name, w / sigma.expand_as(w)) 761 | 762 | def _made_params(self): 763 | try: 764 | u = getattr(self.module, self.name + "_u") 765 | v = getattr(self.module, self.name + "_v") 766 | w = getattr(self.module, self.name + "_bar") 767 | return True 768 | except AttributeError: 769 | return False 770 | 771 | def _make_params(self): 772 | w = getattr(self.module, self.name) 773 | 774 | height = w.data.shape[0] 775 | width = w.view(height, -1).data.shape[1] 776 | 777 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 778 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 779 | u.data = l2normalize(u.data) 780 | v.data = l2normalize(v.data) 781 | w_bar = nn.Parameter(w.data) 782 | 783 | del self.module._parameters[self.name] 784 | 785 | self.module.register_parameter(self.name + "_u", u) 786 | self.module.register_parameter(self.name + "_v", v) 787 | self.module.register_parameter(self.name + "_bar", w_bar) 788 | 789 | def forward(self, *args): 790 | self._update_u_v() 791 | return self.module.forward(*args) 792 | 793 | -------------------------------------------------------------------------------- /models/our_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | from torch.autograd import Variable 6 | #from util.image_pool import ImagePool 7 | # BaseModel, save_network & load_network 8 | from .base_model import BaseModel 9 | from . import networks 10 | 11 | class OurModel(BaseModel): 12 | def name(self): 13 | return 'OurModel' 14 | 15 | def get_num_params(self): 16 | return self.num_params_G, self.num_params_D 17 | 18 | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): 19 | flags = (True, True, True, True, use_gan_feat_loss, use_vgg_loss, True, True, True) 20 | def loss_filter(g_gan, g_coarse_l1, g_out_l1, g_style, g_gan_feat, g_vgg, g_tv, d_real, d_fake): 21 | return [l for (l, f) in zip((g_gan, g_coarse_l1, g_out_l1, g_style, g_gan_feat, g_vgg, g_tv, d_real, d_fake), flags) if f] 22 | return loss_filter 23 | 24 | def initialize(self, opt): 25 | BaseModel.initialize(self, opt) 26 | self.isTrain = opt.isTrain 27 | input_nc = opt.input_nc ### masked_img + mask (RGB + gray) 4 channels 28 | 29 | ### define networks 30 | # Generator 31 | netG_input_nc = input_nc 32 | self.netG, self.num_params_G = networks.define_G(netG_input_nc, opt.output_nc, ngf=64, gpu_ids=self.gpu_ids) 33 | 34 | # Discriminator 35 | if self.isTrain: 36 | use_sigmoid = opt.no_lsgan 37 | netD_input_nc = opt.output_nc * 2 38 | self.netD, self.num_params_D = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) 39 | # for param in self.netD.parameters(): 40 | # param.requires_grad = False 41 | else: 42 | self.num_params_D = 0 43 | 44 | print('-------------------- Networks initialized --------------------') 45 | 46 | ### load networks 47 | if not self.isTrain or opt.continue_train or opt.load_pretrain: 48 | pretrained_path = '' if not self.isTrain else opt.load_pretrain 49 | print(pretrained_path) 50 | 51 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) 52 | if self.isTrain: 53 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) 54 | 55 | 56 | ### set loss functions and optimizers 57 | if self.isTrain: 58 | if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: 59 | raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") 60 | # self.fake_pool = ImagePool(opt.pool_size) 61 | self.old_lr = opt.lr 62 | 63 | # define loss functions 64 | self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) 65 | 66 | self.criterionGAN_D = networks.GANLoss_D_v2(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 67 | self.criterionGAN_G = networks.GANLoss_G_v2(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 68 | 69 | self.criterionFeat = torch.nn.L1Loss() 70 | if not opt.no_vgg_loss: 71 | self.criterionVGG = networks.VGGLoss(self.gpu_ids) 72 | self.criterionL1 = torch.nn.L1Loss() 73 | self.down = nn.UpsamplingBilinear2d(scale_factor=0.25) 74 | self.resize = nn.UpsamplingBilinear2d(size=(224, 224)) 75 | self.criterionTV = networks.TVLoss() 76 | 77 | # Names so we can breakout loss 78 | self.loss_names = self.loss_filter('G_GAN', 'G_COARSE_L1', 'G_OUT_L1', 'G_STYLE', 'G_GAN_Feat', 'G_VGG', 'G_TV', 'D_real', 'D_fake') 79 | 80 | # initialize optimizers 81 | # optimizer G 82 | if opt.niter_fix_global > 0: 83 | import sys 84 | if sys.version_info >= (3,0): 85 | finetune_list = set() 86 | else: 87 | from sets import Set 88 | finetune_list = Set() 89 | 90 | params_dict = dict(self.netG.named_parameters()) 91 | params = [] 92 | for key, value in params_dict.items(): 93 | if key.startswith('model' + str(opt.n_local_enhancers)): 94 | params += [value] 95 | finetune_list.add(key.split('.')[0]) 96 | print('--------------- Only training the local enhancer network (for %d epochs) ---------------' % opt.niter_fix_global) 97 | print('The layers that are finetuned are ', sorted(finetune_list)) 98 | else: 99 | params = list(self.netG.parameters()) 100 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 101 | 102 | # optimizer D 103 | params = list(self.netD.parameters()) 104 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr * 4.0, betas=(opt.beta1, 0.999)) 105 | 106 | def encode_input(self, masked_image, mask, real_image, infer=False): 107 | input_msked_img = masked_image.data.cuda() 108 | input_msk = mask.data.cuda() 109 | 110 | input_msked_img = Variable(input_msked_img, volatile=infer) 111 | input_msk = Variable(input_msk, volatile=infer) 112 | 113 | real_image = Variable(real_image.data.cuda(), volatile=infer) 114 | 115 | return input_msked_img, input_msk, real_image 116 | 117 | def encode_input_test(self, masked_image, mask, infer=False): 118 | input_msked_img = masked_image.data.cuda() 119 | input_msk = mask.data.cuda() 120 | 121 | input_msked_img = Variable(input_msked_img) 122 | input_msk = Variable(input_msk) 123 | 124 | return input_msked_img, input_msk 125 | 126 | def discriminate(self, input_image, test_image, use_pool=False): 127 | input_concat = torch.cat((input_image, test_image.detach()), dim=1) 128 | return self.netD.forward(input_concat) 129 | 130 | def forward(self, masked_img, mask, real_img=None, infer=False, mode=None): 131 | # inference 132 | if mode == 'inference': 133 | compltd_img, reconst_img, lr_x = self.inference(masked_img, mask) 134 | return compltd_img, reconst_img, lr_x 135 | 136 | # Encode inputs 137 | input_msked_img, input_msk, real_image = self.encode_input(masked_img, mask, real_img) 138 | 139 | # Fake Generation 140 | compltd_img, reconst_img, lr_x, lr_img = self.netG.forward(input_msked_img, input_msk, real_image) 141 | lr_msk = input_msk 142 | 143 | msk025 = self.down(input_msk) 144 | img025 = self.down(real_image) 145 | 146 | # Fake Detection and Loss 147 | pred_fake = self.discriminate(input_msked_img, compltd_img) 148 | loss_D_fake = self.criterionGAN_D(pred_fake, False) 149 | 150 | # Real Detection and Loss 151 | pred_real = self.discriminate(input_msked_img, real_image) 152 | loss_D_real = self.criterionGAN_D(pred_real, True) 153 | 154 | # GAN Loss (Fake pass-ability loss) 155 | pred_fake = self.netD.forward(torch.cat((input_msked_img, compltd_img), dim=1)) 156 | loss_G_GAN = self.criterionGAN_G(pred_fake, True) 157 | loss_G_GAN *= self.opt.lambda_gan 158 | 159 | # GAN L1 Loss 160 | loss_G_GAN_L1 = self.criterionL1(reconst_img * input_msk, real_image * input_msk) * self.opt.lambda_l1 161 | loss_G_GAN_L1 += self.criterionL1(reconst_img * (1 - input_msk), real_image * (1 - input_msk)) 162 | loss_G_GAN_L1 += self.criterionL1(lr_img * msk025, img025 * msk025) * self.opt.lambda_l1 163 | loss_G_GAN_L1 += self.criterionL1(lr_img * (1 - msk025), img025 * (1 - msk025)) 164 | 165 | loss_G_COARSE_L1 = self.criterionL1(lr_x * lr_msk, real_image * lr_msk) * self.opt.lambda_l1 166 | loss_G_COARSE_L1 += self.criterionL1(lr_x * (1 - lr_msk), real_image * (1 - lr_msk)) 167 | 168 | loss_G_TV = self.criterionTV(compltd_img) * self.opt.lambda_tv 169 | 170 | # GAN feature matching loss 171 | loss_G_GAN_Feat = 0 172 | if not self.opt.no_ganFeat_loss: 173 | feat_weights = 4.0 / (self.opt.n_layers_D + 1) 174 | D_weights = 1.0 / self.opt.num_D 175 | for i in range(self.opt.num_D): 176 | for j in range(len(pred_fake[i]) - 1): 177 | loss_G_GAN_Feat += D_weights * feat_weights * self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) 178 | loss_G_GAN_Feat *= self.opt.lambda_feat 179 | 180 | # VGG feature matching loss 181 | loss_G_VGG = 0 182 | if not self.opt.no_vgg_loss: 183 | resized_reconst_img = self.resize(reconst_img) 184 | resized_compltd_img = self.resize(compltd_img) 185 | resized_real_img = self.resize(real_image) 186 | loss_G_VGG, loss_G_VGGStyle = self.criterionVGG(resized_reconst_img, resized_real_img) 187 | loss_G_VGG2, loss_G_VGGStyle2 = self.criterionVGG(resized_compltd_img, resized_real_img) 188 | loss_G_VGG += loss_G_VGG2 189 | loss_G_VGGStyle += loss_G_VGGStyle2 190 | loss_G_VGG *= self.opt.lambda_vgg 191 | loss_G_VGGStyle *= self.opt.lambda_style 192 | 193 | # only return the fake image if necessary 194 | return [ self.loss_filter( loss_G_GAN, loss_G_COARSE_L1, loss_G_GAN_L1, loss_G_VGGStyle, loss_G_GAN_Feat, loss_G_VGG, loss_G_TV, loss_D_real, loss_D_fake ), None if not infer else compltd_img, reconst_img, lr_x ] 195 | 196 | def inference(self, masked_img, mask): 197 | # Encode inputs 198 | input_msked_img, input_msk = self.encode_input_test(Variable(masked_img), Variable(mask), infer=True) 199 | 200 | # Fake Generation 201 | if torch.__version__.startswith('0.4'): 202 | with torch.no_grad(): 203 | compltd_img, reconst_img, lr_x, lr_img = self.netG.forward(input_msked_img, input_msk) 204 | else: 205 | compltd_img, reconst_img, lr_x, lr_img = self.netG.forward(input_msked_img, input_msk) 206 | return compltd_img, reconst_img, lr_x 207 | 208 | def save(self, which_epoch): 209 | self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) 210 | self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) 211 | 212 | def update_fixed_params(self): 213 | # after fixing the global generator for a # of iterations, also start finetuning it 214 | params = list(self.netG.parameters()) 215 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 216 | print('----------------- Now also finetuning global generator -----------------') 217 | 218 | def update_learning_rate(self): 219 | lrd = self.opt.lr / self.opt.niter_decay 220 | lr = self.old_lr - lrd 221 | for param_group in self.optimizer_D.param_groups: 222 | param_group['lr'] = lr * 4.0 223 | for param_group in self.optimizer_G.param_groups: 224 | param_group['lr'] = lr 225 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 226 | self.old_lr = lr 227 | 228 | class InferenceModel(OurModel): 229 | def forward(self, inp1, inp2): 230 | masked_img = inp1 231 | mask = inp2 232 | return self.inference(masked_img, mask) 233 | 234 | -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/train_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | #################################################################################### 2 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 3 | # Licensed under the CC BY-NC-SA 4.0 license 4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 5 | #################################################################################### 6 | import argparse 7 | import os 8 | import torch 9 | from util import util 10 | 11 | 12 | class BaseOptions(): 13 | def __init__(self): 14 | self.parser = argparse.ArgumentParser(description='Inpainting') 15 | self.initialized = False 16 | 17 | def initialize(self): 18 | # experiment specifics 19 | self.parser.add_argument('--name', type=str, default='experiment', help='name of the experiment. It decides where to store samples and models') 20 | self.parser.add_argument('--gpu_ids', type=str, default='0,1', help='gpu ids: e.g. 0 0,1,2 0,2. use -1 for cpu') 21 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 22 | self.parser.add_argument('--model', type=str, default='Ours', help='which model to use') 23 | self.parser.add_argument('--norm', type=str, default='instance', help='instance or batch normalization') 24 | 25 | 26 | # input/output sizes 27 | self.parser.add_argument('--batchSize', type=int, default=4, help='input batch size') 28 | self.parser.add_argument('--loadSize', type=int, default=768, help='scale images to this size') 29 | self.parser.add_argument('--fineSize', type=int, default=256, help='image size to the model') 30 | self.parser.add_argument('--input_nc', type=int, default=4, help='# of input image channels') 31 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 32 | 33 | # for setting inputs 34 | self.parser.add_argument('--dataroot', type=str, default='./datasets/ade20k/') 35 | self.parser.add_argument('--resize_or_crop', type=str, default='standard', help='scaling and/or cropping of images at load time') 36 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, no shuffle') 37 | self.parser.add_argument('--no_flip', action='store_true', help='if true, no flip') 38 | self.parser.add_argument('--nThreads', type=int, default=6, help='# of threads for loading data') 39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='max # of images allowed per dataset') 40 | 41 | # for displays 42 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 43 | 44 | # for generator 45 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 46 | self.parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the first conv layer') 47 | self.parser.add_argument('--n_downsample_global', type=int, default=5, help='# of downsampling layers in netG') 48 | self.parser.add_argument('--n_blocks_global', type=int, default=6, help='# of resnet blocks in the global generator network') 49 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='# of resnet blocks in the local enhancer network') 50 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='# of local enhancers to use') 51 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='# of epochs that we only train the outmost local enhancer') 52 | 53 | self.initialized = True 54 | 55 | def parse(self, save=True): 56 | if not self.initialized: 57 | self.initialize() 58 | self.opt = self.parser.parse_args() 59 | self.opt.isTrain = self.isTrain 60 | 61 | str_ids = self.opt.gpu_ids.split(',') 62 | self.opt.gpu_ids = [] 63 | for str_id in str_ids: 64 | id = int(str_id) 65 | if id >= 0: 66 | self.opt.gpu_ids.append(id) 67 | 68 | # set gpu ids 69 | if len(self.opt.gpu_ids) > 0: 70 | torch.cuda.set_device(self.opt.gpu_ids[0]) 71 | 72 | args = vars(self.opt) 73 | 74 | # print options 75 | print('-------------------- Options --------------------') 76 | for k, v in sorted(args.items()): 77 | print('%s: %s' % (str(k), str(v))) 78 | print('---------------------- End ----------------------') 79 | 80 | # save the options to disk 81 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 82 | util.mkdirs(expr_dir) 83 | 84 | if save and not self.opt.continue_train: 85 | file_name = os.path.join(expr_dir, 'opt.txt') 86 | with open(file_name, 'wt') as opt_file: 87 | opt_file.write('-------------------- Options --------------------\n') 88 | for k, v in sorted(args.items()): 89 | opt_file.write('%s: %s\n' % (str(k), str(v))) 90 | opt_file.write('---------------------- End ----------------------\n') 91 | 92 | return self.opt 93 | 94 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | 7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here') 8 | self.parser.add_argument('--phase', type=str, default='test', help='train or test') 9 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load?') 10 | 11 | self.isTrain = False 12 | 13 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | 7 | # for displays 8 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 9 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 10 | self.parser.add_argument('--save_latest_freq', type=int, default=3000, help='frequency of saving the latest results') 11 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 12 | self.parser.add_argument('--no_html', action='store_true', help='if true, do not save intermediate training results to web') 13 | self.parser.add_argument('--tf_log', action='store_true', help='if true, use tensorboard logging') 14 | 15 | # for training 16 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 17 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 18 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load?') 19 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, or test') 20 | self.parser.add_argument('--niter', type=int, default=10, help='# of iter at starting learning rate') 21 | self.parser.add_argument('--niter_decay', type=int, default=90, help='# of iter to linearly decay learning rate to zero') 22 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 23 | self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 24 | 25 | # for discriminators 26 | self.parser.add_argument('--num_D', type=int, default=2, help='# of discriminators to use') 27 | self.parser.add_argument('--n_layers_D', type=int, default=2, help='only used if which_model_netD==n_layers') 28 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discriminator filters in first conv layer') 29 | self.parser.add_argument('--lambda_feat', type=float, default=0.01, help='weight for feature matching loss') 30 | self.parser.add_argument('--lambda_vgg', type=float, default=0.05, help='weight for vgg feature matching loss') 31 | self.parser.add_argument('--lambda_l1', type=float, default=5.0, help='weight for l1 loss') 32 | self.parser.add_argument('--lambda_tv', type=float, default=0.1, help='weight for tv loss') 33 | self.parser.add_argument('--lambda_style', type=float, default=80.0, help='weight for style loss') 34 | self.parser.add_argument('--lambda_gan', type=float, default=0.001, help='weight for g gan loss') 35 | self.parser.add_argument('--no_ganFeat_loss', action='store_false', help='if specified, do *not* use discriminator feature matching loss') 36 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 37 | # use least square GAN 38 | self.parser.add_argument('--no_lsgan', action='store_true', default=True, help='do *not* use least square GAN, if false, use vanilla GAN') 39 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 40 | 41 | self.isTrain = True 42 | 43 | -------------------------------------------------------------------------------- /results/test/AIM_IC_t1_validation_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/results/test/AIM_IC_t1_validation_0.png -------------------------------------------------------------------------------- /test_ensemble.py: -------------------------------------------------------------------------------- 1 | from util.visualizer import Visualizer 2 | import util.util as util 3 | from data.data_loader import CreateDataLoader 4 | from options.test_options import TestOptions 5 | import torchvision.transforms as transforms 6 | import time 7 | import os 8 | import numpy as np 9 | from PIL import Image 10 | import torch 11 | from torch.autograd import Variable 12 | from collections import OrderedDict 13 | 14 | from models.models import create_model 15 | from util import html 16 | 17 | import copy 18 | 19 | opt = TestOptions().parse(save=False) 20 | opt.nThreads = 1 21 | opt.batchSize = 1 22 | opt.serial_batches = True 23 | opt.no_flip = True 24 | 25 | data_loader = CreateDataLoader(opt) 26 | dataset = data_loader.load_data() 27 | dataset_size = len(data_loader) 28 | 29 | def __make_power_2(img, base=256, method=Image.BICUBIC): 30 | ow, oh = img.size 31 | 32 | h = int(round(oh / base) * base) 33 | w = int(round(ow / base) * base) 34 | 35 | if h == 0: 36 | h = base 37 | if w == 0: 38 | w = base 39 | 40 | if (h == oh) and (w == ow): 41 | return img 42 | return img.resize((w, h), method) 43 | 44 | model, num_params_G, num_params_D = create_model(opt) 45 | model.eval() 46 | 47 | rlt_dir = os.path.join(opt.results_dir, 'test') 48 | util.mkdirs([rlt_dir]) 49 | 50 | transform_list = [] 51 | transform_list += [transforms.ToTensor()] 52 | mskimg_transform = transforms.Compose(transform_list) 53 | transform_list = [] 54 | transform_list += [transforms.ToTensor()] 55 | msk_transform = transforms.Compose(transform_list) 56 | 57 | start_time = time.time() 58 | for i, data in enumerate(dataset): 59 | with torch.no_grad(): 60 | msk_img_path = data['path_mskimg'][0] 61 | filename = os.path.basename(msk_img_path) 62 | msk_path = data['path_msk'][0] 63 | 64 | oimg = Image.open(msk_img_path).convert('RGB') 65 | omsk = Image.open(msk_path).convert('L') 66 | ow, oh = oimg.size 67 | 68 | ### 69 | resized_img = __make_power_2(oimg) 70 | resized_msk = __make_power_2(omsk, method=Image.BILINEAR) 71 | rw, rh = resized_img.size 72 | 73 | hori_ver = rw // 256 74 | vert_ver = rh // 256 75 | 76 | tmp_img = oimg.resize((256, 256), Image.BICUBIC) 77 | tmp_msk = omsk.resize((256, 256), Image.BICUBIC) 78 | 79 | np_tmp_img = np.array(tmp_img, np.uint8) 80 | np_tmp_msk = np.array(tmp_msk, np.uint8) 81 | 82 | np_resized_img = np.array(resized_img, np.uint8) 83 | np_resized_msk = np.array(resized_msk, np.uint8) 84 | np_resized_msk = np_resized_msk > 0 85 | np_resized_img[:,:,0] = np_resized_img[:,:,0] * (1 - np_resized_msk) + 255 * np_resized_msk 86 | np_resized_img[:,:,1] = np_resized_img[:,:,1] * (1 - np_resized_msk) + 255 * np_resized_msk 87 | np_resized_img[:,:,2] = np_resized_img[:,:,2] * (1 - np_resized_msk) + 255 * np_resized_msk 88 | np_resized_msk = np_resized_msk * 255 89 | 90 | img_arr = [] 91 | msk_arr = [] 92 | 93 | ### 94 | for hv in range(hori_ver): 95 | for vv in range(vert_ver): 96 | for i in range(256): 97 | for j in range(256): 98 | np_tmp_img[i, j, 0] = np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 0] 99 | np_tmp_img[i, j, 1] = np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 1] 100 | np_tmp_img[i, j, 2] = np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 2] 101 | np_tmp_msk[i, j] = np_resized_msk[vv + vert_ver*j, hv + hori_ver*i] 102 | img_arr.append(np.copy(np_tmp_img)) 103 | msk_arr.append(np.copy(np_tmp_msk)) 104 | 105 | ### 106 | compltd_arr = [] 107 | for i in range(len(img_arr)): 108 | img = Image.fromarray(img_arr[i]) 109 | msk = Image.fromarray(msk_arr[i]) 110 | 111 | img_90 = img.rotate(90) 112 | msk_90 = msk.rotate(90) 113 | img_180 = img.rotate(180) 114 | msk_180 = msk.rotate(180) 115 | img_270 = img.rotate(270) 116 | msk_270 = msk.rotate(270) 117 | img_flp = img.transpose(method=Image.FLIP_LEFT_RIGHT) 118 | msk_flp = msk.transpose(method=Image.FLIP_LEFT_RIGHT) 119 | 120 | compltd_img, reconst_img, lr_x = model(mskimg_transform(img).unsqueeze(0), msk_transform(msk).unsqueeze(0)) 121 | compltd_img_90, reconst_img_90, lr_x_90 = model(mskimg_transform(img_90).unsqueeze(0), msk_transform(msk_90).unsqueeze(0)) 122 | compltd_img_180, reconst_img_180, lr_x_180 = model(mskimg_transform(img_180).unsqueeze(0), msk_transform(msk_180).unsqueeze(0)) 123 | compltd_img_270, reconst_img_270, lr_x_270 = model(mskimg_transform(img_270).unsqueeze(0), msk_transform(msk_270).unsqueeze(0)) 124 | compltd_img_flp, reconst_img_flp, lr_x_flp = model(mskimg_transform(img_flp).unsqueeze(0), msk_transform(msk_flp).unsqueeze(0)) 125 | np_compltd_img = util.tensor2im(reconst_img.data[0], normalize=False) 126 | np_compltd_img_90 = util.tensor2im(reconst_img_90.data[0], normalize=False) 127 | np_compltd_img_180 = util.tensor2im(reconst_img_180.data[0], normalize=False) 128 | np_compltd_img_270 = util.tensor2im(reconst_img_270.data[0], normalize=False) 129 | np_compltd_img_flp = util.tensor2im(reconst_img_flp.data[0], normalize=False) 130 | 131 | new_img_90 = Image.fromarray(np_compltd_img_90) 132 | new_img_90 = new_img_90.rotate(270) 133 | np_new_img_90 = np.array(new_img_90, np.float) 134 | 135 | new_img_180 = Image.fromarray(np_compltd_img_180) 136 | new_img_180 = new_img_180.rotate(180) 137 | np_new_img_180 = np.array(new_img_180, np.float) 138 | 139 | new_img_270 = Image.fromarray(np_compltd_img_270) 140 | new_img_270 = new_img_270.rotate(90) 141 | np_new_img_270 = np.array(new_img_270, np.float) 142 | 143 | new_img_flp = Image.fromarray(np_compltd_img_flp) 144 | new_img_flp = new_img_flp.transpose(method=Image.FLIP_LEFT_RIGHT) 145 | np_new_img_flp = np.array(new_img_flp, np.float) 146 | 147 | np_compltd_img = (np_compltd_img + np_new_img_90 + np_new_img_180 + np_new_img_270 + np_new_img_flp) / 5.0 148 | np_compltd_img = np.array(np.round(np_compltd_img), np.uint8) 149 | final_img = Image.fromarray(np_compltd_img, mode="RGB") 150 | np_compltd_img = np.array(final_img, np.uint8) 151 | 152 | compltd_arr.append(np.copy(np_compltd_img)) 153 | 154 | ### 155 | ver_idx = 0 156 | for hv in range(hori_ver): 157 | for vv in range(vert_ver): 158 | #np_tmp_img = compltd_arr[ver_idx] 159 | for i in range(256): 160 | for j in range(256): 161 | np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 0] = compltd_arr[ver_idx][i, j, 0] 162 | np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 1] = compltd_arr[ver_idx][i, j, 1] 163 | np_resized_img[vv + vert_ver*j, hv + hori_ver*i, 2] = compltd_arr[ver_idx][i, j, 2] 164 | ver_idx += 1 165 | 166 | ### 167 | new_compltd_img = Image.fromarray(np_resized_img) 168 | new_compltd_img = new_compltd_img.resize((ow, oh), Image.BICUBIC) 169 | new_compltd_img = new_compltd_img.resize((int(ow*0.5), int(oh*0.5)), Image.BICUBIC) 170 | new_compltd_img = new_compltd_img.resize((ow, oh), Image.BICUBIC) 171 | np_new_compltd_img = np.array(new_compltd_img) 172 | np_oimg = np.array(oimg) 173 | np_omsk = np.array(omsk) 174 | 175 | np_new_compltd_img[:, :, 0] = np_new_compltd_img[:, :, 0] * (np_omsk / 255.0) + ((255.0 - np_omsk) / 255.0) * np_oimg[:, :, 0] 176 | np_new_compltd_img[:, :, 1] = np_new_compltd_img[:, :, 1] * (np_omsk / 255.0) + ((255.0 - np_omsk) / 255.0) * np_oimg[:, :, 1] 177 | np_new_compltd_img[:, :, 2] = np_new_compltd_img[:, :, 2] * (np_omsk / 255.0) + ((255.0 - np_omsk) / 255.0) * np_oimg[:, :, 2] 178 | 179 | newfilename = filename.replace("_with_holes", "") 180 | compltd_path = os.path.join(rlt_dir, newfilename) 181 | util.save_image(np_new_compltd_img, compltd_path) 182 | 183 | print(compltd_path) 184 | 185 | end_time = time.time() - start_time 186 | print('Avg Time Taken: %.3f sec' % (end_time / dataset_size)) 187 | 188 | print('done') 189 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from util.visualizer import Visualizer 2 | import util.util as util 3 | from data.data_loader import CreateDataLoader 4 | from options.train_options import TrainOptions 5 | import time 6 | import os 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable 10 | from collections import OrderedDict 11 | 12 | import fractions 13 | def lcm(a, b): 14 | return abs(a*b) / fractions.gcd(a, b) if a and b else 0 15 | 16 | from models.models import create_model 17 | 18 | opt = TrainOptions().parse() 19 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 20 | param_path = os.path.join(opt.checkpoints_dir, opt.name, 'param.txt') 21 | 22 | # continue training or start from scratch 23 | if opt.continue_train: 24 | try: 25 | start_epoch, epoch_iter = np.loadtxt( 26 | iter_path, delimiter=',', dtype=int) 27 | except: 28 | start_epoch, epoch_iter = 1, 0 29 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 30 | else: 31 | start_epoch, epoch_iter = 1, 0 32 | 33 | opt.print_freq = lcm(opt.print_freq, opt.batchSize) 34 | 35 | 36 | ######################################## 37 | # load dataset 38 | ######################################## 39 | data_loader = CreateDataLoader(opt) 40 | dataset = data_loader.load_data() 41 | dataset_size = len(data_loader) 42 | print('# of training images = %d' % dataset_size) 43 | 44 | ######################################## 45 | # define model and optimizer 46 | ######################################## 47 | # define own model 48 | model, num_params_G, num_params_D = create_model(opt) 49 | optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D 50 | # output the model size (i.e. num_params_G and num_params_D to txt file) 51 | np.savetxt(param_path, (num_params_G, num_params_D), delimiter=',', fmt='%d') 52 | 53 | ######################################## 54 | # define visualizer 55 | ######################################## 56 | visualizer = Visualizer(opt) 57 | 58 | ######################################## 59 | # define train and/val loop 60 | ######################################## 61 | total_steps = (start_epoch - 1) * dataset_size + epoch_iter 62 | 63 | display_delta = total_steps % opt.display_freq 64 | print_delta = total_steps % opt.print_freq 65 | save_delta = total_steps % opt.save_latest_freq 66 | 67 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 68 | model.train() 69 | # by default, start_epoch, epoch_iter = 1, 0 70 | epoch_start_time = time.time() 71 | if epoch != start_epoch: 72 | epoch_iter = epoch_iter % dataset_size 73 | for i, data in enumerate(dataset, start=epoch_iter): 74 | if total_steps % opt.print_freq == print_delta: 75 | iter_start_time = time.time() 76 | total_steps += opt.batchSize 77 | epoch_iter += opt.batchSize 78 | 79 | # whether to collect output images 80 | save_fake = total_steps % opt.display_freq == display_delta 81 | 82 | ############### Forward pass ############### 83 | # get pred and calculate loss 84 | B, C, H, W = data['masked_image'].shape 85 | msk_img = data['masked_image'].view(-1, 3, H, W) 86 | msk = data['mask'].view(-1, 1, H, W) 87 | real_img = data['real_image'].view(-1, 3, H, W) 88 | 89 | losses, compltd_img, reconst_img, lr_x = model(Variable(msk_img), Variable(msk), 90 | Variable(real_img), infer=save_fake) 91 | 92 | # sum per device losses 93 | losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses] 94 | loss_dict = dict(zip(model.module.loss_names, losses)) 95 | 96 | # calculate final loss scalar 97 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 98 | loss_G = loss_dict['G_GAN'] + loss_dict['G_COARSE_L1'] + loss_dict['G_OUT_L1'] + loss_dict['G_TV'] + loss_dict['G_STYLE'] + loss_dict.get('G_VGG', 0) 99 | 100 | ############### Backward pass ############### 101 | # update Generator parameters 102 | optimizer_G.zero_grad() 103 | loss_G.backward() 104 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3, norm_type=2) 105 | optimizer_G.step() 106 | 107 | # update Discriminator parameters 108 | optimizer_D.zero_grad() 109 | loss_D.backward() 110 | optimizer_D.step() 111 | 112 | ############### Display results and losses ############### 113 | # print out losses 114 | if total_steps % opt.print_freq == print_delta: 115 | errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()} 116 | t = (time.time() - iter_start_time) / opt.print_freq 117 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 118 | visualizer.plot_current_errors(errors, total_steps) 119 | 120 | # display completed images 121 | if save_fake: 122 | visuals = OrderedDict([('real_image', util.tensor2im(real_img[0], normalize=False)), 123 | ('masked_image_1', util.tensor2im(msk_img[0], normalize=False)), 124 | ('coarse_reconst_1', util.tensor2im(lr_x.data[0], normalize=False)), 125 | ('output_image_1', util.tensor2im(reconst_img.data[0], normalize=False)), 126 | ('completed_image_1', util.tensor2im(compltd_img.data[0], normalize=False)), 127 | ('masked_image_2', util.tensor2im(msk_img[1], normalize=False)), 128 | ('coarse_reconst_2', util.tensor2im(lr_x.data[1], normalize=False)), 129 | ('output_image_2', util.tensor2im(reconst_img.data[1], normalize=False)), 130 | ('completed_image_2', util.tensor2im(compltd_img.data[1], normalize=False)), 131 | ('masked_image_3', util.tensor2im(msk_img[2], normalize=False)), 132 | ('coarse_reconst_3', util.tensor2im(lr_x.data[2], normalize=False)), 133 | ('output_image_3', util.tensor2im(reconst_img.data[2], normalize=False)), 134 | ('completed_image_3', util.tensor2im(compltd_img.data[2], normalize=False))]) 135 | visualizer.display_current_results(visuals, epoch, total_steps) 136 | 137 | # save the latest model 138 | if total_steps % opt.save_latest_freq == save_delta: 139 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 140 | model.module.save('latest') 141 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 142 | 143 | if epoch_iter >= dataset_size: 144 | break 145 | 146 | # end of epoch 147 | iter_end_time=time.time() 148 | print('End of epoch %d / %d \t Time Taken: %d sec' % 149 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 150 | 151 | # save model for this epoch 152 | if epoch % opt.save_epoch_freq == 0: 153 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 154 | model.module.save('latest') 155 | model.module.save(epoch) 156 | np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') 157 | 158 | # instead of only training the local enhancer, train the entire network after certain iterations 159 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 160 | model.module.update_fixed_params() 161 | 162 | # linearly decay learning rate after certain iters 163 | if epoch > opt.niter: 164 | print('update learning rate') 165 | model.module.update_learning_rate() 166 | -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/html.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlct1/gin/cfcef5dd13f6431e296c862cf4a8a4efa827b6c2/util/__pycache__/visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | 16 | self.doc = dominate.document(title=title) 17 | if refresh > 0: 18 | with self.doc.head: 19 | meta(http_equiv="refresh", content=str(refresh)) 20 | 21 | def get_image_dir(self): 22 | return self.img_dir 23 | 24 | def add_header(self, str): 25 | with self.doc: 26 | h3(str) 27 | 28 | def add_table(self, border=1): 29 | self.t = table(border=border, style="table-layout: fixed;") 30 | self.doc.add(self.t) 31 | 32 | def add_images(self, ims, txts, links, width=512): 33 | self.add_table() 34 | with self.t: 35 | with tr(): 36 | for im, txt, link in zip(ims, txts, links): 37 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 38 | with p(): 39 | with a(href=os.path.join('images', link)): 40 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 41 | br() 42 | p(txt) 43 | 44 | def save(self): 45 | html_file = '%s/index.html' % self.web_dir 46 | f = open(html_file, 'wt') 47 | f.write(self.doc.render()) 48 | f.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | html = HTML('web/', 'test_html') 53 | html.add_header('hello world') 54 | 55 | ims = [] 56 | txts = [] 57 | links = [] 58 | for n in range(4): 59 | ims.append('image_%d.jpg' % n) 60 | txts.append('text_%d.jpg' % n) 61 | links.append('image_%d.jpg' % n) 62 | html.add_images(ims, txts, links) 63 | html.save() 64 | 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | # no in use at the stage 6 | class ImagePool(): 7 | def __init__(self, pool_size): 8 | self.pool_size = pool_size 9 | if self.pool_size > 0: 10 | self.num_imgs = 0 11 | self.images = [] 12 | 13 | def query(self, images): 14 | if self.pool_size == 0: 15 | return images 16 | return_images = [] 17 | for image in images.data: 18 | image = torch.unsqueeze(image, 0) 19 | if self.num_imgs < self.pool_size: 20 | self.num_imgs = self.num_imgs + 1 21 | self.images.append(image) 22 | return_images.append(image) 23 | else: 24 | p = random.uniform(0, 1) 25 | if p > 0.5: 26 | random_id = random.randint(0, self.pool_size - 1) 27 | tmp = self.images[random_id].clone() 28 | self.images[random_id] = image 29 | return_images.append(tmp) 30 | else: 31 | return_images.append(image) 32 | return_images = Variable(torch.cat(return_images, 0)) 33 | return return_images 34 | 35 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import numpy as np 5 | import os 6 | 7 | def mkdir(path): 8 | if not os.path.exists(path): 9 | os.makedirs(path) 10 | 11 | def mkdirs(paths): 12 | if isinstance(paths, list) and not isinstance(paths, str): 13 | for path in paths: 14 | mkdir(path) 15 | else: 16 | mkdir(paths) 17 | 18 | def save_image(image_numpy, image_path): 19 | image_pil = Image.fromarray(image_numpy) 20 | image_pil.save(image_path) 21 | 22 | 23 | # converts a tensor into a numpy array 24 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 25 | if isinstance(image_tensor, list): 26 | image_numpy = [] 27 | for i in range(len(image_tensor)): 28 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 29 | return image_numpy 30 | image_numpy = image_tensor.cpu().float().numpy() 31 | if normalize: 32 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 33 | else: 34 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 35 | image_numpy = np.clip(image_numpy, 0, 255) 36 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 37 | image_numpy = image_numpy[:, :, 0] 38 | return image_numpy.astype(imtype) 39 | 40 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | import scipy.misc # scipy==1.1.0 5 | from . import html 6 | from . import util 7 | import ntpath 8 | 9 | try: 10 | from StringIO import StringIO # Python 2.7 11 | except ImportError: 12 | from io import BytesIO # Python 3.x 13 | 14 | class Visualizer(): 15 | def __init__(self, opt): 16 | self.opt = opt 17 | # tf_log use tensorboard logging 18 | self.tf_log = opt.tf_log 19 | # intermediate training results to web 20 | self.use_html = opt.isTrain and not opt.no_html 21 | self.win_size = opt.display_winsize 22 | self.name = opt.name 23 | 24 | if self.tf_log: 25 | import tensorflow as tf 26 | self.tf = tf 27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 28 | self.writer = tf.summary.FileWriter(self.log_dir) 29 | 30 | if self.use_html: 31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 32 | self.img_dir = os.path.join(self.web_dir, 'images') 33 | print('Create web directory %s ...' % self.web_dir) 34 | util.mkdirs([self.web_dir, self.img_dir]) 35 | 36 | # a txt file to record the losses 37 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 38 | with open(self.log_name, "a") as log_file: 39 | now = time.strftime("%c") 40 | log_file.write('==================== Training loss (%s) ====================\n' % now) 41 | 42 | 43 | # dictionary of images to display or save 44 | def display_current_results(self, visuals, epoch, step): 45 | if self.tf_log: # show images in tensorboard output 46 | img_summaries = [] 47 | for label, image_numpy in visuals.items(): 48 | # Write the image to a string 49 | try: 50 | s = StringIO() 51 | except: 52 | s = BytesIO() 53 | scipy.misc.toimage(image_numpy).save(s, format="png") 54 | # Create an Image object 55 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 56 | # Create a Summary value 57 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 58 | 59 | # Create and write Summary 60 | summary = self.tf.Summary(value=img_summaries) 61 | self.writer.add_summary(summary, step) 62 | 63 | if self.use_html: # save images to a html file 64 | for label, image_numpy in visuals.items(): 65 | if isinstance(image_numpy, list): 66 | for i in range(len(image_numpy)): 67 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.png' % (epoch, label, i)) 68 | util.save_image(image_numpy[i], img_path) 69 | else: 70 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 71 | util.save_image(image_numpy, img_path) 72 | 73 | # update the website 74 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) 75 | for n in range(epoch, 0, -1): 76 | webpage.add_header('epoch [%d]' % n) 77 | ims = [] 78 | txts = [] 79 | links = [] 80 | 81 | for label, image_numpy in visuals.items(): 82 | if isinstance(image_numpy, list): 83 | for i in range(len(image_numpy)): 84 | img_path = 'epoch%.3d_%s_%d.png' % (n, label, i) 85 | ims.append(img_path) 86 | txts.append(label+str(i)) 87 | links.append(img_path) 88 | else: 89 | img_path = 'epoch%.3d_%s.png' % (n, label) 90 | ims.append(img_path) 91 | txts.append(label) 92 | links.append(img_path) 93 | 94 | if len(ims) < 10: 95 | webpage.add_images(ims, txts, links, width=self.win_size) 96 | else: 97 | num = int(round(len(ims)/2.0)) 98 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 99 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 100 | 101 | webpage.save() 102 | 103 | # dictionary of loss labels and values 104 | def plot_current_errors(self, errors, step): 105 | if self.tf_log: 106 | for tag, value in errors.items(): 107 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 108 | self.writer.add_summary(summary, step) 109 | 110 | # print loss labels and values 111 | def print_current_errors(self, epoch, i, errors, t): 112 | message = '(epoch: %d, iters: %d, time: %.4f) ' % (epoch, i, t) 113 | for k, v in errors.items(): 114 | if v != 0: 115 | message += '%s: %.4f ' % (k, v) 116 | 117 | print(message) 118 | with open(self.log_name, "a") as log_file: 119 | log_file.write('%s\n' % message) 120 | 121 | # save images to disk 122 | def save_images(self, webpage, visuals, image_path): 123 | image_dir = webpage.get_image_dir() 124 | short_path = ntpath.basename(image_path[0]) 125 | name = os.path.splitext(short_path)[0] 126 | 127 | webpage.add_header(name) 128 | ims = [] 129 | txts = [] 130 | links = [] 131 | 132 | for label, image_numpy in visuals.items(): 133 | image_name = '%s_%s.png' % (name, label) 134 | save_path = os.path.join(image_dir, image_name) 135 | util.save_image(image_numpy, save_path) 136 | 137 | ims.append(image_name) 138 | txts.append(label) 139 | links.append(image_name) 140 | 141 | webpage.add_images(ims, txts, links, width=self.win_size) 142 | 143 | --------------------------------------------------------------------------------