├── .gitignore ├── README.md ├── img ├── diagram.png ├── example1.png ├── example2.png └── test_imgs │ ├── .DS_Store │ ├── bikes.jpg │ ├── park.jpeg │ └── test1.jpeg ├── object_remove.pdf ├── src ├── main.py ├── models │ └── deepFill.py └── objRemove.py └── test_imgs ├── .DS_Store ├── bikes.jpg ├── park.jpeg └── test1.jpeg /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # object-remove 2 | 3 | An object removal from image system using deep learning image segmentation and inpainting techniques. 4 | 5 | ## Contents 6 | 1. [Overview](#overview) 7 | 2. [Source Code](src/) 8 | 3. [Report](object_remove.pdf) 9 | 4. [Results](#results) 10 | 5. [Dependencies](#dependencies) 11 | 12 | ## Overview 13 | Object removal from image involves two separate tasks, object detection and object removal. 14 | 15 | The first task is handled by the user drawing a bounding box around an object of interest to be removed. We could then remove all pixels inside the bounding box, but this could lead to loss of valuable information from the pixels in the box that are not part of the object. Instead Mask-RCNN, a state of the art instance segmentation model is used to get the exact mask of the object. 16 | 17 | Filling in the image is done using DeepFillv2, an image inpainting generative adversarial network which employs a gated convolution system. 18 | 19 | The result is a complete image with the object removed. 20 | 21 |

22 | 23 | 24 |

25 | 26 | ## Usage 27 | 28 | The DeepFillv2 model needs pretrained weights from [here](https://drive.google.com/u/0/uc?id=1L63oBNVgz7xSb_3hGbUdkYW1IuRgMkCa&export=download) provided by [this](https://github.com/nipponjo/deepfillv2-pytorch) repository which is a reimplementation of DeepFillv2 in Pytroch. Code for DeepFillv2 model was borrowed and slightly modified from there. 29 | 30 | 31 | 32 | Make sure to put the weights pth file in [src/models/](/src/models/). 33 | 34 | To run on example image, 35 | ``` 36 | ./src/main.py [path of image] 37 | ``` 38 | When drawing bounding box, press 'r' to clear bounding box and reset image. Once box is drawn press 'c' to continue. 39 | 40 | *Drawing bouding boxes is sometimes slow. 41 | 42 | 43 | ## Results 44 | The following are some results of the system. The user selected bounding box is shown along with the masked image and inpainted final result. 45 | 46 |

47 | 48 | 49 |

50 |

51 | 52 | 53 |

