├── .gitignore
├── README.md
├── img
├── diagram.png
├── example1.png
├── example2.png
└── test_imgs
│ ├── .DS_Store
│ ├── bikes.jpg
│ ├── park.jpeg
│ └── test1.jpeg
├── object_remove.pdf
├── src
├── main.py
├── models
│ └── deepFill.py
└── objRemove.py
└── test_imgs
├── .DS_Store
├── bikes.jpg
├── park.jpeg
└── test1.jpeg
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # object-remove
2 |
3 | An object removal from image system using deep learning image segmentation and inpainting techniques.
4 |
5 | ## Contents
6 | 1. [Overview](#overview)
7 | 2. [Source Code](src/)
8 | 3. [Report](object_remove.pdf)
9 | 4. [Results](#results)
10 | 5. [Dependencies](#dependencies)
11 |
12 | ## Overview
13 | Object removal from image involves two separate tasks, object detection and object removal.
14 |
15 | The first task is handled by the user drawing a bounding box around an object of interest to be removed. We could then remove all pixels inside the bounding box, but this could lead to loss of valuable information from the pixels in the box that are not part of the object. Instead Mask-RCNN, a state of the art instance segmentation model is used to get the exact mask of the object.
16 |
17 | Filling in the image is done using DeepFillv2, an image inpainting generative adversarial network which employs a gated convolution system.
18 |
19 | The result is a complete image with the object removed.
20 |
21 |
22 |
23 |
24 |
25 |
26 | ## Usage
27 |
28 | The DeepFillv2 model needs pretrained weights from [here](https://drive.google.com/u/0/uc?id=1L63oBNVgz7xSb_3hGbUdkYW1IuRgMkCa&export=download) provided by [this](https://github.com/nipponjo/deepfillv2-pytorch) repository which is a reimplementation of DeepFillv2 in Pytroch. Code for DeepFillv2 model was borrowed and slightly modified from there.
29 |
30 |
31 |
32 | Make sure to put the weights pth file in [src/models/](/src/models/).
33 |
34 | To run on example image,
35 | ```
36 | ./src/main.py [path of image]
37 | ```
38 | When drawing bounding box, press 'r' to clear bounding box and reset image. Once box is drawn press 'c' to continue.
39 |
40 | *Drawing bouding boxes is sometimes slow.
41 |
42 |
43 | ## Results
44 | The following are some results of the system. The user selected bounding box is shown along with the masked image and inpainted final result.
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | ## Dependencies
56 | - python3
57 | - torch
58 | - torchvision
59 | - cv2
60 | - matplotlib
61 | - numpy
62 |
63 |
64 |
--------------------------------------------------------------------------------
/img/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/diagram.png
--------------------------------------------------------------------------------
/img/example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/example1.png
--------------------------------------------------------------------------------
/img/example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/example2.png
--------------------------------------------------------------------------------
/img/test_imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/.DS_Store
--------------------------------------------------------------------------------
/img/test_imgs/bikes.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/bikes.jpg
--------------------------------------------------------------------------------
/img/test_imgs/park.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/park.jpeg
--------------------------------------------------------------------------------
/img/test_imgs/test1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/img/test_imgs/test1.jpeg
--------------------------------------------------------------------------------
/object_remove.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/object_remove.pdf
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | import cv2
5 | import os
6 | from objRemove import ObjectRemove
7 | from models.deepFill import Generator
8 | from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
9 |
10 | ##################################
11 | #get image path from command line#
12 | ##################################
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("image")
15 | args = parser.parse_args()
16 | image_path = args.image
17 |
18 | ######################################################
19 | #creating Mask-RCNN model and load pretrained weights#
20 | ######################################################
21 | for f in os.listdir('src/models'):
22 | if f.endswith('.pth'):
23 | deepfill_weights_path = os.path.join('src/models', f)
24 | print("Creating rcnn model")
25 | weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
26 | transforms = weights.transforms()
27 | rcnn = maskrcnn_resnet50_fpn(weights=weights, progress=False)
28 | rcnn = rcnn.eval()
29 |
30 | #########################
31 | #create inaptining model#
32 | #########################
33 | print('Creating deepfil model')
34 | deepfill = Generator(checkpoint=deepfill_weights_path, return_flow=True)
35 | ######################
36 | #create ObjectRemoval#
37 | ######################
38 | model = ObjectRemove(segmentModel=rcnn,
39 | rcnn_transforms=transforms,
40 | inpaintModel=deepfill,
41 | image_path=image_path )
42 | #####
43 | #run#
44 | #####
45 | output = model.run()
46 |
47 | #################
48 | #display results#
49 | #################
50 | img = cv2.cvtColor(model.image_orig[0].permute(1,2,0).numpy(),cv2.COLOR_RGB2BGR)
51 | boxed = cv2.rectangle(img, (model.box[0], model.box[1]),(model.box[2], model.box[3]), (0,255,0),2)
52 | boxed = cv2.cvtColor(boxed,cv2.COLOR_BGR2RGB)
53 |
54 | fig,axs = plt.subplots(1,3,layout='constrained')
55 | axs[0].imshow(boxed)
56 | axs[0].set_title('Original Image Bounding Box')
57 | axs[1].imshow(model.image_masked.permute(1,2,0).detach().numpy())
58 | axs[1].set_title('Masked Image')
59 | axs[2].imshow(output)
60 | axs[2].set_title('Inpainted Image')
61 | plt.show()
62 |
63 |
--------------------------------------------------------------------------------
/src/models/deepFill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.utils.parametrizations import spectral_norm
6 |
7 |
8 | #code from https://github.com/nipponjo/deepfillv2-pytorch/blob/master/model/networks.py
9 | #with slight modifications
10 |
11 | # ----------------------------------------------------------------------------
12 |
13 | def _init_conv_layer(conv, activation, mode='fan_out'):
14 | if isinstance(activation, nn.LeakyReLU):
15 | torch.nn.init.kaiming_uniform_(conv.weight,
16 | a=activation.negative_slope,
17 | nonlinearity='leaky_relu',
18 | mode=mode)
19 | elif isinstance(activation, (nn.ReLU, nn.ELU)):
20 | torch.nn.init.kaiming_uniform_(conv.weight,
21 | nonlinearity='relu',
22 | mode=mode)
23 | else:
24 | pass
25 | if conv.bias != None:
26 | torch.nn.init.zeros_(conv.bias)
27 |
28 |
29 | def output_to_image(out):
30 | out = (out[0].cpu().permute(1, 2, 0) + 1.) * 127.5
31 | out = out.to(torch.uint8).numpy()
32 | return out
33 |
34 | # ----------------------------------------------------------------------------
35 |
36 | #################################
37 | ########### GENERATOR ###########
38 | #################################
39 |
40 | class GConv(nn.Module):
41 | """Implements the gated 2D convolution introduced in
42 | `Free-Form Image Inpainting with Gated Convolution`(Yu et al., 2019)
43 | """
44 |
45 | def __init__(self, cnum_in,
46 | cnum_out,
47 | ksize,
48 | stride=1,
49 | padding='auto',
50 | rate=1,
51 | activation=nn.ELU(),
52 | bias=True
53 | ):
54 |
55 | super().__init__()
56 |
57 | padding = rate*(ksize-1)//2 if padding == 'auto' else padding
58 | self.activation = activation
59 | self.cnum_out = cnum_out
60 | num_conv_out = cnum_out if self.cnum_out == 3 or self.activation is None else 2*cnum_out
61 | self.conv = nn.Conv2d(cnum_in,
62 | num_conv_out,
63 | kernel_size=ksize,
64 | stride=stride,
65 | padding=padding,
66 | dilation=rate,
67 | bias=bias)
68 |
69 | _init_conv_layer(self.conv, activation=self.activation)
70 |
71 | self.ksize = ksize
72 | self.stride = stride
73 | self.rate = rate
74 | self.padding = padding
75 |
76 | def forward(self, x):
77 | x = self.conv(x)
78 | if self.cnum_out == 3 or self.activation is None:
79 | return x
80 | x, y = torch.split(x, self.cnum_out, dim=1)
81 | x = self.activation(x)
82 | y = torch.sigmoid(y)
83 | x = x * y
84 | return x
85 |
86 | # ----------------------------------------------------------------------------
87 |
88 | class GDeConv(nn.Module):
89 | """Upsampling followed by convolution"""
90 |
91 | def __init__(self, cnum_in,
92 | cnum_out,
93 | padding=1):
94 | super().__init__()
95 | self.conv = GConv(cnum_in, cnum_out, 3, 1,
96 | padding=padding)
97 |
98 | def forward(self, x):
99 | x = F.interpolate(x, scale_factor=2, mode='nearest',
100 | recompute_scale_factor=False)
101 | x = self.conv(x)
102 | return x
103 |
104 | # ----------------------------------------------------------------------------
105 |
106 | class GDownsamplingBlock(nn.Module):
107 | def __init__(self, cnum_in,
108 | cnum_out,
109 | cnum_hidden=None
110 | ):
111 | super().__init__()
112 | cnum_hidden = cnum_out if cnum_hidden == None else cnum_hidden
113 | self.conv1_downsample = GConv(cnum_in, cnum_hidden, 3, 2)
114 | self.conv2 = GConv(cnum_hidden, cnum_out, 3, 1)
115 |
116 | def forward(self, x):
117 | x = self.conv1_downsample(x)
118 | x = self.conv2(x)
119 | return x
120 |
121 | # ----------------------------------------------------------------------------
122 |
123 | class GUpsamplingBlock(nn.Module):
124 | def __init__(self, cnum_in,
125 | cnum_out,
126 | cnum_hidden=None
127 | ):
128 | super().__init__()
129 | cnum_hidden = cnum_out if cnum_hidden == None else cnum_hidden
130 | self.conv1_upsample = GDeConv(cnum_in, cnum_hidden)
131 | self.conv2 = GConv(cnum_hidden, cnum_out, 3, 1)
132 |
133 | def forward(self, x):
134 | x = self.conv1_upsample(x)
135 | x = self.conv2(x)
136 | return x
137 |
138 | # ----------------------------------------------------------------------------
139 |
140 |
141 | class CoarseGenerator(nn.Module):
142 | def __init__(self, cnum_in, cnum):
143 | super().__init__()
144 | self.conv1 = GConv(cnum_in, cnum//2, 5, 1, padding=2)
145 |
146 | # downsampling
147 | self.down_block1 = GDownsamplingBlock(cnum//2, cnum)
148 | self.down_block2 = GDownsamplingBlock(cnum, 2*cnum)
149 |
150 | # bottleneck
151 | self.conv_bn1 = GConv(2*cnum, 2*cnum, 3, 1)
152 | self.conv_bn2 = GConv(2*cnum, 2*cnum, 3, rate=2, padding=2)
153 | self.conv_bn3 = GConv(2*cnum, 2*cnum, 3, rate=4, padding=4)
154 | self.conv_bn4 = GConv(2*cnum, 2*cnum, 3, rate=8, padding=8)
155 | self.conv_bn5 = GConv(2*cnum, 2*cnum, 3, rate=16, padding=16)
156 | self.conv_bn6 = GConv(2*cnum, 2*cnum, 3, 1)
157 | self.conv_bn7 = GConv(2*cnum, 2*cnum, 3, 1)
158 |
159 | # upsampling
160 | self.up_block1 = GUpsamplingBlock(2*cnum, cnum)
161 | self.up_block2 = GUpsamplingBlock(cnum, cnum//4, cnum_hidden=cnum//2)
162 |
163 | # to RGB
164 | self.conv_to_rgb = GConv(cnum//4, 3, 3, 1, activation=None)
165 | self.tanh = nn.Tanh()
166 |
167 | def forward(self, x):
168 | x = self.conv1(x)
169 |
170 | # downsampling
171 | x = self.down_block1(x)
172 | x = self.down_block2(x)
173 |
174 | # bottleneck
175 | x = self.conv_bn1(x)
176 | x = self.conv_bn2(x)
177 | x = self.conv_bn3(x)
178 | x = self.conv_bn4(x)
179 | x = self.conv_bn5(x)
180 | x = self.conv_bn6(x)
181 | x = self.conv_bn7(x)
182 |
183 | # upsampling
184 | x = self.up_block1(x)
185 | x = self.up_block2(x)
186 |
187 | # to RGB
188 | x = self.conv_to_rgb(x)
189 | x = self.tanh(x)
190 | return x
191 |
192 | # ----------------------------------------------------------------------------
193 |
194 | class FineGenerator(nn.Module):
195 | def __init__(self, cnum, return_flow=False):
196 | super().__init__()
197 |
198 | ### CONV BRANCH (B1) ###
199 | self.conv_conv1 = GConv(3, cnum//2, 5, 1, padding=2)
200 |
201 | # downsampling
202 | self.conv_down_block1 = GDownsamplingBlock(
203 | cnum//2, cnum, cnum_hidden=cnum//2)
204 | self.conv_down_block2 = GDownsamplingBlock(
205 | cnum, 2*cnum, cnum_hidden=cnum)
206 |
207 | # bottleneck
208 | self.conv_conv_bn1 = GConv(2*cnum, 2*cnum, 3, 1)
209 | self.conv_conv_bn2 = GConv(2*cnum, 2*cnum, 3, rate=2, padding=2)
210 | self.conv_conv_bn3 = GConv(2*cnum, 2*cnum, 3, rate=4, padding=4)
211 | self.conv_conv_bn4 = GConv(2*cnum, 2*cnum, 3, rate=8, padding=8)
212 | self.conv_conv_bn5 = GConv(2*cnum, 2*cnum, 3, rate=16, padding=16)
213 |
214 | ### ATTENTION BRANCH (B2) ###
215 | self.ca_conv1 = GConv(3, cnum//2, 5, 1, padding=2)
216 |
217 | # downsampling
218 | self.ca_down_block1 = GDownsamplingBlock(
219 | cnum//2, cnum, cnum_hidden=cnum//2)
220 | self.ca_down_block2 = GDownsamplingBlock(cnum, 2*cnum)
221 |
222 | # bottleneck
223 | self.ca_conv_bn1 = GConv(2*cnum, 2*cnum, 3, 1, activation=nn.ReLU())
224 | self.contextual_attention = ContextualAttention(ksize=3,
225 | stride=1,
226 | rate=2,
227 | fuse_k=3,
228 | softmax_scale=10,
229 | fuse=True,
230 | device_ids=None,
231 | return_flow=return_flow,
232 | n_down=2)
233 | self.ca_conv_bn4 = GConv(2*cnum, 2*cnum, 3, 1)
234 | self.ca_conv_bn5 = GConv(2*cnum, 2*cnum, 3, 1)
235 |
236 | ### UNITED BRANCHES ###
237 | self.conv_bn6 = GConv(4*cnum, 2*cnum, 3, 1)
238 | self.conv_bn7 = GConv(2*cnum, 2*cnum, 3, 1)
239 |
240 | # upsampling
241 | self.up_block1 = GUpsamplingBlock(2*cnum, cnum)
242 | self.up_block2 = GUpsamplingBlock(cnum, cnum//4, cnum_hidden=cnum//2)
243 |
244 | # to RGB
245 | self.conv_to_rgb = GConv(cnum//4, 3, 3, 1, activation=None)
246 | self.tanh = nn.Tanh()
247 |
248 | def forward(self, x, mask):
249 | xnow = x
250 |
251 | ### CONV BRANCH ###
252 | x = self.conv_conv1(xnow)
253 | # downsampling
254 | x = self.conv_down_block1(x)
255 | x = self.conv_down_block2(x)
256 |
257 | # bottleneck
258 | x = self.conv_conv_bn1(x)
259 | x = self.conv_conv_bn2(x)
260 | x = self.conv_conv_bn3(x)
261 | x = self.conv_conv_bn4(x)
262 | x = self.conv_conv_bn5(x)
263 | x_hallu = x
264 |
265 | ### ATTENTION BRANCH ###
266 | x = self.ca_conv1(xnow)
267 | # downsampling
268 | x = self.ca_down_block1(x)
269 | x = self.ca_down_block2(x)
270 |
271 | # bottleneck
272 | x = self.ca_conv_bn1(x)
273 | x, offset_flow = self.contextual_attention(x, x, mask)
274 | x = self.ca_conv_bn4(x)
275 | x = self.ca_conv_bn5(x)
276 | pm = x
277 |
278 | # concatenate outputs from both branches
279 | x = torch.cat([x_hallu, pm], dim=1)
280 |
281 | ### UNITED BRANCHES ###
282 | x = self.conv_bn6(x)
283 | x = self.conv_bn7(x)
284 |
285 | # upsampling
286 | x = self.up_block1(x)
287 | x = self.up_block2(x)
288 |
289 | # to RGB
290 | x = self.conv_to_rgb(x)
291 | x = self.tanh(x)
292 |
293 | return x, offset_flow
294 |
295 | # ----------------------------------------------------------------------------
296 |
297 | class Generator(nn.Module):
298 | def __init__(self, cnum_in=5, cnum=48, return_flow=False, checkpoint=None):
299 | super().__init__()
300 | self.stage1 = CoarseGenerator(cnum_in, cnum)
301 | self.stage2 = FineGenerator(cnum, return_flow)
302 | self.return_flow = return_flow
303 |
304 | if checkpoint is not None:
305 | generator_state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['G']
306 | self.load_state_dict(generator_state_dict, strict=True)
307 |
308 | self.eval()
309 |
310 | def forward(self, x, mask):
311 | xin = x
312 | # get coarse result
313 | x_stage1 = self.stage1(x)
314 | # inpaint input with coarse result
315 | x = x_stage1*mask + xin[:, 0:3, :, :]*(1.-mask)
316 | # get refined result
317 | x_stage2, offset_flow = self.stage2(x, mask)
318 |
319 | if self.return_flow:
320 | return x_stage1, x_stage2, offset_flow
321 |
322 | return x_stage1, x_stage2
323 |
324 | @torch.inference_mode()
325 | def infer(self,
326 | image,
327 | mask,
328 | return_vals=['inpainted', 'stage1'],
329 | device='cuda'):
330 | """
331 | Args:
332 | image:
333 | mask:
334 | return_vals: inpainted, stage1, stage2, flow
335 | Returns:
336 | """
337 |
338 | _, h, w = image.shape
339 | grid = 8
340 |
341 | image = image[:3, :h//grid*grid, :w//grid*grid].unsqueeze(0)
342 | mask = mask[0:1, :h//grid*grid, :w//grid*grid].unsqueeze(0)
343 |
344 | image = (image*2 - 1.) # map image values to [-1, 1] range
345 | # 1.: masked 0.: unmasked
346 | mask = (mask > 0.).to(dtype=torch.float32)
347 |
348 | image_masked = image * (1.-mask) # mask image
349 |
350 | ones_x = torch.ones_like(image_masked)[:, 0:1, :, :] # sketch channel
351 | x = torch.cat([image_masked, ones_x, ones_x*mask],
352 | dim=1) # concatenate channels
353 |
354 | if self.return_flow:
355 | x_stage1, x_stage2, offset_flow = self.forward(x, mask)
356 | else:
357 | x_stage1, x_stage2 = self.forward(x, mask)
358 |
359 | image_compl = image * (1.-mask) + x_stage2 * mask
360 |
361 | output = []
362 | for return_val in return_vals:
363 | if return_val.lower() == 'stage1':
364 | output.append(output_to_image(x_stage1))
365 | elif return_val.lower() == 'stage2':
366 | output.append(output_to_image(x_stage2))
367 | elif return_val.lower() == 'inpainted':
368 | output.append(output_to_image(image_compl))
369 | elif return_val.lower() == 'flow' and self.return_flow:
370 | output.append(offset_flow)
371 | else:
372 | print(f'Invalid return value: {return_val}')
373 |
374 | return output
375 |
376 | # ----------------------------------------------------------------------------
377 |
378 | ####################################
379 | ####### CONTEXTUAL ATTENTION #######
380 | ####################################
381 |
382 | """
383 | adapted from: https://github.com/daa233/generative-inpainting-pytorch/blob/master/model/networks.py
384 | """
385 |
386 | class ContextualAttention(nn.Module):
387 | """ Contextual attention layer implementation. \\
388 | Contextual attention is first introduced in publication: \\
389 | `Generative Image Inpainting with Contextual Attention`, Yu et al \\
390 | Args:
391 | ksize: Kernel size for contextual attention
392 | stride: Stride for extracting patches from b
393 | rate: Dilation for matching
394 | softmax_scale: Scaled softmax for attention
395 | """
396 |
397 | def __init__(self,
398 | ksize=3,
399 | stride=1,
400 | rate=1,
401 | fuse_k=3,
402 | softmax_scale=10.,
403 | n_down=2,
404 | fuse=False,
405 | return_flow=False,
406 | device_ids=None):
407 | super(ContextualAttention, self).__init__()
408 | self.ksize = ksize
409 | self.stride = stride
410 | self.rate = rate
411 | self.fuse_k = fuse_k
412 | self.softmax_scale = softmax_scale
413 | self.fuse = fuse
414 | self.device_ids = device_ids
415 | self.n_down = n_down
416 | self.return_flow = return_flow
417 | self.register_buffer('fuse_weight', torch.eye(
418 | fuse_k).view(1, 1, fuse_k, fuse_k))
419 |
420 | def forward(self, f, b, mask=None):
421 | """
422 | Args:
423 | f: Input feature to match (foreground).
424 | b: Input feature for match (background).
425 | mask: Input mask for b, indicating patches not available.
426 | """
427 | device = f.device
428 | # get shapes
429 | raw_int_fs, raw_int_bs = list(f.size()), list(b.size()) # b*c*h*w
430 |
431 | # extract patches from background with stride and rate
432 | kernel = 2 * self.rate
433 | # raw_w is extracted for reconstruction
434 | raw_w = extract_image_patches(b, ksize=kernel,
435 | stride=self.rate*self.stride,
436 | rate=1, padding='auto') # [N, C*k*k, L]
437 | # raw_shape: [N, C, k, k, L]
438 | raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
439 | raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
440 | raw_w_groups = torch.split(raw_w, 1, dim=0)
441 |
442 | # downscaling foreground option: downscaling both foreground and
443 | # background for matching and use original background for reconstruction.
444 | f = F.interpolate(f, scale_factor=1./self.rate,
445 | mode='nearest', recompute_scale_factor=False)
446 | b = F.interpolate(b, scale_factor=1./self.rate,
447 | mode='nearest', recompute_scale_factor=False)
448 | int_fs, int_bs = list(f.size()), list(b.size()) # b*c*h*w
449 | # split tensors along the batch dimension
450 | f_groups = torch.split(f, 1, dim=0)
451 | # w shape: [N, C*k*k, L]
452 | w = extract_image_patches(b, ksize=self.ksize,
453 | stride=self.stride,
454 | rate=1, padding='auto')
455 | # w shape: [N, C, k, k, L]
456 | w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
457 | w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
458 | w_groups = torch.split(w, 1, dim=0)
459 |
460 | # process mask
461 | if mask is None:
462 | mask = torch.zeros(
463 | [int_bs[0], 1, int_bs[2], int_bs[3]], device=device)
464 | else:
465 | mask = F.interpolate(
466 | mask, scale_factor=1./((2**self.n_down)*self.rate), mode='nearest', recompute_scale_factor=False)
467 | int_ms = list(mask.size())
468 | # m shape: [N, C*k*k, L]
469 | m = extract_image_patches(mask, ksize=self.ksize,
470 | stride=self.stride,
471 | rate=1, padding='auto')
472 | # m shape: [N, C, k, k, L]
473 | m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
474 | m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
475 | m = m[0] # m shape: [L, C, k, k]
476 | # mm shape: [L, 1, 1, 1]
477 |
478 | mm = (torch.mean(m, dim=[1, 2, 3], keepdim=True) == 0.).to(
479 | torch.float32)
480 | mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
481 |
482 | y = []
483 | offsets = []
484 | scale = self.softmax_scale # to fit the PyTorch tensor image value range
485 |
486 | for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
487 | '''
488 | O => output channel as a conv filter
489 | I => input channel as a conv filter
490 | xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
491 | wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
492 | raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
493 | '''
494 | # conv for compare
495 | wi = wi[0] # [L, C, k, k]
496 | max_wi = torch.sqrt(torch.sum(torch.square(wi), dim=[
497 | 1, 2, 3], keepdim=True)).clamp_min(1e-4)
498 | wi_normed = wi / max_wi
499 | # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
500 | yi = F.conv2d(xi, wi_normed, stride=1, padding=(
501 | self.ksize-1)//2) # [1, L, H, W]
502 | # conv implementation for fuse scores to encourage large patches
503 | if self.fuse:
504 | # make all of depth to spatial resolution
505 | # (B=1, I=1, H=32*32, W=32*32)
506 | yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
507 | # (B=1, C=1, H=32*32, W=32*32)
508 | yi = F.conv2d(yi, self.fuse_weight, stride=1,
509 | padding=(self.fuse_k-1)//2)
510 | # (B=1, 32, 32, 32, 32)
511 | yi = yi.contiguous().view(
512 | 1, int_bs[2], int_bs[3], int_fs[2], int_fs[3])
513 | yi = yi.permute(0, 2, 1, 4, 3)
514 |
515 | yi = yi.contiguous().view(
516 | 1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
517 | yi = F.conv2d(yi, self.fuse_weight, stride=1,
518 | padding=(self.fuse_k-1)//2)
519 | yi = yi.contiguous().view(
520 | 1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
521 | yi = yi.permute(0, 2, 1, 4, 3).contiguous()
522 |
523 | # (B=1, C=32*32, H=32, W=32)
524 | yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3])
525 | # softmax to match
526 | yi = yi * mm
527 | yi = F.softmax(yi*scale, dim=1)
528 | yi = yi * mm # [1, L, H, W]
529 |
530 | if self.return_flow:
531 | offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
532 |
533 | if int_bs != int_fs:
534 | # Normalize the offset value to match foreground dimension
535 | times = (int_fs[2]*int_fs[3])/(int_bs[2]*int_bs[3])
536 | offset = ((offset + 1).float() * times - 1).to(torch.int64)
537 | offset = torch.cat([torch.div(offset, int_fs[3], rounding_mode='trunc'),
538 | offset % int_fs[3]], dim=1) # 1*2*H*W
539 | offsets.append(offset)
540 |
541 | # deconv for patch pasting
542 | wi_center = raw_wi[0]
543 | yi = F.conv_transpose2d(
544 | yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
545 | y.append(yi)
546 |
547 | y = torch.cat(y, dim=0) # back to the mini-batch
548 | y = y.contiguous().view(raw_int_fs)
549 |
550 | if not self.return_flow:
551 | return y, None
552 |
553 | offsets = torch.cat(offsets, dim=0)
554 | offsets = offsets.view(int_fs[0], 2, *int_fs[2:])
555 |
556 | # case1: visualize optical flow: minus current position
557 | h_add = torch.arange(int_fs[2], device=device).view(
558 | [1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3])
559 | w_add = torch.arange(int_fs[3], device=device).view(
560 | [1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1)
561 | offsets = offsets - torch.cat([h_add, w_add], dim=1)
562 | # to flow image
563 | flow = torch.from_numpy(flow_to_image(
564 | offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255.
565 | flow = flow.permute(0, 3, 1, 2)
566 | # case2: visualize which pixels are attended
567 | # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))
568 |
569 | if self.rate != 1:
570 | flow = F.interpolate(flow, scale_factor=self.rate,
571 | mode='bilinear', align_corners=True)
572 |
573 | return y, flow
574 |
575 | # ----------------------------------------------------------------------------
576 |
577 | def flow_to_image(flow):
578 | """Transfer flow map to image.
579 | Part of code forked from flownet.
580 | """
581 | out = []
582 | maxu = -999.
583 | maxv = -999.
584 | minu = 999.
585 | minv = 999.
586 | maxrad = -1
587 | for i in range(flow.shape[0]):
588 | u = flow[i, :, :, 0]
589 | v = flow[i, :, :, 1]
590 | idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
591 | u[idxunknow] = 0
592 | v[idxunknow] = 0
593 | maxu = max(maxu, np.max(u))
594 | minu = min(minu, np.min(u))
595 | maxv = max(maxv, np.max(v))
596 | minv = min(minv, np.min(v))
597 | rad = np.sqrt(u ** 2 + v ** 2)
598 | maxrad = max(maxrad, np.max(rad))
599 | u = u / (maxrad + np.finfo(float).eps)
600 | v = v / (maxrad + np.finfo(float).eps)
601 | img = compute_color(u, v)
602 | out.append(img)
603 | return np.float32(np.uint8(out))
604 |
605 | # ----------------------------------------------------------------------------
606 |
607 | def compute_color(u, v):
608 | h, w = u.shape
609 | img = np.zeros([h, w, 3])
610 | nanIdx = np.isnan(u) | np.isnan(v)
611 | u[nanIdx] = 0
612 | v[nanIdx] = 0
613 | # colorwheel = COLORWHEEL
614 | colorwheel = make_color_wheel()
615 | ncols = np.size(colorwheel, 0)
616 | rad = np.sqrt(u ** 2 + v ** 2)
617 | a = np.arctan2(-v, -u) / np.pi
618 | fk = (a + 1) / 2 * (ncols - 1) + 1
619 | k0 = np.floor(fk).astype(int)
620 | k1 = k0 + 1
621 | k1[k1 == ncols + 1] = 1
622 | f = fk - k0
623 | for i in range(np.size(colorwheel, 1)):
624 | tmp = colorwheel[:, i]
625 | col0 = tmp[k0 - 1] / 255
626 | col1 = tmp[k1 - 1] / 255
627 | col = (1 - f) * col0 + f * col1
628 | idx = rad <= 1
629 | col[idx] = 1 - rad[idx] * (1 - col[idx])
630 | notidx = np.logical_not(idx)
631 | col[notidx] *= 0.75
632 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
633 | return img
634 |
635 | # ----------------------------------------------------------------------------
636 |
637 | def make_color_wheel():
638 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
639 | ncols = RY + YG + GC + CB + BM + MR
640 | colorwheel = np.zeros([ncols, 3])
641 | col = 0
642 | # RY
643 | colorwheel[0:RY, 0] = 255
644 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
645 | col += RY
646 | # YG
647 | colorwheel[col:col + YG, 0] = 255 - \
648 | np.transpose(np.floor(255 * np.arange(0, YG) / YG))
649 | colorwheel[col:col + YG, 1] = 255
650 | col += YG
651 | # GC
652 | colorwheel[col:col + GC, 1] = 255
653 | colorwheel[col:col + GC,
654 | 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
655 | col += GC
656 | # CB
657 | colorwheel[col:col + CB, 1] = 255 - \
658 | np.transpose(np.floor(255 * np.arange(0, CB) / CB))
659 | colorwheel[col:col + CB, 2] = 255
660 | col += CB
661 | # BM
662 | colorwheel[col:col + BM, 2] = 255
663 | colorwheel[col:col + BM,
664 | 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
665 | col += + BM
666 | # MR
667 | colorwheel[col:col + MR, 2] = 255 - \
668 | np.transpose(np.floor(255 * np.arange(0, MR) / MR))
669 | colorwheel[col:col + MR, 0] = 255
670 | return colorwheel
671 |
672 | # ----------------------------------------------------------------------------
673 |
674 |
675 | def extract_image_patches(images, ksize, stride, rate, padding='auto'):
676 | """
677 | Extracts sliding local blocks \\
678 | see also: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
679 | """
680 |
681 | padding = rate*(ksize-1)//2 if padding == 'auto' else padding
682 |
683 | unfold = torch.nn.Unfold(kernel_size=ksize,
684 | dilation=rate,
685 | padding=padding,
686 | stride=stride)
687 | patches = unfold(images)
688 | return patches # [N, C*k*k, L], L is the total number of such blocks
689 |
690 | # ----------------------------------------------------------------------------
691 |
692 | #################################
693 | ######### DISCRIMINATOR #########
694 | #################################
695 |
696 | class Conv2DSpectralNorm(nn.Conv2d):
697 | """Convolution layer that applies Spectral Normalization before every call."""
698 |
699 | def __init__(self, cnum_in,
700 | cnum_out, kernel_size, stride, padding=0, n_iter=1, eps=1e-12, bias=True):
701 | super().__init__(cnum_in,
702 | cnum_out, kernel_size=kernel_size,
703 | stride=stride, padding=padding, bias=bias)
704 | self.register_buffer("weight_u", torch.empty(self.weight.size(0), 1))
705 | nn.init.trunc_normal_(self.weight_u)
706 | self.n_iter = n_iter
707 | self.eps = eps
708 |
709 | def l2_norm(self, x):
710 | return F.normalize(x, p=2, dim=0, eps=self.eps)
711 |
712 | def forward(self, x):
713 |
714 | weight_orig = self.weight.flatten(1).detach()
715 |
716 | for _ in range(self.n_iter):
717 | v = self.l2_norm(weight_orig.t() @ self.weight_u)
718 | self.weight_u = self.l2_norm(weight_orig @ v)
719 |
720 | sigma = self.weight_u.t() @ weight_orig @ v
721 | self.weight.data.div_(sigma)
722 |
723 | x = super().forward(x)
724 |
725 | return x
726 |
727 | # ----------------------------------------------------------------------------
728 |
729 | class DConv(nn.Module):
730 | def __init__(self, cnum_in,
731 | cnum_out, ksize=5, stride=2, padding='auto'):
732 | super().__init__()
733 | padding = (ksize-1)//2 if padding == 'auto' else padding
734 | self.conv_sn = Conv2DSpectralNorm(
735 | cnum_in, cnum_out, ksize, stride, padding)
736 | #self.conv_sn = spectral_norm(nn.Conv2d(cnum_in, cnum_out, ksize, stride, padding))
737 | self.leaky = nn.LeakyReLU(negative_slope=0.2)
738 |
739 | def forward(self, x):
740 | x = self.conv_sn(x)
741 | x = self.leaky(x)
742 | return x
743 |
744 | # ----------------------------------------------------------------------------
745 |
746 | class Discriminator(nn.Module):
747 | def __init__(self, cnum_in, cnum):
748 | super().__init__()
749 | self.conv1 = DConv(cnum_in, cnum)
750 | self.conv2 = DConv(cnum, 2*cnum)
751 | self.conv3 = DConv(2*cnum, 4*cnum)
752 | self.conv4 = DConv(4*cnum, 4*cnum)
753 | self.conv5 = DConv(4*cnum, 4*cnum)
754 | self.conv6 = DConv(4*cnum, 4*cnum)
755 |
756 | def forward(self, x):
757 | x = self.conv1(x)
758 | x = self.conv2(x)
759 | x = self.conv3(x)
760 | x = self.conv4(x)
761 | x = self.conv5(x)
762 | x = self.conv6(x)
763 | x = nn.Flatten()(x)
764 |
765 | return x
766 |
--------------------------------------------------------------------------------
/src/objRemove.py:
--------------------------------------------------------------------------------
1 |
2 | import copy
3 | import cv2
4 | import numpy as np
5 | import torchvision.transforms as T
6 | from torchvision.io import read_image
7 |
8 | class ObjectRemove():
9 |
10 | def __init__(self, segmentModel = None, rcnn_transforms = None, inpaintModel= None, image_path = '') -> None:
11 | self.segmentModel = segmentModel
12 | self.inpaintModel = inpaintModel
13 | self.rcnn_transforms = rcnn_transforms
14 | self.image_path = image_path
15 | self.highest_prob_mask = None
16 | self.image_orig = None
17 | self.image_masked = None
18 | self.box = None
19 |
20 | def run(self):
21 | '''
22 | Main run program
23 | '''
24 | #read in image and transform
25 | print('Reading in image')
26 | images = self.preprocess_image()
27 | self.image_orig = images
28 |
29 | print("segmentation")
30 | #segmentation
31 | output = self.segment(images)
32 | out = output[0]
33 |
34 | print('user click')
35 | #user click
36 | ref_points = self.user_click()
37 | self.box = ref_points
38 | self.highest_prob_mask = self.find_mask(out, ref_points)
39 |
40 | self.highest_prob_mask[self.highest_prob_mask > 0.1] = 1
41 | self.highest_prob_mask[self.highest_prob_mask <0.1] = 0
42 | self.image_masked = (images[0]*(1-self.highest_prob_mask))
43 | print('inpaint')
44 | #inpaint
45 | output = self.inpaint()
46 |
47 | #return final inpainted image
48 | return output
49 |
50 | def percent_within(self,nonzeros, rectangle):
51 | '''
52 | Calculates percent of mask inside rectangle
53 | '''
54 | rect_ul, rect_br = rectangle
55 | inside_count = 0
56 | for _,y,x in nonzeros:
57 | if x >= rect_ul[0] and x<= rect_br[0] and y <= rect_br[1] and y>= rect_ul[1]:
58 | inside_count+=1
59 | return inside_count / len(nonzeros)
60 |
61 | def iou(self, boxes_a, boxes_b):
62 | '''
63 | Calculates IOU between all pairs of boxes
64 |
65 | boxes_a and boxes_b are matrices with each row representing the 4 coords of a box
66 | '''
67 |
68 | x1 = np.array([boxes_a[:,0], boxes_b[:,0]]).max(axis=0)
69 | y1 = np.array([boxes_a[:,1], boxes_b[:,1]]).max(axis=0)
70 | x2 = np.array([boxes_a[:,2], boxes_b[:,2]]).min(axis=0)
71 | y2 = np.array([boxes_a[:,3], boxes_b[:,3]]).min(axis=0)
72 |
73 | w = x2-x1
74 | h = y2-y1
75 | w[w<0] = 0
76 | h[h<0] = 0
77 |
78 | intersect = w* h
79 |
80 | area_a = (boxes_a[:,2] - boxes_a[:,0]) * (boxes_a[:,3] - boxes_a[:,1])
81 | area_b = (boxes_b[:,2] - boxes_b[:,0]) * (boxes_b[:,3] - boxes_b[:,1])
82 |
83 | union = area_a + area_b - intersect
84 |
85 | return intersect / (union + 0.00001)
86 |
87 | def find_mask(self, rcnn_output, rectangle):
88 | '''
89 | Finds the mask with highest probability in the rectangle given
90 |
91 | '''
92 | bounding_boxes= rcnn_output['boxes'].detach().numpy()
93 | masks = rcnn_output['masks']
94 |
95 | ref_boxes = np.array([rectangle], dtype=object)
96 | ref_boxes = np.repeat(ref_boxes, bounding_boxes.shape[0], axis=0)
97 |
98 | ious= self.iou(ref_boxes, bounding_boxes)
99 |
100 | best_ind = np.argmax(ious)
101 |
102 | return masks[best_ind]
103 |
104 |
105 | #compare masks pixelwise
106 | '''
107 | masks = rcnn_output['masks']
108 | #go through each nonzero point in the mask and count how many points are within the rectangles
109 | highest_prob_mask = None
110 | percent_within,min_diff = 0,float('inf')
111 | #print('masks lenght:', len(masks))
112 |
113 |
114 | for m in range(len(masks)):
115 | #masks[m][masks[m] > 0.5] = 255.0
116 | #masks[m][masks[m] < 0.5] = 0.0
117 | nonzeros = np.nonzero(masks[m])
118 | #diff = rect_area - len(nonzeros)
119 | p = self.percent_within(nonzeros, rectangle)
120 | if p > percent_within:
121 | highest_prob_mask = masks[m]
122 | percent_within = p
123 | print(p)
124 | return highest_prob_mask
125 | '''
126 |
127 | def preprocess_image(self):
128 | '''
129 | Read in image and prepare for segmentation
130 | '''
131 | img= [read_image(self.image_path)]
132 | _,h,w = img[0].shape
133 | size = min(h,w)
134 | if size > 512:
135 | img[0] = T.Resize(512, max_size=680, antialias=True)(img[0])
136 |
137 | images_transformed = [self.rcnn_transforms(d) for d in img]
138 | return images_transformed
139 |
140 |
141 | def segment(self,images):
142 | out = self.segmentModel(images)
143 | return out
144 |
145 | def user_click(self):
146 | '''
147 | Get user input for object to remove
148 |
149 | Returns the rectangle bounding box give by user as two points
150 | '''
151 | ref_point = []
152 | cache=None
153 | draw = False
154 |
155 |
156 | def click(event, x, y, flags, param):
157 | nonlocal ref_point,cache,img, draw
158 | if event == cv2.EVENT_LBUTTONDOWN:
159 | draw = True
160 | ref_point = [x, y]
161 | cache = copy.deepcopy(img)
162 |
163 | elif event == cv2.EVENT_MOUSEMOVE:
164 | if draw:
165 | img = copy.deepcopy(cache)
166 | cv2.rectangle(img, (ref_point[0], ref_point[1]), (x,y), (0, 255, 0), 2)
167 | cv2.imshow('image',img)
168 |
169 |
170 | elif event == cv2.EVENT_LBUTTONUP:
171 | draw = False
172 | ref_point += [x,y]
173 | ref_point.append((x, y))
174 | cv2.rectangle(img, (ref_point[0], ref_point[1]), (ref_point[2], ref_point[3]), (0, 255, 0), 2)
175 | cv2.imshow("image", img)
176 |
177 |
178 | img = self.image_orig[0].permute(1,2,0).numpy()
179 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
180 | clone = img.copy()
181 |
182 | cv2.namedWindow("image")
183 |
184 | cv2.setMouseCallback('image', click)
185 |
186 | while True:
187 | cv2.imshow("image", img)
188 | key = cv2.waitKey(1) & 0xFF
189 |
190 | if key == ord("r"):
191 | img = clone.copy()
192 |
193 | elif key == ord("c"):
194 | break
195 | cv2.destroyAllWindows()
196 |
197 | return ref_point
198 |
199 | def inpaint(self):
200 | output = self.inpaintModel.infer(self.image_orig[0], self.highest_prob_mask, return_vals=['inpainted'])
201 | return output[0]
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/test_imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/.DS_Store
--------------------------------------------------------------------------------
/test_imgs/bikes.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/bikes.jpg
--------------------------------------------------------------------------------
/test_imgs/park.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/park.jpeg
--------------------------------------------------------------------------------
/test_imgs/test1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/treeebooor/object-remove/721364fe8c2634b81be9289ed8d27017520e05b2/test_imgs/test1.jpeg
--------------------------------------------------------------------------------