├── .gitignore ├── LICENSE ├── README.md ├── datasets.py ├── datasets ├── .gitignore └── make_dataset.py ├── images ├── test_0.jpg ├── test_0_out.jpg ├── test_1.jpg ├── test_1_out.jpg ├── test_2.jpg ├── test_2_out.jpg ├── test_3.jpg ├── test_3_out.jpg ├── test_4.jpg └── test_4_out.jpg ├── layers.py ├── losses.py ├── models.py ├── predict.py ├── requirements.txt ├── results └── .gitignore ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .python-version 3 | .vscode/ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Otenim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GLCIC-PyTorch 2 | 3 | This repository provides a pytorch-based implementation of [GLCIC](http://hi.cs.waseda.ac.jp/~iizuka/projects/completion/data/completion_sig2017.pdf) introduced by Iizuka et. al. 4 | 5 | ![glcic](https://i.imgur.com/KY26J85.png) 6 | ![result_1](https://i.imgur.com/LTYCUup.jpg) 7 | ![result_2](https://i.imgur.com/RR7MhNS.jpg) 8 | ![result_3](https://i.imgur.com/xOrTR4n.jpg) 9 | 10 | - [GLCIC-PyTorch](#glcic-pytorch) 11 | - [Requirements](#requirements) 12 | - [DEMO (Inference)](#demo-inference) 13 | - [1. Download pretrained generator model and training config file.](#1-download-pretrained-generator-model-and-training-config-file) 14 | - [2. Inference](#2-inference) 15 | - [DEMO (Training)](#demo-training) 16 | - [1. Download the dataset](#1-download-the-dataset) 17 | - [2. Training](#2-training) 18 | - [How to train with custom dataset ?](#how-to-train-with-custom-dataset-) 19 | - [1. Prepare dataset](#1-prepare-dataset) 20 | - [2. Training](#2-training-1) 21 | - [How to perform infenrece with custom dataset ?](#how-to-perform-infenrece-with-custom-dataset-) 22 | 23 | ## Requirements 24 | 25 | Our scripts were tested in the following environment. 26 | 27 | * Python: 3.7.6 28 | * torch: 1.9.0 (cuda 11.1) 29 | * torchvision: 0.10.0 (cuda 11.1) 30 | * tqdm: 4.61.1 31 | * Pillow: 8.2.0 32 | * opencv-python: 4.5.2.54 33 | * numpy: 1.19.2 34 | * GPU: Geforce GTX 1080Ti (12GB RAM) X 4 35 | 36 | You can install all the requirements by executing below. 37 | 38 | ```sh 39 | # in / 40 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 41 | ``` 42 | 43 | ## DEMO (Inference) 44 | 45 | ### 1. Download pretrained generator model and training config file. 46 | * [Required] Pretrained generator model (Completion Network): [download (google drive)](https://drive.google.com/file/d/1hsi1Fy0ITiZYTsJ_De-nAVUciuJ8Bql9/view?usp=sharing) 47 | * [Optional] Pretrained discriminator model (Context Discriminator): [download (google drive)](https://drive.google.com/file/d/1_jRuqirwOuiJCg1H73LSq1HuwyG2GON4/view?usp=sharing) 48 | * [Required] Training config file: [download (google drive)](https://drive.google.com/file/d/1yGfQp8U5zcVRYOBxF3-VCZ8TnMAtWBsk/view?usp=sharing) 49 | 50 | Both the generator and discriminator were trained on the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset. 51 | Note that you don't need to have dicriminator when performing image completion (discriminator is needed only during training). 52 | 53 | ### 2. Inference 54 | 55 | ```bash 56 | # in / 57 | python predict.py model_cn config.json images/test_2.jpg test_2_out.jpg 58 | ``` 59 | 60 | **Left**: raw input image 61 | **Center**: masked input image 62 | **Right**: inpainted output image 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | ## DEMO (Training) 71 | 72 | This section introduces how to train a glcic model using CelebA dataset. 73 | 74 | ### 1. Download the dataset 75 | 76 | Download the dataset from [this official link](https://drive.google.com/open?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM). 77 | 78 | Then, execute the following commands. 79 | 80 | ```bash 81 | # unzip dataset 82 | unzip img_align_celeba.zip 83 | # move dataset 84 | mv img_align_celeba/ /datasets/ 85 | # move into datasets/ directory 86 | cd /datasets/ 87 | # make dataset 88 | python make_dataset.py img_align_celeba/ 89 | ``` 90 | 91 | The last command splits the dataset into training dataset (80%) and test dataset (20%) randomly. 92 | 93 | ### 2. Training 94 | 95 | Run the following command. 96 | 97 | ```bash 98 | # in 99 | python train.py datasets/img_align_celeba results/demo/ 100 | ``` 101 | 102 | Training results (model snapshots & test inpainted outputs) are to be saved in ``results/demo/``. 103 | 104 | The training procedure consists of the following three phases. 105 | * **Phase 1**: trains only generator. 106 | * **Phase 2**: trains only discriminator, while generator is frozen. 107 | * **Phase 3**: both generator and discriminator are jointly trained. 108 | 109 | Default settings of ``train.py`` are based on the original paper **except for batch size**. 110 | If you need to reproduce the paper result, add ``--data_parallel --bsize 96`` when executing training. 111 | 112 | ## How to train with custom dataset ? 113 | 114 | ### 1. Prepare dataset 115 | 116 | You have to prepare a dataset in the following format. 117 | 118 | ``` 119 | dataset/ # any name is OK 120 | |____train/ # used for training 121 | | |____XXXX.jpg # .png format is also acceptable. 122 | | |____OOOO.jpg 123 | | |____.... 124 | |____test/ # used for test 125 | |____oooo.jpg 126 | |____xxxx.jpg 127 | |____.... 128 | ``` 129 | 130 | Both `dataset/train` and `dataset/test` are required. 131 | 132 | ### 2. Training 133 | 134 | ```bash 135 | # in / 136 | # move dataset 137 | mv dataset/ datasets/ 138 | # execute training 139 | python train.py datasets/dataset/ results/result/ [--data_parallel (store true)] [--cn_input_size (int)] [--ld_input_size (int)] [--init_model_cn (str)] [--init_model_cd (str)] [--steps_1 (int)] [--steps_2 (int)] [--steps_3 (int)] [--snaperiod_1 (int)] [--snaperiod_2 (int)] [--snaperiod_3 (int)] [--bsize (int)] [--bdivs (int)] 140 | ``` 141 | 142 | 143 | **Arguments** 144 | * `` (required): path to the dataset directory. 145 | * `` (required): path to the result directory. 146 | * `[--data_parallel (store true)]`: when this flag is enabled, models are trained in data-parallel way. If *N* gpus are available, *N* gpus are used during training (default: disabled). 147 | * `[--cn_input_size (int)]`: input size of generator (completion network). All input images are rescalled so that the minimum side is equal to `cn_input_size` then randomly cropped into `cn_input_size` x `cn_input_size` (default: 160). 148 | * `[--ld_input_size (int)]`: input size of local discriminator (default: 96). Input size of global discriminator is the same as `[--cn_input_size]`. 149 | * `[--init_model_cn (str)]`: path to a pretrained generator, used as its initial weights (default: None). 150 | * `[--init_model_cd (str)]`: path to a pretrained discriminator, used as its initial weights (default: None). 151 | * `[--steps_1 (int)]`: training steps during phase 1 (default: 90,000). 152 | * `[--steps_2 (int)]`: training steps during phase 2 (default: 10,000). 153 | * `[--steps_3 (int)]`: training steps during phase 3 (default: 400,000). 154 | * `[--snaperiod_1 (int)]`: snapshot period during phase 1 (default: 10,000). 155 | * `[--snaperiod_2 (int)]`: snapshot period during phase 2 (default: 2,000). 156 | * `[--snaperiod_3 (int)]`: snapshot period during phase 3 (default: 10,000). 157 | * `[--max_holes (int)]`: maximum number of holes randomly generated and applied to each input image (default: 1). 158 | * `[--hole_min_w (int)]`: minimum width of a hole (default: 48). 159 | * `[--hole_max_w (int)]`: maximum width of a hole (default: 96). 160 | * `[--hole_min_h (int)]`: minimum height of a hole (default: 48). 161 | * `[--hole_max_h (int)]`: maximum height of a hole (default: 96). 162 | * `[--bsize (int)]`: batch size (default: 16). **bsize >= 96 is strongly recommended**. 163 | * `[--bdivs (int)]`: divide a single training step of batch size = *bsize* into *bdivs* steps of batch size = *bsize*/*bdivs*, which produces the same training results as when `bdivs` = 1 but uses smaller gpu memory space at the cost of speed. This option can be used together with `data_parallel` (default: 1). 164 | 165 | **Example**: If you train a model with batch size 24 with `data_parallel` option and leave the other settings as default, run the following command. 166 | 167 | ```bash 168 | # in / 169 | python train.py datasets/dataset results/result --data_parallel --bsize 24 170 | ``` 171 | 172 | ## How to perform infenrece with custom dataset ? 173 | 174 | Assume you've finished training and result directory is `/results/result`. 175 | 176 | ```bash 177 | # in / 178 | python predict.py results/result/phase_3/model_cn_step results/result/config.json [--max_holes (int)] [--img_size (int)] [--hole_min_w (int)] [--hole_max_w (int)] [--hole_min_h (int)] [--hole_max_h (int)] 179 | ``` 180 | 181 | **Arguments** 182 | * `` (required): path to an input image. 183 | * `` (required): path to an output image. 184 | * `[--img_size (int)]`: input size of generator. Input images are rescalled so that the minimum side = `img_size` then randomly cropped into `img_size` x `img_size` (default: 160). 185 | * `[--max_holes (int)]`: maximum number of holes to be randomly generated (default: 5). 186 | * `[--hole_min_w (int)]`: minimum width of a hole (default: 24). 187 | * `[--hole_max_w (int)]`: maximum width of a hole (default: 48). 188 | * `[--hole_min_h (int)]`: minimum height of a hole (default: 24). 189 | * `[--hole_max_h (int)]`: maximum height of a hole (default: 48). 190 | 191 | **Example**: If you make an inference with an input image `/input.jpg` and save output image as `/output.jpg`, run the following command. 192 | 193 | ```bash 194 | # in / 195 | python predict.py results/result/phase_3/model_cn_step{step_number} results/result/config.json input.jpg output.jpg 196 | ``` 197 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import imghdr 4 | import torch.utils.data as data 5 | from PIL import Image 6 | 7 | 8 | class ImageDataset(data.Dataset): 9 | def __init__(self, data_dir, transform=None, recursive_search=False): 10 | super(ImageDataset, self).__init__() 11 | self.data_dir = os.path.expanduser(data_dir) 12 | self.transform = transform 13 | self.imgpaths = self.__load_imgpaths_from_dir(self.data_dir, walk=recursive_search) 14 | 15 | def __len__(self): 16 | return len(self.imgpaths) 17 | 18 | def __getitem__(self, index, color_format='RGB'): 19 | img = Image.open(self.imgpaths[index]) 20 | img = img.convert(color_format) 21 | if self.transform is not None: 22 | img = self.transform(img) 23 | return img 24 | 25 | def __is_imgfile(self, filepath): 26 | filepath = os.path.expanduser(filepath) 27 | if os.path.isfile(filepath) and imghdr.what(filepath): 28 | return True 29 | return False 30 | 31 | def __load_imgpaths_from_dir(self, dirpath, walk=False): 32 | imgpaths = [] 33 | dirpath = os.path.expanduser(dirpath) 34 | if walk: 35 | for (root, _, files) in os.walk(dirpath): 36 | for file in files: 37 | file = os.path.join(root, file) 38 | if self.__is_imgfile(file): 39 | imgpaths.append(file) 40 | else: 41 | for path in os.listdir(dirpath): 42 | path = os.path.join(dirpath, path) 43 | if not self.__is_imgfile(path): 44 | continue 45 | imgpaths.append(path) 46 | return imgpaths 47 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | img_align_celeba/ 2 | places2/ 3 | test/ 4 | *.tar.gz 5 | -------------------------------------------------------------------------------- /datasets/make_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import imghdr 4 | import random 5 | import shutil 6 | import tqdm 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('data_dir') 11 | parser.add_argument('--split', type=float, default=0.8) 12 | 13 | 14 | def main(args): 15 | args.data_dir = os.path.expanduser(args.data_dir) 16 | 17 | print('loading dataset...') 18 | src_paths = [] 19 | for file in os.listdir(args.data_dir): 20 | path = os.path.join(args.data_dir, file) 21 | if imghdr.what(path) == None: 22 | continue 23 | src_paths.append(path) 24 | random.shuffle(src_paths) 25 | 26 | # separate the paths 27 | border = int(args.split * len(src_paths)) 28 | train_paths = src_paths[:border] 29 | test_paths = src_paths[border:] 30 | print('train images: %d images.' % len(train_paths)) 31 | print('test images: %d images.' % len(test_paths)) 32 | 33 | # create dst directories 34 | train_dir = os.path.join(args.data_dir, 'train') 35 | test_dir = os.path.join(args.data_dir, 'test') 36 | if os.path.exists(train_dir) == False: 37 | os.makedirs(train_dir) 38 | if os.path.exists(test_dir) == False: 39 | os.makedirs(test_dir) 40 | 41 | # move the image files 42 | pbar = tqdm.tqdm(total=len(src_paths)) 43 | for dset_paths, dset_dir in zip([train_paths, test_paths], [train_dir, test_dir]): 44 | for src_path in dset_paths: 45 | dst_path = os.path.join(dset_dir, os.path.basename(src_path)) 46 | shutil.move(src_path, dst_path) 47 | pbar.update() 48 | pbar.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | args = parser.parse_args() 53 | main(args) 54 | -------------------------------------------------------------------------------- /images/test_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_0.jpg -------------------------------------------------------------------------------- /images/test_0_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_0_out.jpg -------------------------------------------------------------------------------- /images/test_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_1.jpg -------------------------------------------------------------------------------- /images/test_1_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_1_out.jpg -------------------------------------------------------------------------------- /images/test_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_2.jpg -------------------------------------------------------------------------------- /images/test_2_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_2_out.jpg -------------------------------------------------------------------------------- /images/test_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_3.jpg -------------------------------------------------------------------------------- /images/test_3_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_3_out.jpg -------------------------------------------------------------------------------- /images/test_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_4.jpg -------------------------------------------------------------------------------- /images/test_4_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otenim/GLCIC-PyTorch/3f9de13e88b76171f4ab0077827805fe5a55d572/images/test_4_out.jpg -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Flatten(nn.Module): 7 | def __init__(self): 8 | super(Flatten, self).__init__() 9 | 10 | def forward(self, x): 11 | return x.view(x.shape[0], -1) 12 | 13 | 14 | class Concatenate(nn.Module): 15 | def __init__(self, dim=-1): 16 | super(Concatenate, self).__init__() 17 | self.dim = dim 18 | 19 | def forward(self, x): 20 | return torch.cat(x, dim=self.dim) 21 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import mse_loss 2 | 3 | 4 | def completion_network_loss(input, output, mask): 5 | return mse_loss(output * mask, input * mask) 6 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers import Flatten, Concatenate 5 | 6 | 7 | class CompletionNetwork(nn.Module): 8 | def __init__(self): 9 | super(CompletionNetwork, self).__init__() 10 | # input_shape: (None, 4, img_h, img_w) 11 | self.conv1 = nn.Conv2d(4, 64, kernel_size=5, stride=1, padding=2) 12 | self.bn1 = nn.BatchNorm2d(64) 13 | self.act1 = nn.ReLU() 14 | # input_shape: (None, 64, img_h, img_w) 15 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 16 | self.bn2 = nn.BatchNorm2d(128) 17 | self.act2 = nn.ReLU() 18 | # input_shape: (None, 128, img_h//2, img_w//2) 19 | self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 20 | self.bn3 = nn.BatchNorm2d(128) 21 | self.act3 = nn.ReLU() 22 | # input_shape: (None, 128, img_h//2, img_w//2) 23 | self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 24 | self.bn4 = nn.BatchNorm2d(256) 25 | self.act4 = nn.ReLU() 26 | # input_shape: (None, 256, img_h//4, img_w//4) 27 | self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 28 | self.bn5 = nn.BatchNorm2d(256) 29 | self.act5 = nn.ReLU() 30 | # input_shape: (None, 256, img_h//4, img_w//4) 31 | self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 32 | self.bn6 = nn.BatchNorm2d(256) 33 | self.act6 = nn.ReLU() 34 | # input_shape: (None, 256, img_h//4, img_w//4) 35 | self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=2, padding=2) 36 | self.bn7 = nn.BatchNorm2d(256) 37 | self.act7 = nn.ReLU() 38 | # input_shape: (None, 256, img_h//4, img_w//4) 39 | self.conv8 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=4, padding=4) 40 | self.bn8 = nn.BatchNorm2d(256) 41 | self.act8 = nn.ReLU() 42 | # input_shape: (None, 256, img_h//4, img_w//4) 43 | self.conv9 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=8, padding=8) 44 | self.bn9 = nn.BatchNorm2d(256) 45 | self.act9 = nn.ReLU() 46 | # input_shape: (None, 256, img_h//4, img_w//4) 47 | self.conv10 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=16, padding=16) 48 | self.bn10 = nn.BatchNorm2d(256) 49 | self.act10 = nn.ReLU() 50 | # input_shape: (None, 256, img_h//4, img_w//4) 51 | self.conv11 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 52 | self.bn11 = nn.BatchNorm2d(256) 53 | self.act11 = nn.ReLU() 54 | # input_shape: (None, 256, img_h//4, img_w//4) 55 | self.conv12 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 56 | self.bn12 = nn.BatchNorm2d(256) 57 | self.act12 = nn.ReLU() 58 | # input_shape: (None, 256, img_h//4, img_w//4) 59 | self.deconv13 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1) 60 | self.bn13 = nn.BatchNorm2d(128) 61 | self.act13 = nn.ReLU() 62 | # input_shape: (None, 128, img_h//2, img_w//2) 63 | self.conv14 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 64 | self.bn14 = nn.BatchNorm2d(128) 65 | self.act14 = nn.ReLU() 66 | # input_shape: (None, 128, img_h//2, img_w//2) 67 | self.deconv15 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) 68 | self.bn15 = nn.BatchNorm2d(64) 69 | self.act15 = nn.ReLU() 70 | # input_shape: (None, 64, img_h, img_w) 71 | self.conv16 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 72 | self.bn16 = nn.BatchNorm2d(32) 73 | self.act16 = nn.ReLU() 74 | # input_shape: (None, 32, img_h, img_w) 75 | self.conv17 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1) 76 | self.act17 = nn.Sigmoid() 77 | # output_shape: (None, 3, img_h. img_w) 78 | 79 | def forward(self, x): 80 | x = self.bn1(self.act1(self.conv1(x))) 81 | x = self.bn2(self.act2(self.conv2(x))) 82 | x = self.bn3(self.act3(self.conv3(x))) 83 | x = self.bn4(self.act4(self.conv4(x))) 84 | x = self.bn5(self.act5(self.conv5(x))) 85 | x = self.bn6(self.act6(self.conv6(x))) 86 | x = self.bn7(self.act7(self.conv7(x))) 87 | x = self.bn8(self.act8(self.conv8(x))) 88 | x = self.bn9(self.act9(self.conv9(x))) 89 | x = self.bn10(self.act10(self.conv10(x))) 90 | x = self.bn11(self.act11(self.conv11(x))) 91 | x = self.bn12(self.act12(self.conv12(x))) 92 | x = self.bn13(self.act13(self.deconv13(x))) 93 | x = self.bn14(self.act14(self.conv14(x))) 94 | x = self.bn15(self.act15(self.deconv15(x))) 95 | x = self.bn16(self.act16(self.conv16(x))) 96 | x = self.act17(self.conv17(x)) 97 | return x 98 | 99 | 100 | class LocalDiscriminator(nn.Module): 101 | def __init__(self, input_shape): 102 | super(LocalDiscriminator, self).__init__() 103 | self.input_shape = input_shape 104 | self.output_shape = (1024,) 105 | self.img_c = input_shape[0] 106 | self.img_h = input_shape[1] 107 | self.img_w = input_shape[2] 108 | # input_shape: (None, img_c, img_h, img_w) 109 | self.conv1 = nn.Conv2d(self.img_c, 64, kernel_size=5, stride=2, padding=2) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.act1 = nn.ReLU() 112 | # input_shape: (None, 64, img_h//2, img_w//2) 113 | self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2) 114 | self.bn2 = nn.BatchNorm2d(128) 115 | self.act2 = nn.ReLU() 116 | # input_shape: (None, 128, img_h//4, img_w//4) 117 | self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2) 118 | self.bn3 = nn.BatchNorm2d(256) 119 | self.act3 = nn.ReLU() 120 | # input_shape: (None, 256, img_h//8, img_w//8) 121 | self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2) 122 | self.bn4 = nn.BatchNorm2d(512) 123 | self.act4 = nn.ReLU() 124 | # input_shape: (None, 512, img_h//16, img_w//16) 125 | self.conv5 = nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2) 126 | self.bn5 = nn.BatchNorm2d(512) 127 | self.act5 = nn.ReLU() 128 | # input_shape: (None, 512, img_h//32, img_w//32) 129 | in_features = 512 * (self.img_h//32) * (self.img_w//32) 130 | self.flatten6 = Flatten() 131 | # input_shape: (None, 512 * img_h//32 * img_w//32) 132 | self.linear6 = nn.Linear(in_features, 1024) 133 | self.act6 = nn.ReLU() 134 | # output_shape: (None, 1024) 135 | 136 | def forward(self, x): 137 | x = self.bn1(self.act1(self.conv1(x))) 138 | x = self.bn2(self.act2(self.conv2(x))) 139 | x = self.bn3(self.act3(self.conv3(x))) 140 | x = self.bn4(self.act4(self.conv4(x))) 141 | x = self.bn5(self.act5(self.conv5(x))) 142 | x = self.act6(self.linear6(self.flatten6(x))) 143 | return x 144 | 145 | 146 | class GlobalDiscriminator(nn.Module): 147 | def __init__(self, input_shape, arc='celeba'): 148 | super(GlobalDiscriminator, self).__init__() 149 | self.arc = arc 150 | self.input_shape = input_shape 151 | self.output_shape = (1024,) 152 | self.img_c = input_shape[0] 153 | self.img_h = input_shape[1] 154 | self.img_w = input_shape[2] 155 | 156 | # input_shape: (None, img_c, img_h, img_w) 157 | self.conv1 = nn.Conv2d(self.img_c, 64, kernel_size=5, stride=2, padding=2) 158 | self.bn1 = nn.BatchNorm2d(64) 159 | self.act1 = nn.ReLU() 160 | # input_shape: (None, 64, img_h//2, img_w//2) 161 | self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2) 162 | self.bn2 = nn.BatchNorm2d(128) 163 | self.act2 = nn.ReLU() 164 | # input_shape: (None, 128, img_h//4, img_w//4) 165 | self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2) 166 | self.bn3 = nn.BatchNorm2d(256) 167 | self.act3 = nn.ReLU() 168 | # input_shape: (None, 256, img_h//8, img_w//8) 169 | self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2) 170 | self.bn4 = nn.BatchNorm2d(512) 171 | self.act4 = nn.ReLU() 172 | # input_shape: (None, 512, img_h//16, img_w//16) 173 | self.conv5 = nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2) 174 | self.bn5 = nn.BatchNorm2d(512) 175 | self.act5 = nn.ReLU() 176 | # input_shape: (None, 512, img_h//32, img_w//32) 177 | if arc == 'celeba': 178 | in_features = 512 * (self.img_h//32) * (self.img_w//32) 179 | self.flatten6 = Flatten() 180 | self.linear6 = nn.Linear(in_features, 1024) 181 | self.act6 = nn.ReLU() 182 | elif arc == 'places2': 183 | self.conv6 = nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2) 184 | self.bn6 = nn.BatchNorm2d(512) 185 | self.act6 = nn.ReLU() 186 | # input_shape (None, 512, img_h//64, img_w//64) 187 | in_features = 512 * (self.img_h//64) * (self.img_w//64) 188 | self.flatten7 = Flatten() 189 | self.linear7 = nn.Linear(in_features, 1024) 190 | self.act7 = nn.ReLU() 191 | else: 192 | raise ValueError('Unsupported architecture \'%s\'.' % self.arc) 193 | # output_shape: (None, 1024) 194 | 195 | def forward(self, x): 196 | x = self.bn1(self.act1(self.conv1(x))) 197 | x = self.bn2(self.act2(self.conv2(x))) 198 | x = self.bn3(self.act3(self.conv3(x))) 199 | x = self.bn4(self.act4(self.conv4(x))) 200 | x = self.bn5(self.act5(self.conv5(x))) 201 | if self.arc == 'celeba': 202 | x = self.act6(self.linear6(self.flatten6(x))) 203 | elif self.arc == 'places2': 204 | x = self.bn6(self.act6(self.conv6(x))) 205 | x = self.act7(self.linear7(self.flatten7(x))) 206 | return x 207 | 208 | 209 | class ContextDiscriminator(nn.Module): 210 | def __init__(self, local_input_shape, global_input_shape, arc='celeba'): 211 | super(ContextDiscriminator, self).__init__() 212 | self.arc = arc 213 | self.input_shape = [local_input_shape, global_input_shape] 214 | self.output_shape = (1,) 215 | self.model_ld = LocalDiscriminator(local_input_shape) 216 | self.model_gd = GlobalDiscriminator(global_input_shape, arc=arc) 217 | # input_shape: [(None, 1024), (None, 1024)] 218 | in_features = self.model_ld.output_shape[-1] + self.model_gd.output_shape[-1] 219 | self.concat1 = Concatenate(dim=-1) 220 | # input_shape: (None, 2048) 221 | self.linear1 = nn.Linear(in_features, 1) 222 | self.act1 = nn.Sigmoid() 223 | # output_shape: (None, 1) 224 | 225 | def forward(self, x): 226 | x_ld, x_gd = x 227 | x_ld = self.model_ld(x_ld) 228 | x_gd = self.model_gd(x_gd) 229 | out = self.act1(self.linear1(self.concat1([x_ld, x_gd]))) 230 | return out 231 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import torch 5 | import torchvision.transforms as transforms 6 | from torchvision.utils import save_image 7 | import numpy as np 8 | from PIL import Image 9 | from models import CompletionNetwork 10 | from utils import poisson_blend, gen_input_mask 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('model') 15 | parser.add_argument('config') 16 | parser.add_argument('input_img') 17 | parser.add_argument('output_img') 18 | parser.add_argument('--max_holes', type=int, default=5) 19 | parser.add_argument('--img_size', type=int, default=160) 20 | parser.add_argument('--hole_min_w', type=int, default=24) 21 | parser.add_argument('--hole_max_w', type=int, default=48) 22 | parser.add_argument('--hole_min_h', type=int, default=24) 23 | parser.add_argument('--hole_max_h', type=int, default=48) 24 | 25 | 26 | def main(args): 27 | 28 | args.model = os.path.expanduser(args.model) 29 | args.config = os.path.expanduser(args.config) 30 | args.input_img = os.path.expanduser(args.input_img) 31 | args.output_img = os.path.expanduser(args.output_img) 32 | 33 | # ============================================= 34 | # Load model 35 | # ============================================= 36 | with open(args.config, 'r') as f: 37 | config = json.load(f) 38 | mpv = torch.tensor(config['mpv']).view(1, 3, 1, 1) 39 | model = CompletionNetwork() 40 | model.load_state_dict(torch.load(args.model, map_location='cpu')) 41 | 42 | # ============================================= 43 | # Predict 44 | # ============================================= 45 | # convert img to tensor 46 | img = Image.open(args.input_img) 47 | img = transforms.Resize(args.img_size)(img) 48 | img = transforms.RandomCrop((args.img_size, args.img_size))(img) 49 | x = transforms.ToTensor()(img) 50 | x = torch.unsqueeze(x, dim=0) 51 | 52 | # create mask 53 | mask = gen_input_mask( 54 | shape=(1, 1, x.shape[2], x.shape[3]), 55 | hole_size=( 56 | (args.hole_min_w, args.hole_max_w), 57 | (args.hole_min_h, args.hole_max_h), 58 | ), 59 | max_holes=args.max_holes, 60 | ) 61 | 62 | # inpaint 63 | model.eval() 64 | with torch.no_grad(): 65 | x_mask = x - x * mask + mpv * mask 66 | input = torch.cat((x_mask, mask), dim=1) 67 | output = model(input) 68 | inpainted = poisson_blend(x_mask, output, mask) 69 | imgs = torch.cat((x, x_mask, inpainted), dim=0) 70 | save_image(imgs, args.output_img, nrow=3) 71 | print('output img was saved as %s.' % args.output_img) 72 | 73 | 74 | if __name__ == '__main__': 75 | args = parser.parse_args() 76 | main(args) 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | opencv-python==4.5.2.54 3 | Pillow==8.2.0 4 | torch==1.9.0+cu111 5 | torchvision==0.10.0+cu111 6 | tqdm==4.61.1 7 | 8 | -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | test/ 2 | test_*/ 3 | celeba/ 4 | places2/ 5 | celeba_*/ 6 | places2_*/ 7 | demo/ 8 | release/ 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from torch.utils.data import DataLoader 5 | from torch.optim import Adadelta, Adam 6 | from torch.nn import BCELoss, DataParallel 7 | from torchvision.utils import save_image 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | import torch 11 | import numpy as np 12 | from tqdm import tqdm 13 | from models import CompletionNetwork, ContextDiscriminator 14 | from datasets import ImageDataset 15 | from losses import completion_network_loss 16 | from utils import ( 17 | gen_input_mask, 18 | gen_hole_area, 19 | crop, 20 | sample_random_batch, 21 | poisson_blend, 22 | ) 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('data_dir') 26 | parser.add_argument('result_dir') 27 | parser.add_argument('--data_parallel', action='store_true') 28 | parser.add_argument('--recursive_search', action='store_true', default=False) 29 | parser.add_argument('--init_model_cn', type=str, default=None) 30 | parser.add_argument('--init_model_cd', type=str, default=None) 31 | parser.add_argument('--steps_1', type=int, default=90000) 32 | parser.add_argument('--steps_2', type=int, default=10000) 33 | parser.add_argument('--steps_3', type=int, default=400000) 34 | parser.add_argument('--snaperiod_1', type=int, default=10000) 35 | parser.add_argument('--snaperiod_2', type=int, default=2000) 36 | parser.add_argument('--snaperiod_3', type=int, default=10000) 37 | parser.add_argument('--max_holes', type=int, default=1) 38 | parser.add_argument('--hole_min_w', type=int, default=48) 39 | parser.add_argument('--hole_max_w', type=int, default=96) 40 | parser.add_argument('--hole_min_h', type=int, default=48) 41 | parser.add_argument('--hole_max_h', type=int, default=96) 42 | parser.add_argument('--cn_input_size', type=int, default=160) 43 | parser.add_argument('--ld_input_size', type=int, default=96) 44 | parser.add_argument('--bsize', type=int, default=16) 45 | parser.add_argument('--bdivs', type=int, default=1) 46 | parser.add_argument('--num_test_completions', type=int, default=16) 47 | parser.add_argument('--mpv', nargs=3, type=float, default=None) 48 | parser.add_argument('--alpha', type=float, default=4e-4) 49 | parser.add_argument('--arc', type=str, choices=['celeba', 'places2'], default='celeba') 50 | 51 | 52 | def main(args): 53 | # ================================================ 54 | # Preparation 55 | # ================================================ 56 | if not torch.cuda.is_available(): 57 | raise Exception('At least one gpu must be available.') 58 | gpu = torch.device('cuda:0') 59 | 60 | # create result directory (if necessary) 61 | if not os.path.exists(args.result_dir): 62 | os.makedirs(args.result_dir) 63 | for phase in ['phase_1', 'phase_2', 'phase_3']: 64 | if not os.path.exists(os.path.join(args.result_dir, phase)): 65 | os.makedirs(os.path.join(args.result_dir, phase)) 66 | 67 | # load dataset 68 | trnsfm = transforms.Compose([ 69 | transforms.Resize(args.cn_input_size), 70 | transforms.RandomCrop((args.cn_input_size, args.cn_input_size)), 71 | transforms.ToTensor(), 72 | ]) 73 | print('loading dataset... (it may take a few minutes)') 74 | train_dset = ImageDataset( 75 | os.path.join(args.data_dir, 'train'), 76 | trnsfm, 77 | recursive_search=args.recursive_search) 78 | test_dset = ImageDataset( 79 | os.path.join(args.data_dir, 'test'), 80 | trnsfm, 81 | recursive_search=args.recursive_search) 82 | train_loader = DataLoader( 83 | train_dset, 84 | batch_size=(args.bsize // args.bdivs), 85 | shuffle=True) 86 | 87 | # compute mpv (mean pixel value) of training dataset 88 | if args.mpv is None: 89 | mpv = np.zeros(shape=(3,)) 90 | pbar = tqdm( 91 | total=len(train_dset.imgpaths), 92 | desc='computing mean pixel value of training dataset...') 93 | for imgpath in train_dset.imgpaths: 94 | img = Image.open(imgpath) 95 | x = np.array(img) / 255. 96 | mpv += x.mean(axis=(0, 1)) 97 | pbar.update() 98 | mpv /= len(train_dset.imgpaths) 99 | pbar.close() 100 | else: 101 | mpv = np.array(args.mpv) 102 | 103 | # save training config 104 | mpv_json = [] 105 | for i in range(3): 106 | mpv_json.append(float(mpv[i])) 107 | args_dict = vars(args) 108 | args_dict['mpv'] = mpv_json 109 | with open(os.path.join( 110 | args.result_dir, 'config.json'), 111 | mode='w') as f: 112 | json.dump(args_dict, f) 113 | 114 | # make mpv & alpha tensors 115 | mpv = torch.tensor( 116 | mpv.reshape(1, 3, 1, 1), 117 | dtype=torch.float32).to(gpu) 118 | alpha = torch.tensor( 119 | args.alpha, 120 | dtype=torch.float32).to(gpu) 121 | 122 | # ================================================ 123 | # Training Phase 1 124 | # ================================================ 125 | # load completion network 126 | model_cn = CompletionNetwork() 127 | if args.init_model_cn is not None: 128 | model_cn.load_state_dict(torch.load( 129 | args.init_model_cn, 130 | map_location='cpu')) 131 | if args.data_parallel: 132 | model_cn = DataParallel(model_cn) 133 | model_cn = model_cn.to(gpu) 134 | opt_cn = Adadelta(model_cn.parameters()) 135 | 136 | # training 137 | cnt_bdivs = 0 138 | pbar = tqdm(total=args.steps_1) 139 | while pbar.n < args.steps_1: 140 | for x in train_loader: 141 | # forward 142 | x = x.to(gpu) 143 | mask = gen_input_mask( 144 | shape=(x.shape[0], 1, x.shape[2], x.shape[3]), 145 | hole_size=( 146 | (args.hole_min_w, args.hole_max_w), 147 | (args.hole_min_h, args.hole_max_h)), 148 | hole_area=gen_hole_area( 149 | (args.ld_input_size, args.ld_input_size), 150 | (x.shape[3], x.shape[2])), 151 | max_holes=args.max_holes, 152 | ).to(gpu) 153 | x_mask = x - x * mask + mpv * mask 154 | input = torch.cat((x_mask, mask), dim=1) 155 | output = model_cn(input) 156 | loss = completion_network_loss(x, output, mask) 157 | 158 | # backward 159 | loss.backward() 160 | cnt_bdivs += 1 161 | if cnt_bdivs >= args.bdivs: 162 | cnt_bdivs = 0 163 | 164 | # optimize 165 | opt_cn.step() 166 | opt_cn.zero_grad() 167 | pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu()) 168 | pbar.update() 169 | 170 | # test 171 | if pbar.n % args.snaperiod_1 == 0: 172 | model_cn.eval() 173 | with torch.no_grad(): 174 | x = sample_random_batch( 175 | test_dset, 176 | batch_size=args.num_test_completions).to(gpu) 177 | mask = gen_input_mask( 178 | shape=(x.shape[0], 1, x.shape[2], x.shape[3]), 179 | hole_size=( 180 | (args.hole_min_w, args.hole_max_w), 181 | (args.hole_min_h, args.hole_max_h)), 182 | hole_area=gen_hole_area( 183 | (args.ld_input_size, args.ld_input_size), 184 | (x.shape[3], x.shape[2])), 185 | max_holes=args.max_holes).to(gpu) 186 | x_mask = x - x * mask + mpv * mask 187 | input = torch.cat((x_mask, mask), dim=1) 188 | output = model_cn(input) 189 | completed = poisson_blend(x_mask, output, mask) 190 | imgs = torch.cat(( 191 | x.cpu(), 192 | x_mask.cpu(), 193 | completed.cpu()), dim=0) 194 | imgpath = os.path.join( 195 | args.result_dir, 196 | 'phase_1', 197 | 'step%d.png' % pbar.n) 198 | model_cn_path = os.path.join( 199 | args.result_dir, 200 | 'phase_1', 201 | 'model_cn_step%d' % pbar.n) 202 | save_image(imgs, imgpath, nrow=len(x)) 203 | if args.data_parallel: 204 | torch.save( 205 | model_cn.module.state_dict(), 206 | model_cn_path) 207 | else: 208 | torch.save( 209 | model_cn.state_dict(), 210 | model_cn_path) 211 | model_cn.train() 212 | if pbar.n >= args.steps_1: 213 | break 214 | pbar.close() 215 | 216 | # ================================================ 217 | # Training Phase 2 218 | # ================================================ 219 | # load context discriminator 220 | model_cd = ContextDiscriminator( 221 | local_input_shape=(3, args.ld_input_size, args.ld_input_size), 222 | global_input_shape=(3, args.cn_input_size, args.cn_input_size), 223 | arc=args.arc) 224 | if args.init_model_cd is not None: 225 | model_cd.load_state_dict(torch.load( 226 | args.init_model_cd, 227 | map_location='cpu')) 228 | if args.data_parallel: 229 | model_cd = DataParallel(model_cd) 230 | model_cd = model_cd.to(gpu) 231 | opt_cd = Adadelta(model_cd.parameters()) 232 | bceloss = BCELoss() 233 | 234 | # training 235 | cnt_bdivs = 0 236 | pbar = tqdm(total=args.steps_2) 237 | while pbar.n < args.steps_2: 238 | for x in train_loader: 239 | # fake forward 240 | x = x.to(gpu) 241 | hole_area_fake = gen_hole_area( 242 | (args.ld_input_size, args.ld_input_size), 243 | (x.shape[3], x.shape[2])) 244 | mask = gen_input_mask( 245 | shape=(x.shape[0], 1, x.shape[2], x.shape[3]), 246 | hole_size=( 247 | (args.hole_min_w, args.hole_max_w), 248 | (args.hole_min_h, args.hole_max_h)), 249 | hole_area=hole_area_fake, 250 | max_holes=args.max_holes).to(gpu) 251 | fake = torch.zeros((len(x), 1)).to(gpu) 252 | x_mask = x - x * mask + mpv * mask 253 | input_cn = torch.cat((x_mask, mask), dim=1) 254 | output_cn = model_cn(input_cn) 255 | input_gd_fake = output_cn.detach() 256 | input_ld_fake = crop(input_gd_fake, hole_area_fake) 257 | output_fake = model_cd(( 258 | input_ld_fake.to(gpu), 259 | input_gd_fake.to(gpu))) 260 | loss_fake = bceloss(output_fake, fake) 261 | 262 | # real forward 263 | hole_area_real = gen_hole_area( 264 | (args.ld_input_size, args.ld_input_size), 265 | (x.shape[3], x.shape[2])) 266 | real = torch.ones((len(x), 1)).to(gpu) 267 | input_gd_real = x 268 | input_ld_real = crop(input_gd_real, hole_area_real) 269 | output_real = model_cd((input_ld_real, input_gd_real)) 270 | loss_real = bceloss(output_real, real) 271 | 272 | # reduce 273 | loss = (loss_fake + loss_real) / 2. 274 | 275 | # backward 276 | loss.backward() 277 | cnt_bdivs += 1 278 | if cnt_bdivs >= args.bdivs: 279 | cnt_bdivs = 0 280 | 281 | # optimize 282 | opt_cd.step() 283 | opt_cd.zero_grad() 284 | pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu()) 285 | pbar.update() 286 | 287 | # test 288 | if pbar.n % args.snaperiod_2 == 0: 289 | model_cn.eval() 290 | with torch.no_grad(): 291 | x = sample_random_batch( 292 | test_dset, 293 | batch_size=args.num_test_completions).to(gpu) 294 | mask = gen_input_mask( 295 | shape=(x.shape[0], 1, x.shape[2], x.shape[3]), 296 | hole_size=( 297 | (args.hole_min_w, args.hole_max_w), 298 | (args.hole_min_h, args.hole_max_h)), 299 | hole_area=gen_hole_area( 300 | (args.ld_input_size, args.ld_input_size), 301 | (x.shape[3], x.shape[2])), 302 | max_holes=args.max_holes).to(gpu) 303 | x_mask = x - x * mask + mpv * mask 304 | input = torch.cat((x_mask, mask), dim=1) 305 | output = model_cn(input) 306 | completed = poisson_blend(x_mask, output, mask) 307 | imgs = torch.cat(( 308 | x.cpu(), 309 | x_mask.cpu(), 310 | completed.cpu()), dim=0) 311 | imgpath = os.path.join( 312 | args.result_dir, 313 | 'phase_2', 314 | 'step%d.png' % pbar.n) 315 | model_cd_path = os.path.join( 316 | args.result_dir, 317 | 'phase_2', 318 | 'model_cd_step%d' % pbar.n) 319 | save_image(imgs, imgpath, nrow=len(x)) 320 | if args.data_parallel: 321 | torch.save( 322 | model_cd.module.state_dict(), 323 | model_cd_path) 324 | else: 325 | torch.save( 326 | model_cd.state_dict(), 327 | model_cd_path) 328 | model_cn.train() 329 | if pbar.n >= args.steps_2: 330 | break 331 | pbar.close() 332 | 333 | # ================================================ 334 | # Training Phase 3 335 | # ================================================ 336 | cnt_bdivs = 0 337 | pbar = tqdm(total=args.steps_3) 338 | while pbar.n < args.steps_3: 339 | for x in train_loader: 340 | # forward model_cd 341 | x = x.to(gpu) 342 | hole_area_fake = gen_hole_area( 343 | (args.ld_input_size, args.ld_input_size), 344 | (x.shape[3], x.shape[2])) 345 | mask = gen_input_mask( 346 | shape=(x.shape[0], 1, x.shape[2], x.shape[3]), 347 | hole_size=( 348 | (args.hole_min_w, args.hole_max_w), 349 | (args.hole_min_h, args.hole_max_h)), 350 | hole_area=hole_area_fake, 351 | max_holes=args.max_holes).to(gpu) 352 | 353 | # fake forward 354 | fake = torch.zeros((len(x), 1)).to(gpu) 355 | x_mask = x - x * mask + mpv * mask 356 | input_cn = torch.cat((x_mask, mask), dim=1) 357 | output_cn = model_cn(input_cn) 358 | input_gd_fake = output_cn.detach() 359 | input_ld_fake = crop(input_gd_fake, hole_area_fake) 360 | output_fake = model_cd((input_ld_fake, input_gd_fake)) 361 | loss_cd_fake = bceloss(output_fake, fake) 362 | 363 | # real forward 364 | hole_area_real = gen_hole_area( 365 | (args.ld_input_size, args.ld_input_size), 366 | (x.shape[3], x.shape[2])) 367 | real = torch.ones((len(x), 1)).to(gpu) 368 | input_gd_real = x 369 | input_ld_real = crop(input_gd_real, hole_area_real) 370 | output_real = model_cd((input_ld_real, input_gd_real)) 371 | loss_cd_real = bceloss(output_real, real) 372 | 373 | # reduce 374 | loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2. 375 | 376 | # backward model_cd 377 | loss_cd.backward() 378 | cnt_bdivs += 1 379 | if cnt_bdivs >= args.bdivs: 380 | # optimize 381 | opt_cd.step() 382 | opt_cd.zero_grad() 383 | 384 | # forward model_cn 385 | loss_cn_1 = completion_network_loss(x, output_cn, mask) 386 | input_gd_fake = output_cn 387 | input_ld_fake = crop(input_gd_fake, hole_area_fake) 388 | output_fake = model_cd((input_ld_fake, (input_gd_fake))) 389 | loss_cn_2 = bceloss(output_fake, real) 390 | 391 | # reduce 392 | loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2. 393 | 394 | # backward model_cn 395 | loss_cn.backward() 396 | if cnt_bdivs >= args.bdivs: 397 | cnt_bdivs = 0 398 | 399 | # optimize 400 | opt_cn.step() 401 | opt_cn.zero_grad() 402 | pbar.set_description( 403 | 'phase 3 | train loss (cd): %.5f (cn): %.5f' % ( 404 | loss_cd.cpu(), 405 | loss_cn.cpu())) 406 | pbar.update() 407 | 408 | # test 409 | if pbar.n % args.snaperiod_3 == 0: 410 | model_cn.eval() 411 | with torch.no_grad(): 412 | x = sample_random_batch( 413 | test_dset, 414 | batch_size=args.num_test_completions).to(gpu) 415 | mask = gen_input_mask( 416 | shape=(x.shape[0], 1, x.shape[2], x.shape[3]), 417 | hole_size=( 418 | (args.hole_min_w, args.hole_max_w), 419 | (args.hole_min_h, args.hole_max_h)), 420 | hole_area=gen_hole_area( 421 | (args.ld_input_size, args.ld_input_size), 422 | (x.shape[3], x.shape[2])), 423 | max_holes=args.max_holes).to(gpu) 424 | x_mask = x - x * mask + mpv * mask 425 | input = torch.cat((x_mask, mask), dim=1) 426 | output = model_cn(input) 427 | completed = poisson_blend(x_mask, output, mask) 428 | imgs = torch.cat(( 429 | x.cpu(), 430 | x_mask.cpu(), 431 | completed.cpu()), dim=0) 432 | imgpath = os.path.join( 433 | args.result_dir, 434 | 'phase_3', 435 | 'step%d.png' % pbar.n) 436 | model_cn_path = os.path.join( 437 | args.result_dir, 438 | 'phase_3', 439 | 'model_cn_step%d' % pbar.n) 440 | model_cd_path = os.path.join( 441 | args.result_dir, 442 | 'phase_3', 443 | 'model_cd_step%d' % pbar.n) 444 | save_image(imgs, imgpath, nrow=len(x)) 445 | if args.data_parallel: 446 | torch.save( 447 | model_cn.module.state_dict(), 448 | model_cn_path) 449 | torch.save( 450 | model_cd.module.state_dict(), 451 | model_cd_path) 452 | else: 453 | torch.save( 454 | model_cn.state_dict(), 455 | model_cn_path) 456 | torch.save( 457 | model_cd.state_dict(), 458 | model_cd_path) 459 | model_cn.train() 460 | if pbar.n >= args.steps_3: 461 | break 462 | pbar.close() 463 | 464 | 465 | if __name__ == '__main__': 466 | args = parser.parse_args() 467 | args.data_dir = os.path.expanduser(args.data_dir) 468 | args.result_dir = os.path.expanduser(args.result_dir) 469 | if args.init_model_cn is not None: 470 | args.init_model_cn = os.path.expanduser(args.init_model_cn) 471 | if args.init_model_cd is not None: 472 | args.init_model_cd = os.path.expanduser(args.init_model_cd) 473 | main(args) 474 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | def gen_input_mask( 9 | shape, hole_size, hole_area=None, max_holes=1): 10 | """ 11 | * inputs: 12 | - shape (sequence, required): 13 | Shape of a mask tensor to be generated. 14 | A sequence of length 4 (N, C, H, W) is assumed. 15 | - hole_size (sequence or int, required): 16 | Size of holes created in a mask. 17 | If a sequence of length 4 is provided, 18 | holes of size (W, H) = ( 19 | hole_size[0][0] <= hole_size[0][1], 20 | hole_size[1][0] <= hole_size[1][1], 21 | ) are generated. 22 | All the pixel values within holes are filled with 1.0. 23 | - hole_area (sequence, optional): 24 | This argument constraints the area where holes are generated. 25 | hole_area[0] is the left corner (X, Y) of the area, 26 | while hole_area[1] is its width and height (W, H). 27 | This area is used as the input region of Local discriminator. 28 | The default value is None. 29 | - max_holes (int, optional): 30 | This argument specifies how many holes are generated. 31 | The number of holes is randomly chosen from [1, max_holes]. 32 | The default value is 1. 33 | * returns: 34 | A mask tensor of shape [N, C, H, W] with holes. 35 | All the pixel values within holes are filled with 1.0, 36 | while the other pixel values are zeros. 37 | """ 38 | mask = torch.zeros(shape) 39 | bsize, _, mask_h, mask_w = mask.shape 40 | for i in range(bsize): 41 | n_holes = random.choice(list(range(1, max_holes+1))) 42 | for _ in range(n_holes): 43 | # choose patch width 44 | if isinstance(hole_size[0], tuple) and len(hole_size[0]) == 2: 45 | hole_w = random.randint(hole_size[0][0], hole_size[0][1]) 46 | else: 47 | hole_w = hole_size[0] 48 | 49 | # choose patch height 50 | if isinstance(hole_size[1], tuple) and len(hole_size[1]) == 2: 51 | hole_h = random.randint(hole_size[1][0], hole_size[1][1]) 52 | else: 53 | hole_h = hole_size[1] 54 | 55 | # choose offset upper-left coordinate 56 | if hole_area: 57 | harea_xmin, harea_ymin = hole_area[0] 58 | harea_w, harea_h = hole_area[1] 59 | offset_x = random.randint(harea_xmin, harea_xmin + harea_w - hole_w) 60 | offset_y = random.randint(harea_ymin, harea_ymin + harea_h - hole_h) 61 | else: 62 | offset_x = random.randint(0, mask_w - hole_w) 63 | offset_y = random.randint(0, mask_h - hole_h) 64 | mask[i, :, offset_y: offset_y + hole_h, offset_x: offset_x + hole_w] = 1.0 65 | return mask 66 | 67 | 68 | def gen_hole_area(size, mask_size): 69 | """ 70 | * inputs: 71 | - size (sequence, required) 72 | A sequence of length 2 (W, H) is assumed. 73 | (W, H) is the size of hole area. 74 | - mask_size (sequence, required) 75 | A sequence of length 2 (W, H) is assumed. 76 | (W, H) is the size of input mask. 77 | * returns: 78 | A sequence used for the input argument 'hole_area' for function 'gen_input_mask'. 79 | """ 80 | mask_w, mask_h = mask_size 81 | harea_w, harea_h = size 82 | offset_x = random.randint(0, mask_w - harea_w) 83 | offset_y = random.randint(0, mask_h - harea_h) 84 | return ((offset_x, offset_y), (harea_w, harea_h)) 85 | 86 | 87 | def crop(x, area): 88 | """ 89 | * inputs: 90 | - x (torch.Tensor, required) 91 | A torch tensor of shape (N, C, H, W) is assumed. 92 | - area (sequence, required) 93 | A sequence of length 2 ((X, Y), (W, H)) is assumed. 94 | sequence[0] (X, Y) is the left corner of an area to be cropped. 95 | sequence[1] (W, H) is its width and height. 96 | * returns: 97 | A torch tensor of shape (N, C, H, W) cropped in the specified area. 98 | """ 99 | xmin, ymin = area[0] 100 | w, h = area[1] 101 | return x[:, :, ymin: ymin + h, xmin: xmin + w] 102 | 103 | 104 | def sample_random_batch(dataset, batch_size=32): 105 | """ 106 | * inputs: 107 | - dataset (torch.utils.data.Dataset, required) 108 | An instance of torch.utils.data.Dataset. 109 | - batch_size (int, optional) 110 | Batch size. 111 | * returns: 112 | A mini-batch randomly sampled from the input dataset. 113 | """ 114 | num_samples = len(dataset) 115 | batch = [] 116 | for _ in range(min(batch_size, num_samples)): 117 | index = random.choice(range(0, num_samples)) 118 | x = torch.unsqueeze(dataset[index], dim=0) 119 | batch.append(x) 120 | return torch.cat(batch, dim=0) 121 | 122 | 123 | def poisson_blend(input, output, mask): 124 | """ 125 | * inputs: 126 | - input (torch.Tensor, required) 127 | Input tensor of Completion Network, whose shape = (N, 3, H, W). 128 | - output (torch.Tensor, required) 129 | Output tensor of Completion Network, whose shape = (N, 3, H, W). 130 | - mask (torch.Tensor, required) 131 | Input mask tensor of Completion Network, whose shape = (N, 1, H, W). 132 | * returns: 133 | Output image tensor of shape (N, 3, H, W) inpainted with poisson image editing method. 134 | """ 135 | input = input.clone().cpu() 136 | output = output.clone().cpu() 137 | mask = mask.clone().cpu() 138 | mask = torch.cat((mask, mask, mask), dim=1) # convert to 3-channel format 139 | num_samples = input.shape[0] 140 | ret = [] 141 | for i in range(num_samples): 142 | dstimg = transforms.functional.to_pil_image(input[i]) 143 | dstimg = np.array(dstimg)[:, :, [2, 1, 0]] 144 | srcimg = transforms.functional.to_pil_image(output[i]) 145 | srcimg = np.array(srcimg)[:, :, [2, 1, 0]] 146 | msk = transforms.functional.to_pil_image(mask[i]) 147 | msk = np.array(msk)[:, :, [2, 1, 0]] 148 | # compute mask's center 149 | xs, ys = [], [] 150 | for j in range(msk.shape[0]): 151 | for k in range(msk.shape[1]): 152 | if msk[j, k, 0] == 255: 153 | ys.append(j) 154 | xs.append(k) 155 | xmin, xmax = min(xs), max(xs) 156 | ymin, ymax = min(ys), max(ys) 157 | center = ((xmax + xmin) // 2, (ymax + ymin) // 2) 158 | dstimg = cv2.inpaint(dstimg, msk[:, :, 0], 1, cv2.INPAINT_TELEA) 159 | out = cv2.seamlessClone(srcimg, dstimg, msk, center, cv2.NORMAL_CLONE) 160 | out = out[:, :, [2, 1, 0]] 161 | out = transforms.functional.to_tensor(out) 162 | out = torch.unsqueeze(out, dim=0) 163 | ret.append(out) 164 | ret = torch.cat(ret, dim=0) 165 | return ret 166 | --------------------------------------------------------------------------------