├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── loss.py ├── model.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | # Custom 108 | runs/ 109 | snapshots/ 110 | 111 | output.* 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yukari 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSA_pytorch 2 | Pytorch implementation of Coherent Semantic Attention Image Inpainting 3 | 4 | @misc{1905.12384, 5 | Author = {Hongyu Liu and Bin Jiang and Yi Xiao and Chao Yang}, 6 | Title = {Coherent Semantic Attention for Image Inpainting}, 7 | Year = {2019}, 8 | Eprint = {arXiv:1905.12384}, 9 | } 10 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | from torch.utils import data 8 | 9 | 10 | class InfiniteSampler(data.sampler.Sampler): 11 | def __init__(self, num_samples): 12 | self.num_samples = num_samples 13 | 14 | def __iter__(self): 15 | return iter(self.loop()) 16 | 17 | def __len__(self): 18 | return 2 ** 31 19 | 20 | def loop(self): 21 | i = 0 22 | order = np.random.permutation(self.num_samples) 23 | while True: 24 | yield order[i] 25 | i += 1 26 | if i >= self.num_samples: 27 | np.random.seed() 28 | order = np.random.permutation(self.num_samples) 29 | i = 0 30 | 31 | 32 | class DS(data.Dataset): 33 | def __init__(self, root, transform=None): 34 | self.samples = [] 35 | for root, _, fnames in sorted(os.walk(root)): 36 | for fname in sorted(fnames): 37 | path = os.path.join(root, fname) 38 | self.samples.append(path) 39 | if len(self.samples) == 0: 40 | raise RuntimeError("Found 0 files in subfolders of: " + root) 41 | 42 | self.transform = transform 43 | 44 | def __len__(self): 45 | return len(self.samples) 46 | 47 | def __getitem__(self, index): 48 | sample_path = self.samples[index] 49 | sample = Image.open(sample_path).convert('RGB') 50 | 51 | if self.transform is not None: 52 | sample = self.transform(sample) 53 | 54 | mask = self.random_mask() 55 | mask = torch.from_numpy(mask) 56 | 57 | return sample, mask 58 | 59 | @staticmethod 60 | def random_mask(height=256, width=256, pad=50, 61 | min_stroke=2, max_stroke=5, 62 | min_vertex=2, max_vertex=12, 63 | min_brush_width=7, max_brush_width=20, 64 | min_lenght=10, max_length=50): 65 | mask = np.zeros((height, width)) 66 | 67 | max_angle = 2*np.pi 68 | num_stroke = np.random.randint(min_stroke, max_stroke+1) 69 | 70 | for _ in range(num_stroke): 71 | num_vertex = np.random.randint(min_vertex, max_vertex+1) 72 | brush_width = np.random.randint(min_brush_width, max_brush_width+1) 73 | start_x = np.random.randint(pad, height-pad) 74 | start_y = np.random.randint(pad, width-pad) 75 | 76 | for _ in range(num_vertex): 77 | angle = np.random.uniform(max_angle) 78 | length = np.random.randint(min_lenght, max_length+1) 79 | #length = np.random.randint(min_lenght, height//num_vertex) 80 | end_x = (start_x + length * np.sin(angle)).astype(np.int32) 81 | end_y = (start_y + length * np.cos(angle)).astype(np.int32) 82 | end_x = max(0, min(end_x, height)) 83 | end_y = max(0, min(end_y, width)) 84 | 85 | cv2.line(mask, (start_x, start_y), (end_x, end_y), 1., brush_width) 86 | 87 | start_x, start_y = end_x, end_y 88 | 89 | if np.random.random() < 0.5: 90 | mask = np.fliplr(mask) 91 | if np.random.random() < 0.5: 92 | mask = np.flipud(mask) 93 | 94 | return mask.reshape((1,)+mask.shape).astype(np.float32) 95 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from torchvision import transforms 6 | 7 | 8 | def denorm(x): 9 | out = (x + 1) / 2 # [-1,1] -> [0,1] 10 | return out.clamp_(0, 1) 11 | 12 | 13 | class VGG16FeatureExtractor(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | vgg16 = models.vgg16(pretrained=True) 18 | 19 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 20 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 21 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 22 | self.enc_4 = nn.Sequential(*vgg16.features[17:23]) 23 | 24 | #print(self.enc_1) 25 | #print(self.enc_2) 26 | #print(self.enc_3) 27 | #print(self.enc_4) 28 | 29 | # fix the encoder 30 | for i in range(4): 31 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 32 | param.requires_grad = False 33 | 34 | def forward(self, image): 35 | results = [image] 36 | for i in range(4): 37 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 38 | results.append(func(results[-1])) 39 | return results[1:] 40 | 41 | 42 | class ConsistencyLoss(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | self.normalize = transforms.Normalize( 47 | mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225] 49 | ) 50 | self.vgg = VGG16FeatureExtractor() 51 | self.l2 = nn.MSELoss() 52 | 53 | def forward(self, csa, csa_d, target, mask): 54 | # https://pytorch.org/docs/stable/torchvision/models.html 55 | # Pre-trained VGG16 model expect input images normalized in the same way. 56 | # The images have to be loaded in to a range of [0, 1] 57 | # and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. 58 | t = denorm(target) # [-1,1] -> [0,1] 59 | t = self.normalize(t[0]) # BxCxHxW -> CxHxW -> normalize 60 | t = t.unsqueeze(0) # CxHxW -> BxCxHxW 61 | 62 | vgg_gt = self.vgg(t) 63 | vgg_gt = vgg_gt[-1] 64 | 65 | mask_r = F.interpolate(mask, size=csa.size()[2:]) 66 | 67 | lossvalue = self.l2(csa*mask_r, vgg_gt*mask_r) + self.l2(csa_d*mask_r, vgg_gt*mask_r) 68 | return lossvalue 69 | 70 | 71 | def calc_gan_loss(discriminator, output, target): 72 | y_pred_fake = discriminator(output, target) 73 | y_pred = discriminator(target, output) 74 | 75 | g_loss = (torch.mean((y_pred - torch.mean(y_pred_fake) + 1.) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) - 1.) ** 2))/2 76 | d_loss = (torch.mean((y_pred - torch.mean(y_pred_fake) - 1.) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) + 1.) ** 2))/2 77 | 78 | return g_loss, d_loss 79 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_norm(name, out_channels): 7 | if name == 'batch': 8 | norm = nn.BatchNorm2d(out_channels) 9 | elif name == 'instance': 10 | norm = nn.InstanceNorm2d(out_channels) 11 | else: 12 | norm = None 13 | return norm 14 | 15 | 16 | def get_act(name): 17 | if name == 'relu': 18 | activation = nn.ReLU(inplace=True) 19 | elif name == 'elu': 20 | activation == nn.ELU(inplace=True) 21 | elif name == 'leaky_relu': 22 | activation = nn.LeakyReLU(negative_slope=0.2, inplace=True) 23 | elif name == 'tanh': 24 | activation = nn.Tanh() 25 | elif name == 'sigmoid': 26 | activation = nn.Sigmoid() 27 | else: 28 | activation = None 29 | return activation 30 | 31 | 32 | class CoarseEncodeBlock(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, stride, 34 | normalization=None, activation=None): 35 | super().__init__() 36 | 37 | layers = [] 38 | if activation: 39 | layers.append(get_act(activation)) 40 | layers.append( 41 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1)) 42 | if normalization: 43 | layers.append(get_norm(normalization, out_channels)) 44 | self.encode = nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.encode(x) 48 | 49 | 50 | class CoarseDecodeBlock(nn.Module): 51 | def __init__(self, in_channels, out_channels, kernel_size, stride, 52 | normalization=None, activation=None): 53 | super().__init__() 54 | 55 | layers = [] 56 | if activation: 57 | layers.append(get_act(activation)) 58 | layers.append( 59 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=1)) 60 | if normalization: 61 | layers.append(get_norm(normalization, out_channels)) 62 | self.decode = nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | return self.decode(x) 66 | 67 | 68 | class CoarseNet(nn.Module): 69 | def __init__(self, c_img=3, 70 | norm='instance', act_en='leaky_relu', act_de='relu'): 71 | super().__init__() 72 | 73 | cnum = 64 74 | 75 | self.en_1 = nn.Conv2d(c_img, cnum, 4, 2, padding=1) 76 | self.en_2 = CoarseEncodeBlock(cnum, cnum*2, 4, 2, normalization=norm, activation=act_en) 77 | self.en_3 = CoarseEncodeBlock(cnum*2, cnum*4, 4, 2, normalization=norm, activation=act_en) 78 | self.en_4 = CoarseEncodeBlock(cnum*4, cnum*8, 4, 2, normalization=norm, activation=act_en) 79 | self.en_5 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_en) 80 | self.en_6 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_en) 81 | self.en_7 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_en) 82 | self.en_8 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, activation=act_en) 83 | 84 | self.de_8 = CoarseDecodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_de) 85 | self.de_7 = CoarseDecodeBlock(cnum*8*2, cnum*8, 4, 2, normalization=norm, activation=act_de) 86 | self.de_6 = CoarseDecodeBlock(cnum*8*2, cnum*8, 4, 2, normalization=norm, activation=act_de) 87 | self.de_5 = CoarseDecodeBlock(cnum*8*2, cnum*8, 4, 2, normalization=norm, activation=act_de) 88 | self.de_4 = CoarseDecodeBlock(cnum*8*2, cnum*4, 4, 2, normalization=norm, activation=act_de) 89 | self.de_3 = CoarseDecodeBlock(cnum*4*2, cnum*2, 4, 2, normalization=norm, activation=act_de) 90 | self.de_2 = CoarseDecodeBlock(cnum*2*2, cnum, 4, 2, normalization=norm, activation=act_de) 91 | self.de_1 = nn.Sequential( 92 | get_act(act_de), 93 | nn.ConvTranspose2d(cnum*2, c_img, 4, 2, padding=1), 94 | get_act('tanh')) 95 | 96 | def forward(self, x): 97 | out_1 = self.en_1(x) 98 | out_2 = self.en_2(out_1) 99 | out_3 = self.en_3(out_2) 100 | out_4 = self.en_4(out_3) 101 | out_5 = self.en_5(out_4) 102 | out_6 = self.en_6(out_5) 103 | out_7 = self.en_7(out_6) 104 | out_8 = self.en_8(out_7) 105 | 106 | dout_8 = self.de_8(out_8) 107 | dout_8_out_7 = torch.cat([dout_8, out_7], 1) 108 | dout_7 = self.de_7(dout_8_out_7) 109 | dout_7_out_6 = torch.cat([dout_7, out_6], 1) 110 | dout_6 = self.de_6(dout_7_out_6) 111 | dout_6_out_5 = torch.cat([dout_6, out_5], 1) 112 | dout_5 = self.de_5(dout_6_out_5) 113 | dout_5_out_4 = torch.cat([dout_5, out_4], 1) 114 | dout_4 = self.de_4(dout_5_out_4) 115 | dout_4_out_3 = torch.cat([dout_4, out_3], 1) 116 | dout_3 = self.de_3(dout_4_out_3) 117 | dout_3_out_2 = torch.cat([dout_3, out_2], 1) 118 | dout_2 = self.de_2(dout_3_out_2) 119 | dout_2_out_1 = torch.cat([dout_2, out_1], 1) 120 | dout_1 = self.de_1(dout_2_out_1) 121 | 122 | return dout_1 123 | 124 | 125 | class RefineEncodeBlock(nn.Module): 126 | def __init__(self, in_channels, out_channels, 127 | normalization=None, activation=None): 128 | super().__init__() 129 | 130 | layers = [] 131 | if activation: 132 | layers.append(get_act(activation)) 133 | layers.append( 134 | nn.Conv2d(in_channels, in_channels, 4, 2, dilation=2, padding=3)) 135 | if normalization: 136 | layers.append(get_norm(normalization, out_channels)) 137 | 138 | if activation: 139 | layers.append(get_act(activation)) 140 | layers.append( 141 | nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)) 142 | if normalization: 143 | layers.append(get_norm(normalization, out_channels)) 144 | self.encode = nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | return self.encode(x) 148 | 149 | 150 | class RefineDecodeBlock(nn.Module): 151 | def __init__(self, in_channels, out_channels, 152 | normalization=None, activation=None): 153 | super().__init__() 154 | 155 | layers = [] 156 | if activation: 157 | layers.append(get_act(activation)) 158 | layers.append( 159 | nn.ConvTranspose2d(in_channels, out_channels, 3, 1, padding=1)) 160 | if normalization: 161 | layers.append(get_norm(normalization, out_channels)) 162 | 163 | if activation: 164 | layers.append(get_act(activation)) 165 | layers.append( 166 | nn.ConvTranspose2d(out_channels, out_channels, 4, 2, padding=1)) 167 | if normalization: 168 | layers.append(get_norm(normalization, out_channels)) 169 | self.decode = nn.Sequential(*layers) 170 | 171 | def forward(self, x): 172 | return self.decode(x) 173 | 174 | 175 | class RefineNet(nn.Module): 176 | def __init__(self, c_img=3, 177 | norm='instance', act_en='leaky_relu', act_de='relu'): 178 | super().__init__() 179 | 180 | c_in = c_img + c_img 181 | cnum = 64 182 | 183 | self.en_1 = nn.Conv2d(c_in, cnum, 3, 1, padding=1) 184 | self.en_2 = RefineEncodeBlock(cnum, cnum*2, normalization=norm, activation=act_en) 185 | self.en_3 = RefineEncodeBlock(cnum*2, cnum*4, normalization=norm, activation=act_en) 186 | self.en_4 = RefineEncodeBlock(cnum*4, cnum*8, normalization=norm, activation=act_en) 187 | self.en_5 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en) 188 | self.en_6 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en) 189 | self.en_7 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en) 190 | self.en_8 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en) 191 | self.en_9 = nn.Sequential( 192 | get_act(act_en), 193 | nn.Conv2d(cnum*8, cnum*8, 4, 2, padding=1)) 194 | 195 | self.de_9 = nn.Sequential( 196 | get_act(act_de), 197 | nn.ConvTranspose2d(cnum*8, cnum*8, 4, 2, padding=1), 198 | get_norm(norm, cnum*8)) 199 | self.de_8 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de) 200 | self.de_7 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de) 201 | self.de_6 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de) 202 | self.de_5 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de) 203 | self.de_4 = RefineDecodeBlock(cnum*8*2, cnum*4, normalization=norm, activation=act_de) 204 | self.de_3 = RefineDecodeBlock(cnum*4*2, cnum*2, normalization=norm, activation=act_de) 205 | self.de_2 = RefineDecodeBlock(cnum*2*2, cnum, normalization=norm, activation=act_de) 206 | self.de_1 = nn.Sequential( 207 | get_act(act_de), 208 | nn.ConvTranspose2d(cnum*2, c_img, 3, 1, padding=1)) 209 | 210 | def forward(self, x1, x2): 211 | x = torch.cat([x1, x2], 1) 212 | out_1 = self.en_1(x) 213 | out_2 = self.en_2(out_1) 214 | out_3 = self.en_3(out_2) 215 | out_4 = self.en_4(out_3) 216 | out_5 = self.en_5(out_4) 217 | out_6 = self.en_6(out_5) 218 | out_7 = self.en_7(out_6) 219 | out_8 = self.en_8(out_7) 220 | out_9 = self.en_9(out_8) 221 | 222 | dout_9 = self.de_9(out_9) 223 | dout_9_out_8 = torch.cat([dout_9, out_8], 1) 224 | dout_8 = self.de_8(dout_9_out_8) 225 | dout_8_out_7 = torch.cat([dout_8, out_7], 1) 226 | dout_7 = self.de_7(dout_8_out_7) 227 | dout_7_out_6 = torch.cat([dout_7, out_6], 1) 228 | dout_6 = self.de_6(dout_7_out_6) 229 | dout_6_out_5 = torch.cat([dout_6, out_5], 1) 230 | dout_5 = self.de_5(dout_6_out_5) 231 | dout_5_out_4 = torch.cat([dout_5, out_4], 1) 232 | dout_4 = self.de_4(dout_5_out_4) 233 | dout_4_out_3 = torch.cat([dout_4, out_3], 1) 234 | dout_3 = self.de_3(dout_4_out_3) 235 | dout_3_out_2 = torch.cat([dout_3, out_2], 1) 236 | dout_2 = self.de_2(dout_3_out_2) 237 | dout_2_out_1 = torch.cat([dout_2, out_1], 1) 238 | dout_1 = self.de_1(dout_2_out_1) 239 | 240 | return dout_1, out_4, dout_5 241 | 242 | 243 | class CSA(nn.Module): 244 | def __init__(self): 245 | super().__init__() 246 | 247 | def forward(self, x, mask): 248 | return x 249 | 250 | 251 | class InpaintNet(nn.Module): 252 | def __init__(self): 253 | super().__init__() 254 | 255 | self.coarse = CoarseNet() 256 | self.refine = RefineNet() 257 | 258 | def forward(self, image, mask): 259 | out_c = self.coarse(image) 260 | out_c = image * (1. - mask) + out_c * mask 261 | 262 | out_r, csa, csa_d = self.refine(out_c, image) 263 | out_r = image * (1. - mask) + out_r * mask 264 | 265 | return out_c, out_r, csa, csa_d 266 | 267 | 268 | class PatchDiscriminator(nn.Module): 269 | def __init__(self, c_img=3, 270 | norm='instance', act='leaky_relu'): 271 | super().__init__() 272 | 273 | c_in = c_img + c_img 274 | cnum = 64 275 | self.discriminator = nn.Sequential( 276 | nn.Conv2d(c_in, cnum, 4, 2, 1), 277 | get_act(act), 278 | 279 | nn.Conv2d(cnum, cnum*2, 4, 2, 1), 280 | get_norm(norm, cnum*2), 281 | get_act(act), 282 | 283 | nn.Conv2d(cnum*2, cnum*4, 4, 2, 1), 284 | get_norm(norm, cnum*4), 285 | get_act(act), 286 | 287 | nn.Conv2d(cnum*4, cnum*8, 4, 1, 1), 288 | get_norm(norm, cnum*8), 289 | get_act(act), 290 | 291 | nn.Conv2d(cnum*8, 1, 4, 1, 1)) 292 | 293 | def forward(self, x1, x2): 294 | x = torch.cat([x1, x2], 1) 295 | return self.discriminator(x) 296 | 297 | 298 | class FeaturePatchDiscriminator(nn.Module): 299 | def __init__(self, c_img=3, 300 | norm='instance', act='leaky_relu'): 301 | super().__init__() 302 | 303 | c_in = c_img + c_img 304 | cnum = 64 305 | self.discriminator = nn.Sequential( 306 | # VGG-16 up to 3rd pooling 307 | nn.Conv2d(c_in, cnum, kernel_size=3, padding=1), 308 | nn.ReLU(inplace=True), 309 | nn.Conv2d(cnum, cnum, kernel_size=3, padding=1), 310 | nn.ReLU(inplace=True), 311 | 312 | nn.MaxPool2d(kernel_size=2, stride=2), 313 | nn.Conv2d(cnum, cnum*2, kernel_size=3, padding=1), 314 | nn.ReLU(inplace=True), 315 | nn.Conv2d(cnum*2, cnum*2, kernel_size=3, padding=1), 316 | nn.ReLU(inplace=True), 317 | 318 | nn.MaxPool2d(kernel_size=2, stride=2), 319 | nn.Conv2d(cnum*2, cnum*4, kernel_size=3, padding=1), 320 | nn.ReLU(inplace=True), 321 | nn.Conv2d(cnum*4, cnum*4, kernel_size=3, padding=1), 322 | nn.ReLU(inplace=True), 323 | nn.Conv2d(cnum*4, cnum*4, kernel_size=3, padding=1), 324 | nn.ReLU(inplace=True), 325 | 326 | nn.MaxPool2d(kernel_size=2, stride=2), 327 | 328 | # Discriminator 329 | nn.Conv2d(cnum*4, cnum*8, 4, 2, 1), 330 | get_act(act), 331 | 332 | nn.Conv2d(cnum*8, cnum*8, 4, 1, 1), 333 | get_norm(norm, cnum*8), 334 | get_act(act), 335 | 336 | nn.Conv2d(cnum*8, cnum*8, 4, 1, 1)) 337 | 338 | def forward(self, x1, x2): 339 | x = torch.cat([x1, x2], 1) 340 | return self.discriminator(x) 341 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from PIL import Image 4 | import torch 5 | from torchvision import transforms 6 | from torchvision.utils import save_image 7 | 8 | from model import InpaintNet 9 | 10 | 11 | def norm(x): 12 | return 2. * x - 1. # [0,1] -> [-1,1] 13 | 14 | 15 | def denorm(x): 16 | out = (x + 1) / 2 # [-1,1] -> [0,1] 17 | return out.clamp_(0, 1) 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--image', type=str, 22 | help='The filename of image to be completed.') 23 | parser.add_argument('--mask', type=str, 24 | help='The filename of mask, value 255 indicates mask.') 25 | parser.add_argument('--output', default='output.png', type=str, 26 | help='Where to write output.') 27 | parser.add_argument('--checkpoint', type=str, 28 | help='The filename of pickle checkpoint.') 29 | 30 | 31 | if __name__ == "__main__": 32 | args = parser.parse_args() 33 | 34 | use_cuda = torch.cuda.is_available() 35 | device = torch.device('cuda' if use_cuda else 'cpu') 36 | 37 | g_model = InpaintNet().to(device) 38 | g_checkpoint = torch.load(args.checkpoint, map_location=device) 39 | g_model.load_state_dict(g_checkpoint) 40 | g_model.eval() 41 | 42 | to_tensor = transforms.ToTensor() 43 | 44 | img = Image.open(args.image).convert('RGB') 45 | mask = Image.open(args.mask).convert('RGB') 46 | img = to_tensor(img) 47 | mask = to_tensor(mask) 48 | _, h, w = img.shape 49 | grid = 256 50 | img = img[:, :h//grid*grid, :w//grid*grid] 51 | mask = mask[:, :h//grid*grid, :w//grid*grid] 52 | img = img.unsqueeze_(0) # CHW -> BCHW 53 | mask = mask.unsqueeze_(0) # CHW -> BCHW 54 | img = norm(img) # [0,1] -> [-1,1] 55 | mask = mask[:, 0:1, :, :] #Bx3xHxW -> Bx1xHxW 56 | img = img * (1. - mask) 57 | img = img.to(device) 58 | mask = mask.to(device) 59 | print(img.shape) 60 | 61 | import time 62 | start_time = time.time() 63 | _, result, _, _ = g_model(img, mask) 64 | print("Done in %.3f seconds!" % (time.time() - start_time)) 65 | 66 | save_image(denorm(result), args.output) 67 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import PIL 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils import data 10 | from torchvision import transforms 11 | from torchvision.utils import make_grid, save_image 12 | from tqdm import tqdm 13 | 14 | from data import DS, InfiniteSampler 15 | from loss import ConsistencyLoss, calc_gan_loss 16 | from model import InpaintNet, FeaturePatchDiscriminator, PatchDiscriminator 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--root', type=str, default='./root') 21 | parser.add_argument('--save_dir', type=str, default='./snapshots') 22 | parser.add_argument('--lr', type=float, default=2e-4, help="adam: learning rate") 23 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 24 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 25 | parser.add_argument('--max_iter', type=int, default=1000000) 26 | parser.add_argument('--batch_size', type=int, default=1) 27 | parser.add_argument('--n_threads', type=int, default=16) 28 | parser.add_argument('--save_model_interval', type=int, default=10000) 29 | parser.add_argument('--vis_interval', type=int, default=1000) 30 | parser.add_argument('--log_interval', type=int, default=10) 31 | parser.add_argument('--image_size', type=int, default=256) 32 | parser.add_argument('--resume', type=int) 33 | args = parser.parse_args() 34 | 35 | use_cuda = torch.cuda.is_available() 36 | device = torch.device('cuda' if use_cuda else 'cpu') 37 | if use_cuda: 38 | torch.backends.cudnn.benchmark = True 39 | 40 | if not os.path.exists(args.save_dir): 41 | os.makedirs('{:s}/ckpt'.format(args.save_dir)) 42 | 43 | writer = SummaryWriter() 44 | 45 | size = (args.image_size, args.image_size) 46 | train_tf = transforms.Compose([ 47 | transforms.Resize(size), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.RandomGrayscale(), 50 | transforms.ToTensor(), 51 | ]) 52 | 53 | train_set = DS(args.root, train_tf) 54 | iterator_train = iter(data.DataLoader( 55 | train_set, 56 | batch_size=args.batch_size, 57 | sampler=InfiniteSampler(len(train_set)), 58 | num_workers=args.n_threads)) 59 | print(len(train_set)) 60 | 61 | g_model = InpaintNet().to(device) 62 | fd_model = FeaturePatchDiscriminator().to(device) 63 | pd_model = PatchDiscriminator().to(device) 64 | l1 = nn.L1Loss().to(device) 65 | cons = ConsistencyLoss().to(device) 66 | 67 | start_iter = 0 68 | g_optimizer = torch.optim.Adam( 69 | g_model.parameters(), 70 | args.lr, (args.b1, args.b2)) 71 | fd_optimizer = torch.optim.Adam( 72 | fd_model.parameters(), 73 | args.lr, (args.b1, args.b2)) 74 | pd_optimizer = torch.optim.Adam( 75 | pd_model.parameters(), 76 | args.lr, (args.b1, args.b2)) 77 | 78 | if args.resume: 79 | g_checkpoint = torch.load(f'{args.save_dir}/ckpt/G_{args.resume}.pth', map_location=device) 80 | g_model.load_state_dict(g_checkpoint) 81 | fd_checkpoint = torch.load(f'{args.save_dir}/ckpt/FD_{args.resume}.pth', map_location=device) 82 | fd_model.load_state_dict(fd_checkpoint) 83 | pd_checkpoint = torch.load(f'{args.save_dir}/ckpt/PD_{args.resume}.pth', map_location=device) 84 | pd_model.load_state_dict(pd_checkpoint) 85 | print('Models restored') 86 | 87 | for i in tqdm(range(start_iter, args.max_iter)): 88 | img, mask = [x.to(device) for x in next(iterator_train)] 89 | img = 2. * img - 1. # [0,1] -> [-1,1] 90 | masked = img * (1. - mask) 91 | 92 | coarse_result, refine_result, csa, csa_d = g_model(masked, mask) 93 | 94 | fg_loss, fd_loss = calc_gan_loss(fd_model, refine_result, img) 95 | pg_loss, pd_loss = calc_gan_loss(pd_model, refine_result, img) 96 | 97 | recon_loss = l1(coarse_result, img) + l1(refine_result, img) 98 | gan_loss = fg_loss + pg_loss 99 | cons_loss = cons(csa, csa_d, img, mask) 100 | total_loss = 1*recon_loss + 0.01*cons_loss + 0.002*gan_loss 101 | g_optimizer.zero_grad() 102 | total_loss.backward(retain_graph=True) 103 | g_optimizer.step() 104 | 105 | fd_optimizer.zero_grad() 106 | fd_loss.backward(retain_graph=True) 107 | fd_optimizer.step() 108 | 109 | pd_optimizer.zero_grad() 110 | pd_loss.backward() 111 | pd_optimizer.step() 112 | 113 | if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: 114 | #torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_{i + 1}.pth') 115 | #torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_{i + 1}.pth') 116 | #torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_{i + 1}.pth') 117 | torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_10000.pth') 118 | torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_10000.pth') 119 | torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_10000.pth') 120 | 121 | if (i + 1) % args.log_interval == 0: 122 | writer.add_scalar('g_loss/recon_loss', recon_loss.item(), i + 1) 123 | writer.add_scalar('g_loss/cons_loss', cons_loss.item(), i + 1) 124 | writer.add_scalar('g_loss/gan_loss', gan_loss.item(), i + 1) 125 | writer.add_scalar('g_loss/total_loss', total_loss.item(), i + 1) 126 | writer.add_scalar('d_loss/fd_loss', fd_loss.item(), i + 1) 127 | writer.add_scalar('d_loss/pd_loss', pd_loss.item(), i + 1) 128 | 129 | def denorm(x): 130 | out = (x + 1) / 2 # [-1,1] -> [0,1] 131 | return out.clamp_(0, 1) 132 | if (i + 1) % args.vis_interval == 0: 133 | ims = torch.cat([img, masked, refine_result], dim=3) 134 | writer.add_images('raw_masked_refine', denorm(ims), i + 1) 135 | 136 | writer.close() 137 | --------------------------------------------------------------------------------