├── .gitignore ├── src ├── config.py ├── loss.py ├── utils.py ├── evaluate.py ├── dataset.py └── generator.py ├── README.md ├── LICENSE ├── eval.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | runs 3 | snapshots 4 | results 5 | src/__pycache__ -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | 2 | # generator loss function 3 | VALID_LOSS = 1.0 4 | HOLE_LOSS = 6.0 5 | TOTAL_VARIATION_LOSS = 0.1 6 | PERCEPTUAL_LOSS = 0.05 7 | STYLE_LOSS = 120.0 8 | 9 | 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PConv 2 | 3 | PyTorch implementation of "Image Inpainting for Irregular Holes Using Partial Convolutions (ECCV 2018)" [[Paper]](https://arxiv.org/abs/1804.07723). 4 | 5 | **Authors**: _Guilin Liu, Fitsum A. Reda, Kevin J. Shih, Ting-Chun Wang, Andrew Tao, Bryan Catanzaro_ 6 | 7 | ## Prerequisites 8 | 9 | * Python 3 10 | * PyTorch 1.0 11 | * NVIDIA GPU + CUDA cuDNN 12 | 13 | ## Installation 14 | 15 | * Clone this repo: 16 | 17 | ``` 18 | git clone https://github.com/Xiefan-Guo/PConv.git 19 | cd PConv 20 | ``` 21 | 22 | ## Usage 23 | 24 | ### Training 25 | 26 | To train the PUNet model: 27 | 28 | ``` 29 | python train.py \ 30 | --image_root [path to input image directory] \ 31 | --mask_root [path to masks directory] 32 | ``` 33 | 34 | ### Evaluating 35 | 36 | To evaluate the model: 37 | 38 | ``` 39 | python eval.py \ 40 | --pre_trained [path to checkpoints] \ 41 | --image_root [path to input image directory] \ 42 | --mask_root [path to masks directory] 43 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xie-Fan Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import src.config as config 5 | from src.utils import gram_matrix 6 | 7 | 8 | # -------- 9 | # tv loss 10 | # -------- 11 | def total_variation_loss(image): 12 | # --------------------------------------------------------------- 13 | # shift one pixel and get difference (for both x and y direction) 14 | # --------------------------------------------------------------- 15 | loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \ 16 | torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])) 17 | 18 | return loss 19 | 20 | 21 | def generator_loss(inputs, masks, outputs, ground_truths, extractor): 22 | 23 | l1 = nn.L1Loss() 24 | 25 | comp = masks * inputs + (1 - masks) * outputs 26 | 27 | # --------- 28 | # hole loss 29 | # --------- 30 | loss_hole = l1((1 - masks) * outputs, (1 - masks) * ground_truths) 31 | 32 | # ---------- 33 | # valid loss 34 | # ---------- 35 | loss_valid = l1(masks * outputs, masks * ground_truths) 36 | 37 | if outputs.size(1) == 3: 38 | feat_comp = extractor(comp) 39 | feat_output = extractor(outputs) 40 | feat_gt = extractor(ground_truths) 41 | elif outputs.size(1) == 1: 42 | feat_comp = extractor(torch.cat([comp] * 3, dim=1)) 43 | feat_output = extractor(torch.cat([outputs] * 3, dim=1)) 44 | feat_gt = extractor(torch.cat([ground_truths] * 3, dim=1)) 45 | else: 46 | raise ValueError('only gray rgb') 47 | 48 | # --------------- 49 | # perceptual loss 50 | # --------------- 51 | loss_perceptual = 0.0 52 | for i in range(3): 53 | loss_perceptual += l1(feat_output[i], feat_gt[i]) 54 | loss_perceptual += l1(feat_comp[i], feat_gt[i]) 55 | 56 | # ---------- 57 | # style loss 58 | # ---------- 59 | loss_style = 0.0 60 | for i in range(3): 61 | loss_style += l1(gram_matrix(feat_output[i]), 62 | gram_matrix(feat_gt[i])) 63 | loss_style += l1(gram_matrix(feat_comp[i]), 64 | gram_matrix(feat_gt[i])) 65 | 66 | # ------- 67 | # tv loss 68 | # ------- 69 | loss_tv = total_variation_loss((1 - masks) * outputs) 70 | 71 | total_loss = loss_hole * config.HOLE_LOSS + loss_valid * config.VALID_LOSS + \ 72 | loss_perceptual * config.PERCEPTUAL_LOSS + loss_style * config.STYLE_LOSS +\ 73 | loss_tv * config.TOTAL_VARIATION_LOSS 74 | 75 | return total_loss, [loss_hole, loss_valid, loss_perceptual, loss_style, loss_tv] 76 | 77 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | 9 | 10 | # --------------------- 11 | # VGG16 feature extract 12 | # --------------------- 13 | class VGG16FeatureExtractor(nn.Module): 14 | 15 | def __init__(self): 16 | super(VGG16FeatureExtractor, self).__init__() 17 | 18 | vgg16 = torchvision.models.vgg16(pretrained=True) 19 | 20 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 21 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 22 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 23 | 24 | # fix the encoder 25 | for i in range(3): 26 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 27 | param.requires_grad = False 28 | 29 | def forward(self, images): 30 | results = [images] 31 | for i in range(3): 32 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 33 | results.append(func(results[-1])) 34 | return results[1:] 35 | 36 | 37 | # -------------- 38 | # weight initial 39 | # -------------- 40 | def weights_init(init_type='gaussian', gain=0.02): 41 | 42 | def init_func(m): 43 | classname = m.__class__.__name__ 44 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 45 | 46 | if init_type == 'normal': 47 | nn.init.normal_(m.weight.data, 0.0, gain) 48 | elif init_type == 'xavier': 49 | nn.init.xavier_normal_(m.weight.data, gain=gain) 50 | elif init_type == 'kaiming': 51 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 52 | elif init_type == 'orthogonal': 53 | nn.init.orthogonal_(m.weight.data, gain=gain) 54 | 55 | if hasattr(m, 'bias') and m.bias is not None: 56 | nn.init.constant_(m.bias.data, 0.0) 57 | 58 | elif classname.find('BatchNorm2d') != -1: 59 | nn.init.normal_(m.weight.data, 1.0, gain) 60 | nn.init.constant_(m.bias.data, 0.0) 61 | 62 | return init_func 63 | 64 | 65 | # --------------------------------------------------------------------------------------- 66 | # gram matrix 67 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 68 | # --------------------------------------------------------------------------------------- 69 | def gram_matrix(feat): 70 | 71 | (b, ch, h, w) = feat.size() 72 | feat = feat.view(b, ch, h * w) 73 | feat_t = feat.transpose(1, 2) 74 | gram = torch.bmm(feat, feat_t) / (ch * h * w) 75 | 76 | return gram 77 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import exp 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | def compute_psnr(ground_truths, outputs_comp): 10 | 11 | batch_mse = ((ground_truths - outputs_comp) ** 2).mean() 12 | psnr = 10.0 * math.log10(1.0 / batch_mse) 13 | 14 | return psnr 15 | 16 | 17 | def gaussian(window_size, sigma): 18 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 26 | return window 27 | 28 | 29 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 30 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 31 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 32 | 33 | mu1_sq = mu1.pow(2) 34 | mu2_sq = mu2.pow(2) 35 | mu1_mu2 = mu1 * mu2 36 | 37 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 38 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 39 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 40 | 41 | C1 = 0.01 ** 2 42 | C2 = 0.03 ** 2 43 | 44 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 45 | 46 | if size_average: 47 | return ssim_map.mean() 48 | else: 49 | return ssim_map.mean(1).mean(1).mean(1) 50 | 51 | 52 | class SSIM(torch.nn.Module): 53 | def __init__(self, window_size=11, size_average=True): 54 | super(SSIM, self).__init__() 55 | self.window_size = window_size 56 | self.size_average = size_average 57 | self.channel = 1 58 | self.window = create_window(window_size, self.channel) 59 | 60 | def forward(self, img1, img2): 61 | (_, channel, _, _) = img1.size() 62 | 63 | if channel == self.channel and self.window.data.type() == img1.data.type(): 64 | window = self.window 65 | else: 66 | window = create_window(self.window_size, channel) 67 | 68 | if img1.is_cuda: 69 | window = window.cuda(img1.get_device()) 70 | window = window.type_as(img1) 71 | 72 | self.window = window 73 | self.channel = channel 74 | 75 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 76 | 77 | 78 | def compute_ssim(img1, img2, window_size=11, size_average=True): 79 | (_, channel, _, _) = img1.size() 80 | window = create_window(window_size, channel) 81 | 82 | if img1.is_cuda: 83 | window = window.cuda(img1.get_device()) 84 | window = window.type_as(img1) 85 | 86 | return _ssim(img1, img2, window, window_size, channel, size_average) 87 | 88 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | def check_image_file(filename): 15 | # ------------------------------------------------------ 16 | # https://www.runoob.com/python/python-func-any.html 17 | # https://www.runoob.com/python/att-string-endswith.html 18 | # ------------------------------------------------------ 19 | return any([filename.endswith(extention) for extention in 20 | ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']]) 21 | 22 | 23 | def image_transforms(load_size, crop_size): 24 | # -------------------------------------------------------------- 25 | # https://blog.csdn.net/weixin_38533896/article/details/86028509 26 | # -------------------------------------------------------------- 27 | return transforms.Compose([ 28 | transforms.Resize(size=load_size, interpolation=Image.BICUBIC), 29 | transforms.RandomCrop(size=crop_size), 30 | transforms.RandomHorizontalFlip(p=0.5), 31 | transforms.ToTensor() 32 | ]) 33 | 34 | 35 | def mask_transforms(crop_size): 36 | 37 | return transforms.Compose([ 38 | transforms.Resize(size=crop_size, interpolation=Image.NEAREST), 39 | transforms.ToTensor() 40 | ]) 41 | 42 | 43 | class ImageDataset(Dataset): 44 | 45 | def __init__(self, image_root, mask_root, load_size, crop_size): 46 | super(ImageDataset, self).__init__() 47 | 48 | self.image_files = [os.path.join(root, file) for root, dirs, files in os.walk(image_root) 49 | for file in files if check_image_file(file)] 50 | self.mask_files = [os.path.join(root, file) for root, dirs, files in os.walk(mask_root) 51 | for file in files if check_image_file(file)] 52 | 53 | self.number_image = len(self.image_files) 54 | self.number_mask = len(self.mask_files) 55 | self.load_size = load_size 56 | self.crop_size = crop_size 57 | self.image_files_transforms = image_transforms(load_size, crop_size) 58 | self.mask_files_transforms = mask_transforms(crop_size) 59 | 60 | def __getitem__(self, index): 61 | 62 | image = Image.open(self.image_files[index % self.number_image]) 63 | mask = Image.open(self.mask_files[random.randint(0, self.number_mask - 1)]) 64 | 65 | ground_truth = self.image_files_transforms(image.convert('RGB')) 66 | mask = self.mask_files_transforms(mask.convert('RGB')) 67 | 68 | threshold = 0.5 69 | ones = mask >= threshold 70 | zeros = mask < threshold 71 | 72 | mask.masked_fill_(ones, 1.0) 73 | mask.masked_fill_(zeros, 0.0) 74 | 75 | # --------------------------------------------------- 76 | # white values(ones) denotes the area to be inpainted 77 | # dark values(zeros) is the values remained 78 | # --------------------------------------------------- 79 | mask = 1 - mask 80 | input_image = ground_truth * mask 81 | 82 | return input_image, ground_truth, mask 83 | 84 | def __len__(self): 85 | return self.number_image 86 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | from torch.utils.data import DataLoader 9 | from torchvision.utils import save_image 10 | 11 | from src.dataset import ImageDataset 12 | from src.generator import PUNet 13 | from src.evaluate import compute_psnr, compute_ssim 14 | from src import config 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--num_workers', type=int, default=0, help='workers for dataloader') 18 | parser.add_argument('--pre_trained', type=str, default='', 19 | help='pre-trained models for fine-tuning') 20 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 21 | parser.add_argument('--load_size', type=int, default=350, help='image loading size') 22 | parser.add_argument('--crop_size', type=int, default=256, help='image training size') 23 | parser.add_argument('--image_root', type=str, default='') 24 | parser.add_argument('--mask_root', type=str, default='') 25 | parser.add_argument('--result_root', type=str, default='results/eval/default', help='train result') 26 | parser.add_argument('--number_eval', type=int, default=10, help='number of batches eval') 27 | args = parser.parse_args() 28 | 29 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 30 | is_cuda = torch.cuda.is_available() 31 | if is_cuda: 32 | print('Cuda is available!') 33 | cudnn.benchmark = True 34 | 35 | if not os.path.exists(args.result_root): 36 | os.makedirs(args.result_root) 37 | 38 | image_dataset = ImageDataset( 39 | args.image_root, args.mask_root, (args.load_size, args.load_size), (args.crop_size, args.crop_size) 40 | ) 41 | data_loader = DataLoader( 42 | image_dataset, 43 | batch_size=args.batch_size, 44 | shuffle=True, 45 | num_workers=args.num_workers, 46 | drop_last=False 47 | ) 48 | print(len(data_loader)) 49 | generator = PUNet(3, 3) 50 | 51 | if args.pre_trained != '': 52 | generator.load_state_dict(torch.load(args.pre_trained)['generator']) 53 | else: 54 | print('Please provide pre-trained model!') 55 | 56 | for param in generator.parameters(): 57 | param.requires_grad = False 58 | 59 | if is_cuda: 60 | generator = generator.cuda() 61 | 62 | print('start eval...') 63 | 64 | sum_psnr = 0.0 65 | sum_ssim = 0.0 66 | l1_loss = 0.0 67 | sum_time = 0.0 68 | count = 0 69 | 70 | 71 | while True: 72 | 73 | if count >= args.number_eval: 74 | break 75 | 76 | generator.eval() 77 | for _, (input_images, ground_truths, masks) in enumerate(data_loader): 78 | 79 | if count >= args.number_eval: 80 | break 81 | 82 | count = count + 1 83 | if is_cuda: 84 | input_images, ground_truths, masks = \ 85 | input_images.cuda(), ground_truths.cuda(), masks.cuda() 86 | 87 | start_time = time.time() 88 | outputs = generator(input_images, masks) 89 | end_time = time.time() 90 | sum_time = sum_time + (end_time - start_time) 91 | print('time: ', end_time - start_time) 92 | # canny edge detection 93 | 94 | outputs = outputs.data.cpu() 95 | ground_truths = ground_truths.data.cpu() 96 | masks = masks.data.cpu() 97 | 98 | damaged_images = ground_truths * masks + (1 - masks) 99 | outputs_comps = ground_truths * masks + outputs * (1 - masks) 100 | 101 | psnr = compute_psnr(ground_truths, outputs_comps) 102 | sum_psnr += psnr 103 | print(count, ' psnr: ', psnr) 104 | 105 | ssim = compute_ssim(ground_truths * 255, outputs_comps * 255).item() 106 | sum_ssim += ssim 107 | print(count, ' ssim: ', ssim) 108 | 109 | l1 = nn.L1Loss()(ground_truths, outputs_comps).item() 110 | l1_loss += l1 111 | print(count, ' l1_loss: ', l1) 112 | 113 | sizes = ground_truths.size() 114 | bound = min(5, sizes[0]) 115 | save_images = torch.Tensor(sizes[0] * 8, sizes[1], sizes[2], sizes[3]) 116 | 117 | # ------------- 118 | # simple sample 119 | # ------------- 120 | 121 | save_images_5 = torch.Tensor(sizes[0] * 5, sizes[1], sizes[2], sizes[3]) 122 | for i in range(sizes[0]): 123 | save_images_5[5 * i] = 1 - masks[i] 124 | save_images_5[5 * i + 1] = damaged_images[i] 125 | save_images_5[5 * i + 2] = outputs[i] 126 | save_images_5[5 * i + 3] = outputs_comps[i] 127 | save_images_5[5 * i + 4] = ground_truths[i] 128 | 129 | save_image(save_images_5, os.path.join(args.result_root, '{:05d}.png'.format(count)), nrow=5) 130 | 131 | 132 | print('count: ', count) 133 | print('avgrage l1 loss: ', l1_loss / count) 134 | print('average psnr: ', sum_psnr / count) 135 | print('average ssim: ', sum_ssim / count) 136 | print('average time cost: ', sum_time / count) 137 | 138 | print('complete the eval.') 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data import DataLoader 10 | 11 | from tensorboardX import SummaryWriter 12 | 13 | from src.dataset import ImageDataset 14 | from src.generator import PUNet 15 | from src.utils import VGG16FeatureExtractor 16 | from src.loss import generator_loss 17 | from src import config 18 | 19 | import torch.multiprocessing 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--num_workers', type=int, default=0, help='workers for dataloader') 24 | parser.add_argument('--image_root', type=str, default='') 25 | parser.add_argument('--mask_root', type=str, default='') 26 | parser.add_argument('--save_dir', type=str, default='snapshots/PairsStreetView', help='path for saving models') 27 | parser.add_argument('--log_dir', type=str, default='runs/PairsStreetView', help='log with tensorboardX') 28 | parser.add_argument('--pre_trained', type=str, default='', help='the path of checkpoint') 29 | parser.add_argument('--save_interval', type=int, default=5, help='interval between model save') 30 | parser.add_argument('--gen_lr', type=float, default=0.0002) 31 | parser.add_argument('--D2G_lr', type=float, default=0.1) 32 | parser.add_argument('--lr_finetune', type=float, default=0.00005) 33 | parser.add_argument('--finetune', type=int, default=0) 34 | parser.add_argument('--b1', type=float, default=0.0) 35 | parser.add_argument('--b2', type=float, default=0.9) 36 | parser.add_argument('--batch_size', type=int, default=6) 37 | parser.add_argument('--load_size', type=int, default=350, help='image loading size') 38 | parser.add_argument('--crop_size', type=int, default=256, help='image training size') 39 | parser.add_argument('--start_iter', type=int, default=1, help='start iter') 40 | parser.add_argument('--train_epochs', type=int, default=500, help='training epochs') 41 | args = parser.parse_args() 42 | 43 | 44 | # ----------- 45 | # GPU setting 46 | # ----------- 47 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 48 | is_cuda = torch.cuda.is_available() 49 | if is_cuda: 50 | print('Cuda is available') 51 | cudnn.enable = True 52 | cudnn.benchmark = True 53 | 54 | # ----------------------- 55 | # samples and checkpoints 56 | # ----------------------- 57 | if not os.path.exists(args.save_dir): 58 | os.makedirs('{:s}/images'.format(args.save_dir)) 59 | os.makedirs('{:s}/ckpt'.format(args.save_dir)) 60 | 61 | # ---------- 62 | # Dataloader 63 | # ---------- 64 | load_size = (args.load_size, args.load_size) 65 | crop_size = (args.crop_size, args.crop_size) 66 | image_dataset = ImageDataset(args.image_root, args.mask_root, load_size, crop_size) 67 | data_loader = DataLoader( 68 | image_dataset, 69 | batch_size=args.batch_size, 70 | shuffle=True, 71 | num_workers=args.num_workers, 72 | drop_last=False, 73 | pin_memory=True 74 | ) 75 | 76 | # ----- 77 | # model 78 | # ----- 79 | generator = PUNet(in_channels=3, out_channels=3) 80 | extractor = VGG16FeatureExtractor() 81 | 82 | # ---------- 83 | # load model 84 | # ---------- 85 | start_iter = args.start_iter 86 | if args.pre_trained != '': 87 | 88 | ckpt_dict_load = torch.load(args.pre_trained) 89 | start_iter = ckpt_dict_load['n_iter'] 90 | generator.load_state_dict(ckpt_dict_load['generator']) 91 | 92 | print('Starting from iter ', start_iter) 93 | 94 | # ----------------------------- 95 | # DataParallel and model cuda() 96 | # ----------------------------- 97 | num_GPUS = 0 98 | if is_cuda: 99 | 100 | generator = generator.cuda() 101 | extractor = extractor.cuda() 102 | 103 | num_GPUS = torch.cuda.device_count() 104 | print('The number of GPU is ', num_GPUS) 105 | 106 | if num_GPUS > 1: 107 | 108 | generator = nn.DataParallel(generator, device_ids=range(num_GPUS)) 109 | 110 | # --------- 111 | # fine tune 112 | # --------- 113 | if args.finetune != 0: 114 | print('Fine tune...') 115 | lr = args.lr_finetune 116 | generator.freeze_ec_bn = True 117 | else: 118 | lr = args.gen_lr 119 | 120 | # --------- 121 | # optimizer 122 | # --------- 123 | G_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(args.b1, args.b2)) 124 | 125 | writer = SummaryWriter(args.log_dir) 126 | 127 | # -------- 128 | # Training 129 | # -------- 130 | print('Start train...') 131 | count = 0 132 | for epoch in range(start_iter, args.train_epochs + 1): 133 | 134 | generator.train() 135 | for _, (input_images, ground_truths, masks) in enumerate(data_loader): 136 | 137 | count += 1 138 | start_time = time.time() 139 | if is_cuda: 140 | input_images, ground_truths, masks = \ 141 | input_images.cuda(), ground_truths.cuda(), masks.cuda() 142 | 143 | # --------- 144 | # Generator 145 | # --------- 146 | G_optimizer.zero_grad() 147 | 148 | outputs = generator(input_images, masks) 149 | G_loss, G_loss_list = generator_loss(input_images, masks, outputs, ground_truths, extractor) 150 | 151 | writer.add_scalar('G hole loss', G_loss_list[0].item(), count) 152 | writer.add_scalar('G valid loss', G_loss_list[1].item(), count) 153 | writer.add_scalar('G perceptual loss', G_loss_list[2].item(), count) 154 | writer.add_scalar('G style loss', G_loss_list[3].item(), count) 155 | writer.add_scalar('G tv loss', G_loss_list[4].item(), count) 156 | 157 | writer.add_scalar('G total loss', G_loss.item(), count) 158 | 159 | G_loss.backward() 160 | G_optimizer.step() 161 | 162 | end_time = time.time() 163 | sum_time = end_time - start_time 164 | 165 | print('[Epoch %d/%d] [Batch %d/%d] [count %d] [G_loss %f] [time %f]' % 166 | (epoch, args.train_epochs, _ + 1, len(data_loader), count, G_loss, sum_time)) 167 | 168 | # ---------- 169 | # save model 170 | # ---------- 171 | if epoch % args.save_interval == 0: 172 | 173 | ckpt_dict_save = {'n_iter': epoch + 1} 174 | if num_GPUS > 1: 175 | ckpt_dict_save['generator'] = generator.module.state_dict() 176 | else: 177 | ckpt_dict_save['generator'] = generator.state_dict() 178 | 179 | torch.save(ckpt_dict_save, '{:s}/ckpt/model_{:d}.pth'.format(args.save_dir, epoch)) 180 | 181 | 182 | writer.close() 183 | 184 | -------------------------------------------------------------------------------- /src/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from src.utils import weights_init 7 | 8 | 9 | # ------------ 10 | # Partial Conv 11 | # ------------ 12 | class PartialConv2d(nn.Conv2d): 13 | def __init__(self, *args, **kwargs): 14 | 15 | # whether the mask is multi-channel or not 16 | if 'multi_channel' in kwargs: 17 | self.multi_channel = kwargs['multi_channel'] 18 | kwargs.pop('multi_channel') 19 | else: 20 | self.multi_channel = False 21 | 22 | self.return_mask = True 23 | 24 | super(PartialConv2d, self).__init__(*args, **kwargs) 25 | 26 | if self.multi_channel: 27 | self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], 28 | self.kernel_size[1]) 29 | else: 30 | self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) 31 | 32 | self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \ 33 | self.weight_maskUpdater.shape[3] 34 | 35 | self.last_size = (None, None) 36 | self.update_mask = None 37 | self.mask_ratio = None 38 | 39 | def forward(self, input, mask=None): 40 | 41 | if mask is not None or self.last_size != (input.data.shape[2], input.data.shape[3]): 42 | self.last_size = (input.data.shape[2], input.data.shape[3]) 43 | 44 | with torch.no_grad(): 45 | if self.weight_maskUpdater.type() != input.type(): 46 | self.weight_maskUpdater = self.weight_maskUpdater.to(input) 47 | 48 | if mask is None: 49 | # if mask is not provided, create a mask 50 | if self.multi_channel: 51 | mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], 52 | input.data.shape[3]).to(input) 53 | else: 54 | mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) 55 | 56 | self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, 57 | padding=self.padding, dilation=self.dilation, groups=1) 58 | 59 | self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8) 60 | # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) 61 | self.update_mask = torch.clamp(self.update_mask, 0, 1) 62 | self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) 63 | 64 | if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type(): 65 | self.update_mask.to(input) 66 | self.mask_ratio.to(input) 67 | 68 | raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask is not None else input) 69 | 70 | if self.bias is not None: 71 | bias_view = self.bias.view(1, self.out_channels, 1, 1) 72 | output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view 73 | output = torch.mul(output, self.update_mask) 74 | else: 75 | output = torch.mul(raw_out, self.mask_ratio) 76 | 77 | if self.return_mask: 78 | return output, self.update_mask 79 | else: 80 | return output 81 | 82 | 83 | # -------------------------- 84 | # PConv-BatchNorm-Activation 85 | # -------------------------- 86 | class PConvBNActiv(nn.Module): 87 | 88 | def __init__(self, in_channels, out_channels, bn=True, sample='none-3', activ='relu', bias=False): 89 | super(PConvBNActiv, self).__init__() 90 | 91 | if sample == 'down-7': 92 | self.conv = PartialConv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=bias, multi_channel = True) 93 | elif sample == 'down-5': 94 | self.conv = PartialConv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2, bias=bias, multi_channel = True) 95 | elif sample == 'down-3': 96 | self.conv = PartialConv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=bias, multi_channel = True) 97 | else: 98 | self.conv = PartialConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias, multi_channel = True) 99 | 100 | if bn: 101 | self.bn = nn.BatchNorm2d(out_channels) 102 | 103 | if activ == 'relu': 104 | self.activation = nn.ReLU() 105 | elif activ == 'leaky': 106 | self.activation = nn.LeakyReLU(negative_slope=0.2) 107 | 108 | def forward(self, images, masks): 109 | 110 | images, masks = self.conv(images, masks) 111 | if hasattr(self, 'bn'): 112 | images = self.bn(images) 113 | if hasattr(self, 'activation'): 114 | images = self.activation(images) 115 | 116 | return images, masks 117 | 118 | 119 | # ------------ 120 | # Double U-Net 121 | # ------------ 122 | class PUNet(nn.Module): 123 | 124 | def __init__(self, in_channels=3, out_channels=3, up_sampling_node='nearest', init_weights=True): 125 | super(PUNet, self).__init__() 126 | 127 | self.freeze_ec_bn = False 128 | self.up_sampling_node = up_sampling_node 129 | 130 | self.ec_images_1 = PConvBNActiv(in_channels, 64, bn=False, sample='down-7') 131 | self.ec_images_2 = PConvBNActiv(64, 128, sample='down-5') 132 | self.ec_images_3 = PConvBNActiv(128, 256, sample='down-5') 133 | self.ec_images_4 = PConvBNActiv(256, 512, sample='down-3') 134 | self.ec_images_5 = PConvBNActiv(512, 512, sample='down-3') 135 | self.ec_images_6 = PConvBNActiv(512, 512, sample='down-3') 136 | self.ec_images_7 = PConvBNActiv(512, 512, sample='down-3') 137 | 138 | self.dc_images_7 = PConvBNActiv(512 + 512, 512, activ='leaky') 139 | self.dc_images_6 = PConvBNActiv(512 + 512, 512, activ='leaky') 140 | self.dc_images_5 = PConvBNActiv(512 + 512, 512, activ='leaky') 141 | self.dc_images_4 = PConvBNActiv(512 + 256, 256, activ='leaky') 142 | self.dc_images_3 = PConvBNActiv(256 + 128, 128, activ='leaky') 143 | self.dc_images_2 = PConvBNActiv(128 + 64, 64, activ='leaky') 144 | self.dc_images_1 = PConvBNActiv(64 + out_channels, out_channels, bn=False, sample='none-3', activ=None, bias=True) 145 | 146 | self.tanh = nn.Tanh() 147 | 148 | if init_weights: 149 | self.apply(weights_init()) 150 | 151 | def forward(self, input_images, input_masks): 152 | 153 | ec_images = {} 154 | 155 | ec_images['ec_images_0'], ec_images['ec_images_masks_0'] = input_images, input_masks 156 | ec_images['ec_images_1'], ec_images['ec_images_masks_1'] = self.ec_images_1(input_images, input_masks) 157 | ec_images['ec_images_2'], ec_images['ec_images_masks_2'] = self.ec_images_2(ec_images['ec_images_1'], ec_images['ec_images_masks_1']) 158 | ec_images['ec_images_3'], ec_images['ec_images_masks_3'] = self.ec_images_3(ec_images['ec_images_2'], ec_images['ec_images_masks_2']) 159 | ec_images['ec_images_4'], ec_images['ec_images_masks_4'] = self.ec_images_4(ec_images['ec_images_3'], ec_images['ec_images_masks_3']) 160 | ec_images['ec_images_5'], ec_images['ec_images_masks_5'] = self.ec_images_5(ec_images['ec_images_4'], ec_images['ec_images_masks_4']) 161 | ec_images['ec_images_6'], ec_images['ec_images_masks_6'] = self.ec_images_6(ec_images['ec_images_5'], ec_images['ec_images_masks_5']) 162 | ec_images['ec_images_7'], ec_images['ec_images_masks_7'] = self.ec_images_7(ec_images['ec_images_6'], ec_images['ec_images_masks_6']) 163 | 164 | # -------------- 165 | # images decoder 166 | # -------------- 167 | dc_images, dc_images_masks = ec_images['ec_images_7'], ec_images['ec_images_masks_7'] 168 | for _ in range(7, 0, -1): 169 | 170 | ec_images_skip = 'ec_images_{:d}'.format(_ - 1) 171 | ec_images_masks = 'ec_images_masks_{:d}'.format(_ - 1) 172 | dc_conv = 'dc_images_{:d}'.format(_) 173 | 174 | dc_images = F.interpolate(dc_images, scale_factor=2, mode=self.up_sampling_node) 175 | dc_images_masks = F.interpolate(dc_images_masks, scale_factor=2, mode=self.up_sampling_node) 176 | 177 | dc_images = torch.cat((dc_images, ec_images[ec_images_skip]), dim=1) 178 | dc_images_masks = torch.cat((dc_images_masks, ec_images[ec_images_masks]), dim=1) 179 | 180 | dc_images, dc_images_masks = getattr(self, dc_conv)(dc_images, dc_images_masks) 181 | 182 | outputs = self.tanh(dc_images) 183 | 184 | return outputs 185 | 186 | def train(self, mode=True): 187 | 188 | super().train(mode) 189 | 190 | if self.freeze_ec_bn: 191 | for name, module in self.named_modules(): 192 | if isinstance(module, nn.BatchNorm2d) and 'ec' in name: 193 | module.eval() 194 | 195 | --------------------------------------------------------------------------------