54 | 55 | ## Dependencies 56 | - python3 57 | - torch 58 | - torchvision 59 | - cv2 60 | - matplotlib 61 | - numpy 62 | 63 | 64 | -------------------------------------------------------------------------------- /img/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/diagram.png -------------------------------------------------------------------------------- /img/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/example1.png -------------------------------------------------------------------------------- /img/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/example2.png -------------------------------------------------------------------------------- /img/test_imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/.DS_Store -------------------------------------------------------------------------------- /img/test_imgs/bikes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/bikes.jpg -------------------------------------------------------------------------------- /img/test_imgs/park.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/park.jpeg -------------------------------------------------------------------------------- /img/test_imgs/test1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/test1.jpeg -------------------------------------------------------------------------------- /object_remove.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/object_remove.pdf -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import os 6 | from objRemove import ObjectRemove 7 | from models.deepFill import Generator 8 | from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights 9 | 10 | ################################## 11 | #get image path from command line# 12 | ################################## 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("image") 15 | args = parser.parse_args() 16 | image_path = args.image 17 | 18 | ###################################################### 19 | #creating Mask-RCNN model and load pretrained weights# 20 | ###################################################### 21 | for f in os.listdir('src/models'): 22 | if f.endswith('.pth'): 23 | deepfill_weights_path = os.path.join('src/models', f) 24 | print("Creating rcnn model") 25 | weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT 26 | transforms = weights.transforms() 27 | rcnn = maskrcnn_resnet50_fpn(weights=weights, progress=False) 28 | rcnn = rcnn.eval() 29 | 30 | ######################### 31 | #create inaptining model# 32 | ######################### 33 | print('Creating deepfil model') 34 | deepfill = Generator(checkpoint=deepfill_weights_path, return_flow=True) 35 | ###################### 36 | #create ObjectRemoval# 37 | ###################### 38 | model = ObjectRemove(segmentModel=rcnn, 39 | rcnn_transforms=transforms, 40 | inpaintModel=deepfill, 41 | image_path=image_path ) 42 | ##### 43 | #run# 44 | ##### 45 | output = model.run() 46 | 47 | ################# 48 | #display results# 49 | ################# 50 | img = cv2.cvtColor(model.image_orig[0].permute(1,2,0).numpy(),cv2.COLOR_RGB2BGR) 51 | boxed = cv2.rectangle(img, (model.box[0], model.box[1]),(model.box[2], model.box[3]), (0,255,0),2) 52 | boxed = cv2.cvtColor(boxed,cv2.COLOR_BGR2RGB) 53 | 54 | fig,axs = plt.subplots(1,3,layout='constrained') 55 | axs[0].imshow(boxed) 56 | axs[0].set_title('Original Image Bounding Box') 57 | axs[1].imshow(model.image_masked.permute(1,2,0).detach().numpy()) 58 | axs[1].set_title('Masked Image') 59 | axs[2].imshow(output) 60 | axs[2].set_title('Inpainted Image') 61 | plt.show() 62 | 63 | -------------------------------------------------------------------------------- /src/models/deepFill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.parametrizations import spectral_norm 6 | 7 | 8 | #code from https://github.com/nipponjo/deepfillv2-pytorch/blob/master/model/networks.py 9 | #with slight modifications 10 | 11 | # ---------------------------------------------------------------------------- 12 | 13 | def _init_conv_layer(conv, activation, mode='fan_out'): 14 | if isinstance(activation, nn.LeakyReLU): 15 | torch.nn.init.kaiming_uniform_(conv.weight, 16 | a=activation.negative_slope, 17 | nonlinearity='leaky_relu', 18 | mode=mode) 19 | elif isinstance(activation, (nn.ReLU, nn.ELU)): 20 | torch.nn.init.kaiming_uniform_(conv.weight, 21 | nonlinearity='relu', 22 | mode=mode) 23 | else: 24 | pass 25 | if conv.bias != None: 26 | torch.nn.init.zeros_(conv.bias) 27 | 28 | 29 | def output_to_image(out): 30 | out = (out[0].cpu().permute(1, 2, 0) + 1.) * 127.5 31 | out = out.to(torch.uint8).numpy() 32 | return out 33 | 34 | # ---------------------------------------------------------------------------- 35 | 36 | ################################# 37 | ########### GENERATOR ########### 38 | ################################# 39 | 40 | class GConv(nn.Module): 41 | """Implements the gated 2D convolution introduced in 42 | `Free-Form Image Inpainting with Gated Convolution`(Yu et al., 2019) 43 | """ 44 | 45 | def __init__(self, cnum_in, 46 | cnum_out, 47 | ksize, 48 | stride=1, 49 | padding='auto', 50 | rate=1, 51 | activation=nn.ELU(), 52 | bias=True 53 | ): 54 | 55 | super().__init__() 56 | 57 | padding = rate*(ksize-1)//2 if padding == 'auto' else padding 58 | self.activation = activation 59 | self.cnum_out = cnum_out 60 | num_conv_out = cnum_out if self.cnum_out == 3 or self.activation is None else 2*cnum_out 61 | self.conv = nn.Conv2d(cnum_in, 62 | num_conv_out, 63 | kernel_size=ksize, 64 | stride=stride, 65 | padding=padding, 66 | dilation=rate, 67 | bias=bias) 68 | 69 | _init_conv_layer(self.conv, activation=self.activation) 70 | 71 | self.ksize = ksize 72 | self.stride = stride 73 | self.rate = rate 74 | self.padding = padding 75 | 76 | def forward(self, x): 77 | x = self.conv(x) 78 | if self.cnum_out == 3 or self.activation is None: 79 | return x 80 | x, y = torch.split(x, self.cnum_out, dim=1) 81 | x = self.activation(x) 82 | y = torch.sigmoid(y) 83 | x = x * y 84 | return x 85 | 86 | # ---------------------------------------------------------------------------- 87 | 88 | class GDeConv(nn.Module): 89 | """Upsampling followed by convolution""" 90 | 91 | def __init__(self, cnum_in, 92 | cnum_out, 93 | padding=1): 94 | super().__init__() 95 | self.conv = GConv(cnum_in, cnum_out, 3, 1, 96 | padding=padding) 97 | 98 | def forward(self, x): 99 | x = F.interpolate(x, scale_factor=2, mode='nearest', 100 | recompute_scale_factor=False) 101 | x = self.conv(x) 102 | return x 103 | 104 | # ---------------------------------------------------------------------------- 105 | 106 | class GDownsamplingBlock(nn.Module): 107 | def __init__(self, cnum_in, 108 | cnum_out, 109 | cnum_hidden=None 110 | ): 111 | super().__init__() 112 | cnum_hidden = cnum_out if cnum_hidden == None else cnum_hidden 113 | self.conv1_downsample = GConv(cnum_in, cnum_hidden, 3, 2) 114 | self.conv2 = GConv(cnum_hidden, cnum_out, 3, 1) 115 | 116 | def forward(self, x): 117 | x = self.conv1_downsample(x) 118 | x = self.conv2(x) 119 | return x 120 | 121 | # ---------------------------------------------------------------------------- 122 | 123 | class GUpsamplingBlock(nn.Module): 124 | def __init__(self, cnum_in, 125 | cnum_out, 126 | cnum_hidden=None 127 | ): 128 | super().__init__() 129 | cnum_hidden = cnum_out if cnum_hidden == None else cnum_hidden 130 | self.conv1_upsample = GDeConv(cnum_in, cnum_hidden) 131 | self.conv2 = GConv(cnum_hidden, cnum_out, 3, 1) 132 | 133 | def forward(self, x): 134 | x = self.conv1_upsample(x) 135 | x = self.conv2(x) 136 | return x 137 | 138 | # ---------------------------------------------------------------------------- 139 | 140 | 141 | class CoarseGenerator(nn.Module): 142 | def __init__(self, cnum_in, cnum): 143 | super().__init__() 144 | self.conv1 = GConv(cnum_in, cnum//2, 5, 1, padding=2) 145 | 146 | # downsampling 147 | self.down_block1 = GDownsamplingBlock(cnum//2, cnum) 148 | self.down_block2 = GDownsamplingBlock(cnum, 2*cnum) 149 | 150 | # bottleneck 151 | self.conv_bn1 = GConv(2*cnum, 2*cnum, 3, 1) 152 | self.conv_bn2 = GConv(2*cnum, 2*cnum, 3, rate=2, padding=2) 153 | self.conv_bn3 = GConv(2*cnum, 2*cnum, 3, rate=4, padding=4) 154 | self.conv_bn4 = GConv(2*cnum, 2*cnum, 3, rate=8, padding=8) 155 | self.conv_bn5 = GConv(2*cnum, 2*cnum, 3, rate=16, padding=16) 156 | self.conv_bn6 = GConv(2*cnum, 2*cnum, 3, 1) 157 | self.conv_bn7 = GConv(2*cnum, 2*cnum, 3, 1) 158 | 159 | # upsampling 160 | self.up_block1 = GUpsamplingBlock(2*cnum, cnum) 161 | self.up_block2 = GUpsamplingBlock(cnum, cnum//4, cnum_hidden=cnum//2) 162 | 163 | # to RGB 164 | self.conv_to_rgb = GConv(cnum//4, 3, 3, 1, activation=None) 165 | self.tanh = nn.Tanh() 166 | 167 | def forward(self, x): 168 | x = self.conv1(x) 169 | 170 | # downsampling 171 | x = self.down_block1(x) 172 | x = self.down_block2(x) 173 | 174 | # bottleneck 175 | x = self.conv_bn1(x) 176 | x = self.conv_bn2(x) 177 | x = self.conv_bn3(x) 178 | x = self.conv_bn4(x) 179 | x = self.conv_bn5(x) 180 | x = self.conv_bn6(x) 181 | x = self.conv_bn7(x) 182 | 183 | # upsampling 184 | x = self.up_block1(x) 185 | x = self.up_block2(x) 186 | 187 | # to RGB 188 | x = self.conv_to_rgb(x) 189 | x = self.tanh(x) 190 | return x 191 | 192 | # ---------------------------------------------------------------------------- 193 | 194 | class FineGenerator(nn.Module): 195 | def __init__(self, cnum, return_flow=False): 196 | super().__init__() 197 | 198 | ### CONV BRANCH (B1) ### 199 | self.conv_conv1 = GConv(3, cnum//2, 5, 1, padding=2) 200 | 201 | # downsampling 202 | self.conv_down_block1 = GDownsamplingBlock( 203 | cnum//2, cnum, cnum_hidden=cnum//2) 204 | self.conv_down_block2 = GDownsamplingBlock( 205 | cnum, 2*cnum, cnum_hidden=cnum) 206 | 207 | # bottleneck 208 | self.conv_conv_bn1 = GConv(2*cnum, 2*cnum, 3, 1) 209 | self.conv_conv_bn2 = GConv(2*cnum, 2*cnum, 3, rate=2, padding=2) 210 | self.conv_conv_bn3 = GConv(2*cnum, 2*cnum, 3, rate=4, padding=4) 211 | self.conv_conv_bn4 = GConv(2*cnum, 2*cnum, 3, rate=8, padding=8) 212 | self.conv_conv_bn5 = GConv(2*cnum, 2*cnum, 3, rate=16, padding=16) 213 | 214 | ### ATTENTION BRANCH (B2) ### 215 | self.ca_conv1 = GConv(3, cnum//2, 5, 1, padding=2) 216 | 217 | # downsampling 218 | self.ca_down_block1 = GDownsamplingBlock( 219 | cnum//2, cnum, cnum_hidden=cnum//2) 220 | self.ca_down_block2 = GDownsamplingBlock(cnum, 2*cnum) 221 | 222 | # bottleneck 223 | self.ca_conv_bn1 = GConv(2*cnum, 2*cnum, 3, 1, activation=nn.ReLU()) 224 | self.contextual_attention = ContextualAttention(ksize=3, 225 | stride=1, 226 | rate=2, 227 | fuse_k=3, 228 | softmax_scale=10, 229 | fuse=True, 230 | device_ids=None, 231 | return_flow=return_flow, 232 | n_down=2) 233 | self.ca_conv_bn4 = GConv(2*cnum, 2*cnum, 3, 1) 234 | self.ca_conv_bn5 = GConv(2*cnum, 2*cnum, 3, 1) 235 | 236 | ### UNITED BRANCHES ### 237 | self.conv_bn6 = GConv(4*cnum, 2*cnum, 3, 1) 238 | self.conv_bn7 = GConv(2*cnum, 2*cnum, 3, 1) 239 | 240 | # upsampling 241 | self.up_block1 = GUpsamplingBlock(2*cnum, cnum) 242 | self.up_block2 = GUpsamplingBlock(cnum, cnum//4, cnum_hidden=cnum//2) 243 | 244 | # to RGB 245 | self.conv_to_rgb = GConv(cnum//4, 3, 3, 1, activation=None) 246 | self.tanh = nn.Tanh() 247 | 248 | def forward(self, x, mask): 249 | xnow = x 250 | 251 | ### CONV BRANCH ### 252 | x = self.conv_conv1(xnow) 253 | # downsampling 254 | x = self.conv_down_block1(x) 255 | x = self.conv_down_block2(x) 256 | 257 | # bottleneck 258 | x = self.conv_conv_bn1(x) 259 | x = self.conv_conv_bn2(x) 260 | x = self.conv_conv_bn3(x) 261 | x = self.conv_conv_bn4(x) 262 | x = self.conv_conv_bn5(x) 263 | x_hallu = x 264 | 265 | ### ATTENTION BRANCH ### 266 | x = self.ca_conv1(xnow) 267 | # downsampling 268 | x = self.ca_down_block1(x) 269 | x = self.ca_down_block2(x) 270 | 271 | # bottleneck 272 | x = self.ca_conv_bn1(x) 273 | x, offset_flow = self.contextual_attention(x, x, mask) 274 | x = self.ca_conv_bn4(x) 275 | x = self.ca_conv_bn5(x) 276 | pm = x 277 | 278 | # concatenate outputs from both branches 279 | x = torch.cat([x_hallu, pm], dim=1) 280 | 281 | ### UNITED BRANCHES ### 282 | x = self.conv_bn6(x) 283 | x = self.conv_bn7(x) 284 | 285 | # upsampling 286 | x = self.up_block1(x) 287 | x = self.up_block2(x) 288 | 289 | # to RGB 290 | x = self.conv_to_rgb(x) 291 | x = self.tanh(x) 292 | 293 | return x, offset_flow 294 | 295 | # ---------------------------------------------------------------------------- 296 | 297 | class Generator(nn.Module): 298 | def __init__(self, cnum_in=5, cnum=48, return_flow=False, checkpoint=None): 299 | super().__init__() 300 | self.stage1 = CoarseGenerator(cnum_in, cnum) 301 | self.stage2 = FineGenerator(cnum, return_flow) 302 | self.return_flow = return_flow 303 | 304 | if checkpoint is not None: 305 | generator_state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['G'] 306 | self.load_state_dict(generator_state_dict, strict=True) 307 | 308 | self.eval() 309 | 310 | def forward(self, x, mask): 311 | xin = x 312 | # get coarse result 313 | x_stage1 = self.stage1(x) 314 | # inpaint input with coarse result 315 | x = x_stage1*mask + xin[:, 0:3, :, :]*(1.-mask) 316 | # get refined result 317 | x_stage2, offset_flow = self.stage2(x, mask) 318 | 319 | if self.return_flow: 320 | return x_stage1, x_stage2, offset_flow 321 | 322 | return x_stage1, x_stage2 323 | 324 | @torch.inference_mode() 325 | def infer(self, 326 | image, 327 | mask, 328 | return_vals=['inpainted', 'stage1'], 329 | device='cuda'): 330 | """ 331 | Args: 332 | image: 333 | mask: 334 | return_vals: inpainted, stage1, stage2, flow 335 | Returns: 336 | """ 337 | 338 | _, h, w = image.shape 339 | grid = 8 340 | 341 | image = image[:3, :h//grid*grid, :w//grid*grid].unsqueeze(0) 342 | mask = mask[0:1, :h//grid*grid, :w//grid*grid].unsqueeze(0) 343 | 344 | image = (image*2 - 1.) # map image values to [-1, 1] range 345 | # 1.: masked 0.: unmasked 346 | mask = (mask > 0.).to(dtype=torch.float32) 347 | 348 | image_masked = image * (1.-mask) # mask image 349 | 350 | ones_x = torch.ones_like(image_masked)[:, 0:1, :, :] # sketch channel 351 | x = torch.cat([image_masked, ones_x, ones_x*mask], 352 | dim=1) # concatenate channels 353 | 354 | if self.return_flow: 355 | x_stage1, x_stage2, offset_flow = self.forward(x, mask) 356 | else: 357 | x_stage1, x_stage2 = self.forward(x, mask) 358 | 359 | image_compl = image * (1.-mask) + x_stage2 * mask 360 | 361 | output = [] 362 | for return_val in return_vals: 363 | if return_val.lower() == 'stage1': 364 | output.append(output_to_image(x_stage1)) 365 | elif return_val.lower() == 'stage2': 366 | output.append(output_to_image(x_stage2)) 367 | elif return_val.lower() == 'inpainted': 368 | output.append(output_to_image(image_compl)) 369 | elif return_val.lower() == 'flow' and self.return_flow: 370 | output.append(offset_flow) 371 | else: 372 | print(f'Invalid return value: {return_val}') 373 | 374 | return output 375 | 376 | # ---------------------------------------------------------------------------- 377 | 378 | #################################### 379 | ####### CONTEXTUAL ATTENTION ####### 380 | #################################### 381 | 382 | """ 383 | adapted from: https://github.com/daa233/generative-inpainting-pytorch/blob/master/model/networks.py 384 | """ 385 | 386 | class ContextualAttention(nn.Module): 387 | """ Contextual attention layer implementation. \\ 388 | Contextual attention is first introduced in publication: \\ 389 | `Generative Image Inpainting with Contextual Attention`, Yu et al \\ 390 | Args: 391 | ksize: Kernel size for contextual attention 392 | stride: Stride for extracting patches from b 393 | rate: Dilation for matching 394 | softmax_scale: Scaled softmax for attention 395 | """ 396 | 397 | def __init__(self, 398 | ksize=3, 399 | stride=1, 400 | rate=1, 401 | fuse_k=3, 402 | softmax_scale=10., 403 | n_down=2, 404 | fuse=False, 405 | return_flow=False, 406 | device_ids=None): 407 | super(ContextualAttention, self).__init__() 408 | self.ksize = ksize 409 | self.stride = stride 410 | self.rate = rate 411 | self.fuse_k = fuse_k 412 | self.softmax_scale = softmax_scale 413 | self.fuse = fuse 414 | self.device_ids = device_ids 415 | self.n_down = n_down 416 | self.return_flow = return_flow 417 | self.register_buffer('fuse_weight', torch.eye( 418 | fuse_k).view(1, 1, fuse_k, fuse_k)) 419 | 420 | def forward(self, f, b, mask=None): 421 | """ 422 | Args: 423 | f: Input feature to match (foreground). 424 | b: Input feature for match (background). 425 | mask: Input mask for b, indicating patches not available. 426 | """ 427 | device = f.device 428 | # get shapes 429 | raw_int_fs, raw_int_bs = list(f.size()), list(b.size()) # b*c*h*w 430 | 431 | # extract patches from background with stride and rate 432 | kernel = 2 * self.rate 433 | # raw_w is extracted for reconstruction 434 | raw_w = extract_image_patches(b, ksize=kernel, 435 | stride=self.rate*self.stride, 436 | rate=1, padding='auto') # [N, C*k*k, L] 437 | # raw_shape: [N, C, k, k, L] 438 | raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1) 439 | raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] 440 | raw_w_groups = torch.split(raw_w, 1, dim=0) 441 | 442 | # downscaling foreground option: downscaling both foreground and 443 | # background for matching and use original background for reconstruction. 444 | f = F.interpolate(f, scale_factor=1./self.rate, 445 | mode='nearest', recompute_scale_factor=False) 446 | b = F.interpolate(b, scale_factor=1./self.rate, 447 | mode='nearest', recompute_scale_factor=False) 448 | int_fs, int_bs = list(f.size()), list(b.size()) # b*c*h*w 449 | # split tensors along the batch dimension 450 | f_groups = torch.split(f, 1, dim=0) 451 | # w shape: [N, C*k*k, L] 452 | w = extract_image_patches(b, ksize=self.ksize, 453 | stride=self.stride, 454 | rate=1, padding='auto') 455 | # w shape: [N, C, k, k, L] 456 | w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1) 457 | w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] 458 | w_groups = torch.split(w, 1, dim=0) 459 | 460 | # process mask 461 | if mask is None: 462 | mask = torch.zeros( 463 | [int_bs[0], 1, int_bs[2], int_bs[3]], device=device) 464 | else: 465 | mask = F.interpolate( 466 | mask, scale_factor=1./((2**self.n_down)*self.rate), mode='nearest', recompute_scale_factor=False) 467 | int_ms = list(mask.size()) 468 | # m shape: [N, C*k*k, L] 469 | m = extract_image_patches(mask, ksize=self.ksize, 470 | stride=self.stride, 471 | rate=1, padding='auto') 472 | # m shape: [N, C, k, k, L] 473 | m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1) 474 | m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k] 475 | m = m[0] # m shape: [L, C, k, k] 476 | # mm shape: [L, 1, 1, 1] 477 | 478 | mm = (torch.mean(m, dim=[1, 2, 3], keepdim=True) == 0.).to( 479 | torch.float32) 480 | mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1] 481 | 482 | y = [] 483 | offsets = [] 484 | scale = self.softmax_scale # to fit the PyTorch tensor image value range 485 | 486 | for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): 487 | ''' 488 | O => output channel as a conv filter 489 | I => input channel as a conv filter 490 | xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) 491 | wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) 492 | raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) 493 | ''' 494 | # conv for compare 495 | wi = wi[0] # [L, C, k, k] 496 | max_wi = torch.sqrt(torch.sum(torch.square(wi), dim=[ 497 | 1, 2, 3], keepdim=True)).clamp_min(1e-4) 498 | wi_normed = wi / max_wi 499 | # xi shape: [1, C, H, W], yi shape: [1, L, H, W] 500 | yi = F.conv2d(xi, wi_normed, stride=1, padding=( 501 | self.ksize-1)//2) # [1, L, H, W] 502 | # conv implementation for fuse scores to encourage large patches 503 | if self.fuse: 504 | # make all of depth to spatial resolution 505 | # (B=1, I=1, H=32*32, W=32*32) 506 | yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) 507 | # (B=1, C=1, H=32*32, W=32*32) 508 | yi = F.conv2d(yi, self.fuse_weight, stride=1, 509 | padding=(self.fuse_k-1)//2) 510 | # (B=1, 32, 32, 32, 32) 511 | yi = yi.contiguous().view( 512 | 1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) 513 | yi = yi.permute(0, 2, 1, 4, 3) 514 | 515 | yi = yi.contiguous().view( 516 | 1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) 517 | yi = F.conv2d(yi, self.fuse_weight, stride=1, 518 | padding=(self.fuse_k-1)//2) 519 | yi = yi.contiguous().view( 520 | 1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]) 521 | yi = yi.permute(0, 2, 1, 4, 3).contiguous() 522 | 523 | # (B=1, C=32*32, H=32, W=32) 524 | yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) 525 | # softmax to match 526 | yi = yi * mm 527 | yi = F.softmax(yi*scale, dim=1) 528 | yi = yi * mm # [1, L, H, W] 529 | 530 | if self.return_flow: 531 | offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W 532 | 533 | if int_bs != int_fs: 534 | # Normalize the offset value to match foreground dimension 535 | times = (int_fs[2]*int_fs[3])/(int_bs[2]*int_bs[3]) 536 | offset = ((offset + 1).float() * times - 1).to(torch.int64) 537 | offset = torch.cat([torch.div(offset, int_fs[3], rounding_mode='trunc'), 538 | offset % int_fs[3]], dim=1) # 1*2*H*W 539 | offsets.append(offset) 540 | 541 | # deconv for patch pasting 542 | wi_center = raw_wi[0] 543 | yi = F.conv_transpose2d( 544 | yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64) 545 | y.append(yi) 546 | 547 | y = torch.cat(y, dim=0) # back to the mini-batch 548 | y = y.contiguous().view(raw_int_fs) 549 | 550 | if not self.return_flow: 551 | return y, None 552 | 553 | offsets = torch.cat(offsets, dim=0) 554 | offsets = offsets.view(int_fs[0], 2, *int_fs[2:]) 555 | 556 | # case1: visualize optical flow: minus current position 557 | h_add = torch.arange(int_fs[2], device=device).view( 558 | [1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3]) 559 | w_add = torch.arange(int_fs[3], device=device).view( 560 | [1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1) 561 | offsets = offsets - torch.cat([h_add, w_add], dim=1) 562 | # to flow image 563 | flow = torch.from_numpy(flow_to_image( 564 | offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255. 565 | flow = flow.permute(0, 3, 1, 2) 566 | # case2: visualize which pixels are attended 567 | # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy())) 568 | 569 | if self.rate != 1: 570 | flow = F.interpolate(flow, scale_factor=self.rate, 571 | mode='bilinear', align_corners=True) 572 | 573 | return y, flow 574 | 575 | # ---------------------------------------------------------------------------- 576 | 577 | def flow_to_image(flow): 578 | """Transfer flow map to image. 579 | Part of code forked from flownet. 580 | """ 581 | out = [] 582 | maxu = -999. 583 | maxv = -999. 584 | minu = 999. 585 | minv = 999. 586 | maxrad = -1 587 | for i in range(flow.shape[0]): 588 | u = flow[i, :, :, 0] 589 | v = flow[i, :, :, 1] 590 | idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7) 591 | u[idxunknow] = 0 592 | v[idxunknow] = 0 593 | maxu = max(maxu, np.max(u)) 594 | minu = min(minu, np.min(u)) 595 | maxv = max(maxv, np.max(v)) 596 | minv = min(minv, np.min(v)) 597 | rad = np.sqrt(u ** 2 + v ** 2) 598 | maxrad = max(maxrad, np.max(rad)) 599 | u = u / (maxrad + np.finfo(float).eps) 600 | v = v / (maxrad + np.finfo(float).eps) 601 | img = compute_color(u, v) 602 | out.append(img) 603 | return np.float32(np.uint8(out)) 604 | 605 | # ---------------------------------------------------------------------------- 606 | 607 | def compute_color(u, v): 608 | h, w = u.shape 609 | img = np.zeros([h, w, 3]) 610 | nanIdx = np.isnan(u) | np.isnan(v) 611 | u[nanIdx] = 0 612 | v[nanIdx] = 0 613 | # colorwheel = COLORWHEEL 614 | colorwheel = make_color_wheel() 615 | ncols = np.size(colorwheel, 0) 616 | rad = np.sqrt(u ** 2 + v ** 2) 617 | a = np.arctan2(-v, -u) / np.pi 618 | fk = (a + 1) / 2 * (ncols - 1) + 1 619 | k0 = np.floor(fk).astype(int) 620 | k1 = k0 + 1 621 | k1[k1 == ncols + 1] = 1 622 | f = fk - k0 623 | for i in range(np.size(colorwheel, 1)): 624 | tmp = colorwheel[:, i] 625 | col0 = tmp[k0 - 1] / 255 626 | col1 = tmp[k1 - 1] / 255 627 | col = (1 - f) * col0 + f * col1 628 | idx = rad <= 1 629 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 630 | notidx = np.logical_not(idx) 631 | col[notidx] *= 0.75 632 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 633 | return img 634 | 635 | # ---------------------------------------------------------------------------- 636 | 637 | def make_color_wheel(): 638 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) 639 | ncols = RY + YG + GC + CB + BM + MR 640 | colorwheel = np.zeros([ncols, 3]) 641 | col = 0 642 | # RY 643 | colorwheel[0:RY, 0] = 255 644 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 645 | col += RY 646 | # YG 647 | colorwheel[col:col + YG, 0] = 255 - \ 648 | np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 649 | colorwheel[col:col + YG, 1] = 255 650 | col += YG 651 | # GC 652 | colorwheel[col:col + GC, 1] = 255 653 | colorwheel[col:col + GC, 654 | 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 655 | col += GC 656 | # CB 657 | colorwheel[col:col + CB, 1] = 255 - \ 658 | np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 659 | colorwheel[col:col + CB, 2] = 255 660 | col += CB 661 | # BM 662 | colorwheel[col:col + BM, 2] = 255 663 | colorwheel[col:col + BM, 664 | 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 665 | col += + BM 666 | # MR 667 | colorwheel[col:col + MR, 2] = 255 - \ 668 | np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 669 | colorwheel[col:col + MR, 0] = 255 670 | return colorwheel 671 | 672 | # ---------------------------------------------------------------------------- 673 | 674 | 675 | def extract_image_patches(images, ksize, stride, rate, padding='auto'): 676 | """ 677 | Extracts sliding local blocks \\ 678 | see also: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html 679 | """ 680 | 681 | padding = rate*(ksize-1)//2 if padding == 'auto' else padding 682 | 683 | unfold = torch.nn.Unfold(kernel_size=ksize, 684 | dilation=rate, 685 | padding=padding, 686 | stride=stride) 687 | patches = unfold(images) 688 | return patches # [N, C*k*k, L], L is the total number of such blocks 689 | 690 | # ---------------------------------------------------------------------------- 691 | 692 | ################################# 693 | ######### DISCRIMINATOR ######### 694 | ################################# 695 | 696 | class Conv2DSpectralNorm(nn.Conv2d): 697 | """Convolution layer that applies Spectral Normalization before every call.""" 698 | 699 | def __init__(self, cnum_in, 700 | cnum_out, kernel_size, stride, padding=0, n_iter=1, eps=1e-12, bias=True): 701 | super().__init__(cnum_in, 702 | cnum_out, kernel_size=kernel_size, 703 | stride=stride, padding=padding, bias=bias) 704 | self.register_buffer("weight_u", torch.empty(self.weight.size(0), 1)) 705 | nn.init.trunc_normal_(self.weight_u) 706 | self.n_iter = n_iter 707 | self.eps = eps 708 | 709 | def l2_norm(self, x): 710 | return F.normalize(x, p=2, dim=0, eps=self.eps) 711 | 712 | def forward(self, x): 713 | 714 | weight_orig = self.weight.flatten(1).detach() 715 | 716 | for _ in range(self.n_iter): 717 | v = self.l2_norm(weight_orig.t() @ self.weight_u) 718 | self.weight_u = self.l2_norm(weight_orig @ v) 719 | 720 | sigma = self.weight_u.t() @ weight_orig @ v 721 | self.weight.data.div_(sigma) 722 | 723 | x = super().forward(x) 724 | 725 | return x 726 | 727 | # ---------------------------------------------------------------------------- 728 | 729 | class DConv(nn.Module): 730 | def __init__(self, cnum_in, 731 | cnum_out, ksize=5, stride=2, padding='auto'): 732 | super().__init__() 733 | padding = (ksize-1)//2 if padding == 'auto' else padding 734 | self.conv_sn = Conv2DSpectralNorm( 735 | cnum_in, cnum_out, ksize, stride, padding) 736 | #self.conv_sn = spectral_norm(nn.Conv2d(cnum_in, cnum_out, ksize, stride, padding)) 737 | self.leaky = nn.LeakyReLU(negative_slope=0.2) 738 | 739 | def forward(self, x): 740 | x = self.conv_sn(x) 741 | x = self.leaky(x) 742 | return x 743 | 744 | # ---------------------------------------------------------------------------- 745 | 746 | class Discriminator(nn.Module): 747 | def __init__(self, cnum_in, cnum): 748 | super().__init__() 749 | self.conv1 = DConv(cnum_in, cnum) 750 | self.conv2 = DConv(cnum, 2*cnum) 751 | self.conv3 = DConv(2*cnum, 4*cnum) 752 | self.conv4 = DConv(4*cnum, 4*cnum) 753 | self.conv5 = DConv(4*cnum, 4*cnum) 754 | self.conv6 = DConv(4*cnum, 4*cnum) 755 | 756 | def forward(self, x): 757 | x = self.conv1(x) 758 | x = self.conv2(x) 759 | x = self.conv3(x) 760 | x = self.conv4(x) 761 | x = self.conv5(x) 762 | x = self.conv6(x) 763 | x = nn.Flatten()(x) 764 | 765 | return x 766 | -------------------------------------------------------------------------------- /src/objRemove.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import cv2 4 | import numpy as np 5 | import torchvision.transforms as T 6 | from torchvision.io import read_image 7 | 8 | class ObjectRemove(): 9 | 10 | def __init__(self, segmentModel = None, rcnn_transforms = None, inpaintModel= None, image_path = '') -> None: 11 | self.segmentModel = segmentModel 12 | self.inpaintModel = inpaintModel 13 | self.rcnn_transforms = rcnn_transforms 14 | self.image_path = image_path 15 | self.highest_prob_mask = None 16 | self.image_orig = None 17 | self.image_masked = None 18 | self.box = None 19 | 20 | def run(self): 21 | ''' 22 | Main run program 23 | ''' 24 | #read in image and transform 25 | print('Reading in image') 26 | images = self.preprocess_image() 27 | self.image_orig = images 28 | 29 | print("segmentation") 30 | #segmentation 31 | output = self.segment(images) 32 | out = output[0] 33 | 34 | print('user click') 35 | #user click 36 | ref_points = self.user_click() 37 | self.box = ref_points 38 | self.highest_prob_mask = self.find_mask(out, ref_points) 39 | 40 | self.highest_prob_mask[self.highest_prob_mask > 0.1] = 1 41 | self.highest_prob_mask[self.highest_prob_mask <0.1] = 0 42 | self.image_masked = (images[0]*(1-self.highest_prob_mask)) 43 | print('inpaint') 44 | #inpaint 45 | output = self.inpaint() 46 | 47 | #return final inpainted image 48 | return output 49 | 50 | def percent_within(self,nonzeros, rectangle): 51 | ''' 52 | Calculates percent of mask inside rectangle 53 | ''' 54 | rect_ul, rect_br = rectangle 55 | inside_count = 0 56 | for _,y,x in nonzeros: 57 | if x >= rect_ul[0] and x<= rect_br[0] and y <= rect_br[1] and y>= rect_ul[1]: 58 | inside_count+=1 59 | return inside_count / len(nonzeros) 60 | 61 | def iou(self, boxes_a, boxes_b): 62 | ''' 63 | Calculates IOU between all pairs of boxes 64 | 65 | boxes_a and boxes_b are matrices with each row representing the 4 coords of a box 66 | ''' 67 | 68 | x1 = np.array([boxes_a[:,0], boxes_b[:,0]]).max(axis=0) 69 | y1 = np.array([boxes_a[:,1], boxes_b[:,1]]).max(axis=0) 70 | x2 = np.array([boxes_a[:,2], boxes_b[:,2]]).min(axis=0) 71 | y2 = np.array([boxes_a[:,3], boxes_b[:,3]]).min(axis=0) 72 | 73 | w = x2-x1 74 | h = y2-y1 75 | w[w<0] = 0 76 | h[h<0] = 0 77 | 78 | intersect = w* h 79 | 80 | area_a = (boxes_a[:,2] - boxes_a[:,0]) * (boxes_a[:,3] - boxes_a[:,1]) 81 | area_b = (boxes_b[:,2] - boxes_b[:,0]) * (boxes_b[:,3] - boxes_b[:,1]) 82 | 83 | union = area_a + area_b - intersect 84 | 85 | return intersect / (union + 0.00001) 86 | 87 | def find_mask(self, rcnn_output, rectangle): 88 | ''' 89 | Finds the mask with highest probability in the rectangle given 90 | 91 | ''' 92 | bounding_boxes= rcnn_output['boxes'].detach().numpy() 93 | masks = rcnn_output['masks'] 94 | 95 | ref_boxes = np.array([rectangle], dtype=object) 96 | ref_boxes = np.repeat(ref_boxes, bounding_boxes.shape[0], axis=0) 97 | 98 | ious= self.iou(ref_boxes, bounding_boxes) 99 | 100 | best_ind = np.argmax(ious) 101 | 102 | return masks[best_ind] 103 | 104 | 105 | #compare masks pixelwise 106 | ''' 107 | masks = rcnn_output['masks'] 108 | #go through each nonzero point in the mask and count how many points are within the rectangles 109 | highest_prob_mask = None 110 | percent_within,min_diff = 0,float('inf') 111 | #print('masks lenght:', len(masks)) 112 | 113 | 114 | for m in range(len(masks)): 115 | #masks[m][masks[m] > 0.5] = 255.0 116 | #masks[m][masks[m] < 0.5] = 0.0 117 | nonzeros = np.nonzero(masks[m]) 118 | #diff = rect_area - len(nonzeros) 119 | p = self.percent_within(nonzeros, rectangle) 120 | if p > percent_within: 121 | highest_prob_mask = masks[m] 122 | percent_within = p 123 | print(p) 124 | return highest_prob_mask 125 | ''' 126 | 127 | def preprocess_image(self): 128 | ''' 129 | Read in image and prepare for segmentation 130 | ''' 131 | img= [read_image(self.image_path)] 132 | _,h,w = img[0].shape 133 | size = min(h,w) 134 | if size > 512: 135 | img[0] = T.Resize(512, max_size=680, antialias=True)(img[0]) 136 | 137 | images_transformed = [self.rcnn_transforms(d) for d in img] 138 | return images_transformed 139 | 140 | 141 | def segment(self,images): 142 | out = self.segmentModel(images) 143 | return out 144 | 145 | def user_click(self): 146 | ''' 147 | Get user input for object to remove 148 | 149 | Returns the rectangle bounding box give by user as two points 150 | ''' 151 | ref_point = [] 152 | cache=None 153 | draw = False 154 | 155 | 156 | def click(event, x, y, flags, param): 157 | nonlocal ref_point,cache,img, draw 158 | if event == cv2.EVENT_LBUTTONDOWN: 159 | draw = True 160 | ref_point = [x, y] 161 | cache = copy.deepcopy(img) 162 | 163 | elif event == cv2.EVENT_MOUSEMOVE: 164 | if draw: 165 | img = copy.deepcopy(cache) 166 | cv2.rectangle(img, (ref_point[0], ref_point[1]), (x,y), (0, 255, 0), 2) 167 | cv2.imshow('image',img) 168 | 169 | 170 | elif event == cv2.EVENT_LBUTTONUP: 171 | draw = False 172 | ref_point += [x,y] 173 | ref_point.append((x, y)) 174 | cv2.rectangle(img, (ref_point[0], ref_point[1]), (ref_point[2], ref_point[3]), (0, 255, 0), 2) 175 | cv2.imshow("image", img) 176 | 177 | 178 | img = self.image_orig[0].permute(1,2,0).numpy() 179 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 180 | clone = img.copy() 181 | 182 | cv2.namedWindow("image") 183 | 184 | cv2.setMouseCallback('image', click) 185 | 186 | while True: 187 | cv2.imshow("image", img) 188 | key = cv2.waitKey(1) & 0xFF 189 | 190 | if key == ord("r"): 191 | img = clone.copy() 192 | 193 | elif key == ord("c"): 194 | break 195 | cv2.destroyAllWindows() 196 | 197 | return ref_point 198 | 199 | def inpaint(self): 200 | output = self.inpaintModel.infer(self.image_orig[0], self.highest_prob_mask, return_vals=['inpainted']) 201 | return output[0] 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /test_imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/.DS_Store -------------------------------------------------------------------------------- /test_imgs/bikes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/bikes.jpg -------------------------------------------------------------------------------- /test_imgs/park.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/park.jpeg -------------------------------------------------------------------------------- /test_imgs/test1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/test1.jpeg --------------------------------------------------------------------------------