├── utils
├── __init__.py
├── region_fill.py
├── Poisson_blend.py
└── Poisson_blend_img.py
├── edgeconnect
├── __init__.py
├── metrics.py
├── config.py
├── region_fill.py
├── utils.py
├── loss.py
├── networks.py
├── models.py
└── dataset.py
├── models
├── DeepFill_Models
│ ├── util.py
│ ├── __init__.py
│ ├── DeepFill.py
│ └── ops.py
└── __init__.py
├── .gitignore
├── RAFT
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── utils.py
│ ├── frame_utils.py
│ ├── flow_viz.py
│ └── augmentor.py
├── demo.py
├── corr.py
├── raft.py
├── update.py
├── extractor.py
└── datasets.py
├── requirements.txt
├── download_data_weights.sh
├── tool
├── spatial_inpaint.py
├── frame_inpaint.py
└── video_completion_modified.py
├── LICENSE
├── FAQ.md
├── evaluate.py
└── README.md
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/edgeconnect/__init__.py:
--------------------------------------------------------------------------------
1 | # empty
--------------------------------------------------------------------------------
/models/DeepFill_Models/util.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/DeepFill_Models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | *.pyc
3 | data
4 | weight
5 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .DeepFill_Models import DeepFill
2 |
--------------------------------------------------------------------------------
/RAFT/__init__.py:
--------------------------------------------------------------------------------
1 | # from .demo import RAFT_infer
2 | from .raft import RAFT
3 |
--------------------------------------------------------------------------------
/RAFT/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .flow_viz import flow_to_image
2 | from .frame_utils import writeFlow
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | opencv-contrib-python
3 | imageio
4 | imageio-ffmpeg
5 | scipy
6 | scikit-image
7 |
--------------------------------------------------------------------------------
/download_data_weights.sh:
--------------------------------------------------------------------------------
1 |
2 | wget https://filebox.ece.vt.edu/~chengao/FGVC/data.zip
3 | unzip data.zip
4 | rm data.zip
5 |
6 | wget https://filebox.ece.vt.edu/~chengao/FGVC/weight.zip
7 | unzip weight.zip
8 | rm weight.zip
9 |
--------------------------------------------------------------------------------
/tool/spatial_inpaint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def spatial_inpaint(deepfill, mask, video_comp):
9 |
10 | keyFrameInd = np.argmax(np.sum(np.sum(mask, axis=0), axis=0))
11 | with torch.no_grad():
12 | img_res = deepfill.forward(video_comp[:, :, :, keyFrameInd] * 255., mask[:, :, keyFrameInd]) / 255.
13 | video_comp[mask[:, :, keyFrameInd], :, keyFrameInd] = img_res[mask[:, :, keyFrameInd], :]
14 | mask[:, :, keyFrameInd] = False
15 |
16 | return mask, video_comp
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) 2020 Virginia Tech Vision and Learning Lab
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | --------------------------- LICENSE FOR EdgeConnect --------------------------------
25 |
26 | Attribution-NonCommercial 4.0 International
27 |
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | Q: val集里的video_0046里有全白的mask?
2 | A: val集里的video_0046最后几张因为目标出画了,mask误成了全白。可以直接忽略、删除。
3 |
4 | Q: 允许在线更新吗?
5 | A: 允许。
6 |
7 | Q: 你们可以提供掩膜吗?
8 | A: 训练集不提供掩膜,选手训练过程中可以随机挖、无穷无尽地挖~
9 |
10 | Q: 内存显存8G的限制对于训练要不要求?还是只是推断过程?
11 | A: 8G的要求指的是推理过程。训练时不做要求,别大到无法复现即可。
12 |
13 | Q: 测试集的mask也可以拿来训练吗?
14 | A: 测试集的mask可以用,val和test_a都是同样的一套水印和舞者模版,没什么差别。但是,测试集的视频不能用来训练。
15 |
16 | Q: test_a里的0095这个视频前面几张是全黑的,正常吗?
17 | A: 没有关系,原图也黑,不影响。
18 |
19 | Q: 我开一个主线程实时测内存和运行时间,然后开一个子线程进行真正的计算处理,这样可以吗?
20 | A: 可以的。真正计算是单独的一个线程是可以的。
21 |
22 | Q: 复现误差要求多少以内?
23 | A: 我们会根据选手提交的代码进行复现。复现误差小于0.5%时,不改变名次,否则会修正名次。当误差过大且选手无法给出合理解释时,会取消成绩。
24 |
25 | Q: 可以直接告诉你们哪些id我用哪种方法吗?
26 | A: 不可以去直接指定特殊id用特殊方法,但是你可以用算法去判断用哪种方法。
27 |
28 | Q: 我用同一种方法迭代出了不同的模型,我可以挑选出来最佳组合结果吗?比如一个模型在验证集上对某一类掩膜效果好,然后还有一个对另外一类掩膜效果好,提交结果的时候提交他两的组合可以么?
29 | A: 同上。不可以去直接指定特殊id用特殊模型,但是你可以用算法去判断用哪种模型。
30 |
31 | Q: 单卡是指在服务器跑还是本地呢? V100跟3090不同模型可能速度不一样?
32 | A: 服务器上。显卡型号V100或3090。要求15小时内跑完是定得比较宽泛了,选手也不能卡着15小时来。而且我们也不比速度,速度在范围内就可以了。
33 |
34 | Q: 8G指的是8000M还是8192M? 内存不超8G是指模型推断过程中的任意一个时刻都不能超过吗?
35 | A: 8192M。测试阶段我们会给选手开一个比8G稍大一点的服务器。
36 |
37 | Q:使用的预训练模型是要6月以前的是吗?有star要求吗?
38 | A: 使用的预训练模型需要为2021年6月1日以前公开的,没有star要求。
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/edgeconnect/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class EdgeAccuracy(nn.Module):
6 | """
7 | Measures the accuracy of the edge map
8 | """
9 | def __init__(self, threshold=0.5):
10 | super(EdgeAccuracy, self).__init__()
11 | self.threshold = threshold
12 |
13 | def __call__(self, inputs, outputs):
14 | labels = (inputs > self.threshold)
15 | outputs = (outputs > self.threshold)
16 |
17 | relevant = torch.sum(labels.float())
18 | selected = torch.sum(outputs.float())
19 |
20 | if relevant == 0 and selected == 0:
21 | return torch.tensor(1), torch.tensor(1)
22 |
23 | true_positive = ((outputs == labels) * labels).float()
24 | recall = torch.sum(true_positive) / (relevant + 1e-8)
25 | precision = torch.sum(true_positive) / (selected + 1e-8)
26 |
27 | return precision, recall
28 |
29 |
30 | class PSNR(nn.Module):
31 | def __init__(self, max_val):
32 | super(PSNR, self).__init__()
33 |
34 | base10 = torch.log(torch.tensor(10.0))
35 | max_val = torch.tensor(max_val).float()
36 |
37 | self.register_buffer('base10', base10)
38 | self.register_buffer('max_val', 20 * torch.log(max_val) / base10)
39 |
40 | def __call__(self, a, b):
41 | mse = torch.mean((a.float() - b.float()) ** 2)
42 |
43 | if mse == 0:
44 | return torch.tensor(0)
45 |
46 | return self.max_val - 10 * torch.log(mse) / self.base10
47 |
--------------------------------------------------------------------------------
/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os
4 | import cv2
5 | import glob
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 | from .raft import RAFT
11 | from .utils import flow_viz
12 | from .utils.utils import InputPadder
13 |
14 |
15 |
16 | DEVICE = 'cuda'
17 |
18 | def load_image(imfile):
19 | img = np.array(Image.open(imfile)).astype(np.uint8)
20 | img = torch.from_numpy(img).permute(2, 0, 1).float()
21 | return img
22 |
23 |
24 | def load_image_list(image_files):
25 | images = []
26 | for imfile in sorted(image_files):
27 | images.append(load_image(imfile))
28 |
29 | images = torch.stack(images, dim=0)
30 | images = images.to(DEVICE)
31 |
32 | padder = InputPadder(images.shape)
33 | return padder.pad(images)[0]
34 |
35 |
36 | def viz(img, flo):
37 | img = img[0].permute(1,2,0).cpu().numpy()
38 | flo = flo[0].permute(1,2,0).cpu().numpy()
39 |
40 | # map flow to rgb image
41 | flo = flow_viz.flow_to_image(flo)
42 | # img_flo = np.concatenate([img, flo], axis=0)
43 | img_flo = flo
44 |
45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
47 | # cv2.waitKey()
48 |
49 |
50 | def demo(args):
51 | model = torch.nn.DataParallel(RAFT(args))
52 | model.load_state_dict(torch.load(args.model))
53 |
54 | model = model.module
55 | model.to(DEVICE)
56 | model.eval()
57 |
58 | with torch.no_grad():
59 | images = glob.glob(os.path.join(args.path, '*.png')) + \
60 | glob.glob(os.path.join(args.path, '*.jpg'))
61 |
62 | images = load_image_list(images)
63 | for i in range(images.shape[0]-1):
64 | image1 = images[i,None]
65 | image2 = images[i+1,None]
66 |
67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
68 | viz(image1, flow_up)
69 |
70 |
71 | def RAFT_infer(args):
72 | model = torch.nn.DataParallel(RAFT(args))
73 | model.load_state_dict(torch.load(args.model))
74 |
75 | model = model.module
76 | model.to(DEVICE)
77 | model.eval()
78 |
79 | return model
80 |
--------------------------------------------------------------------------------
/RAFT/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd):
75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/edgeconnect/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | class Config(dict):
5 | def __init__(self, config_path):
6 | with open(config_path, 'r') as f:
7 | self._yaml = f.read()
8 | self._dict = yaml.load(self._yaml)
9 | self._dict['PATH'] = os.path.dirname(config_path)
10 |
11 | def __getattr__(self, name):
12 | if self._dict.get(name) is not None:
13 | return self._dict[name]
14 |
15 | if DEFAULT_CONFIG.get(name) is not None:
16 | return DEFAULT_CONFIG[name]
17 |
18 | return None
19 |
20 | def print(self):
21 | print('Model configurations:')
22 | print('---------------------------------')
23 | print(self._yaml)
24 | print('')
25 | print('---------------------------------')
26 | print('')
27 |
28 |
29 | DEFAULT_CONFIG = {
30 | 'MODE': 1, # 1: train, 2: test, 3: eval
31 | 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
32 | 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
33 | 'EDGE': 1, # 1: canny, 2: external
34 | 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
35 | 'SEED': 10, # random seed
36 | 'GPU': [0], # list of gpu ids
37 | 'DEBUG': 0, # turns on debugging mode
38 | 'VERBOSE': 0, # turns on verbose mode in the output console
39 |
40 | 'LR': 0.0001, # learning rate
41 | 'D2G_LR': 0.1, # discriminator/generator learning rate ratio
42 | 'BETA1': 0.0, # adam optimizer beta1
43 | 'BETA2': 0.9, # adam optimizer beta2
44 | 'BATCH_SIZE': 8, # input batch size for training
45 | 'INPUT_SIZE': 256, # input image size for training 0 for original size
46 | 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
47 | 'MAX_ITERS': 2e6, # maximum number of iterations to train the model
48 |
49 | 'EDGE_THRESHOLD': 0.5, # edge detection threshold
50 | 'L1_LOSS_WEIGHT': 1, # l1 loss weight
51 | 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight
52 | 'STYLE_LOSS_WEIGHT': 1, # style loss weight
53 | 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight
54 | 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight
55 |
56 | 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge
57 | 'GAN_POOL_SIZE': 0, # fake images pool size
58 |
59 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never)
60 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never)
61 | 'SAMPLE_SIZE': 12, # number of images to sample
62 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never)
63 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never)
64 | }
65 |
--------------------------------------------------------------------------------
/models/DeepFill_Models/DeepFill.py:
--------------------------------------------------------------------------------
1 | from .ops import *
2 |
3 |
4 | class Generator(nn.Module):
5 | def __init__(self, first_dim=32, isCheck=False, device=None):
6 | super(Generator, self).__init__()
7 | self.isCheck = isCheck
8 | self.device = device
9 | self.stage_1 = CoarseNet(5, first_dim, device=device)
10 | self.stage_2 = RefinementNet(5, first_dim, device=device)
11 |
12 | def forward(self, masked_img, mask, small_mask): # mask : 1 x 1 x H x W
13 |
14 | # border, maybe
15 | mask = mask.expand(masked_img.size(0),1,masked_img.size(2),masked_img.size(3))
16 | small_mask = small_mask.expand(masked_img.size(0), 1, masked_img.size(2) // 8, masked_img.size(3) // 8)
17 | if self.device:
18 | ones = to_var(torch.ones(mask.size()), device=self.device)
19 | else:
20 | ones = to_var(torch.ones(mask.size()))
21 | # stage1
22 | stage1_input = torch.cat([masked_img, ones, ones*mask], dim=1)
23 | stage1_output, resized_mask = self.stage_1(stage1_input, mask)
24 | # stage2
25 | new_masked_img = stage1_output*mask.clone() + masked_img.clone()*(1.-mask.clone())
26 | stage2_input = torch.cat([new_masked_img, ones.clone(), ones.clone()*mask.clone()], dim=1)
27 | stage2_output, offset_flow = self.stage_2(stage2_input, small_mask)
28 |
29 | return stage1_output, stage2_output, offset_flow
30 |
31 |
32 | class CoarseNet(nn.Module):
33 | '''
34 | # input: B x 5 x W x H
35 | # after down: B x 128(32*4) x W/4 x H/4
36 | # after atrous: same with the output size of the down module
37 | # after up : same with the input size
38 | '''
39 | def __init__(self, in_ch, out_ch, device=None):
40 | super(CoarseNet,self).__init__()
41 | self.down = Down_Module(in_ch, out_ch)
42 | self.atrous = Dilation_Module(out_ch*4, out_ch*4)
43 | self.up = Up_Module(out_ch*4, 3)
44 | self.device=device
45 |
46 | def forward(self, x, mask):
47 | x = self.down(x)
48 | resized_mask = down_sample(mask, scale_factor=0.25, mode='nearest', device=self.device)
49 | x = self.atrous(x)
50 | x = self.up(x)
51 |
52 | return x, resized_mask
53 |
54 |
55 | class RefinementNet(nn.Module):
56 | '''
57 | # input: B x 5 x W x H
58 | # after down: B x 128(32*4) x W/4 x H/4
59 | # after atrous: same with the output size of the down module
60 | # after up : same with the input size
61 | '''
62 | def __init__(self, in_ch, out_ch, device=None):
63 | super(RefinementNet,self).__init__()
64 | self.down_conv_branch = Down_Module(in_ch, out_ch, isRefine=True)
65 | self.down_attn_branch = Down_Module(in_ch, out_ch, activation=nn.ReLU(), isRefine=True, isAttn=True)
66 | self.atrous = Dilation_Module(out_ch*4, out_ch*4)
67 | self.CAttn = Contextual_Attention_Module(out_ch*4, out_ch*4, device=device)
68 | self.up = Up_Module(out_ch*8, 3, isRefine=True)
69 |
70 | def forward(self, x, resized_mask):
71 | # conv branch
72 | conv_x = self.down_conv_branch(x)
73 | conv_x = self.atrous(conv_x)
74 |
75 | # attention branch
76 | attn_x = self.down_attn_branch(x)
77 |
78 | attn_x, offset_flow = self.CAttn(attn_x, attn_x, mask=resized_mask)
79 |
80 | # concat two branches
81 | deconv_x = torch.cat([conv_x, attn_x], dim=1) # deconv_x => B x 256 x W/4 x H/4
82 | x = self.up(deconv_x)
83 |
84 | return x, offset_flow
85 |
--------------------------------------------------------------------------------
/RAFT/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class CorrLayer(torch.autograd.Function):
64 | @staticmethod
65 | def forward(ctx, fmap1, fmap2, coords, r):
66 | fmap1 = fmap1.contiguous()
67 | fmap2 = fmap2.contiguous()
68 | coords = coords.contiguous()
69 | ctx.save_for_backward(fmap1, fmap2, coords)
70 | ctx.r = r
71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
72 | return corr
73 |
74 | @staticmethod
75 | def backward(ctx, grad_corr):
76 | fmap1, fmap2, coords = ctx.saved_tensors
77 | grad_corr = grad_corr.contiguous()
78 | fmap1_grad, fmap2_grad, coords_grad = \
79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
80 | return fmap1_grad, fmap2_grad, coords_grad, None
81 |
82 |
83 | class AlternateCorrBlock:
84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
85 | self.num_levels = num_levels
86 | self.radius = radius
87 |
88 | self.pyramid = [(fmap1, fmap2)]
89 | for i in range(self.num_levels):
90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
92 | self.pyramid.append((fmap1, fmap2))
93 |
94 | def __call__(self, coords):
95 |
96 | coords = coords.permute(0, 2, 3, 1)
97 | B, H, W, _ = coords.shape
98 |
99 | corr_list = []
100 | for i in range(self.num_levels):
101 | r = self.radius
102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
104 |
105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
107 | corr_list.append(corr.squeeze(1))
108 |
109 | corr = torch.stack(corr_list, dim=1)
110 | corr = corr.reshape(B, -1, H, W)
111 | return corr / 16.0
112 |
--------------------------------------------------------------------------------
/RAFT/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/tool/frame_inpaint.py:
--------------------------------------------------------------------------------
1 | import sys, os, argparse
2 | sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))
3 | import torch
4 | import numpy as np
5 | import cv2
6 |
7 | from models import DeepFill
8 |
9 |
10 | class DeepFillv1(object):
11 | def __init__(self,
12 | pretrained_model=None,
13 | image_shape=[512, 960],
14 | res_shape=None,
15 | device=torch.device('cuda:0')):
16 | self.image_shape = image_shape
17 | self.res_shape = res_shape
18 | self.device = device
19 |
20 | self.deepfill = DeepFill.Generator().to(device)
21 | model_weight = torch.load(pretrained_model)
22 | self.deepfill.load_state_dict(model_weight, strict=True)
23 | self.deepfill.eval()
24 | print('Load Deepfill Model from', pretrained_model)
25 |
26 | def forward(self, img, mask):
27 |
28 | img, mask, small_mask = self.data_preprocess(img, mask, size=self.image_shape)
29 |
30 | image = torch.stack([img])
31 | mask = torch.stack([mask])
32 | small_mask = torch.stack([small_mask])
33 |
34 | with torch.no_grad():
35 | _, inpaint_res, _ = self.deepfill(image.to(self.device), mask.to(self.device), small_mask.to(self.device))
36 |
37 | res_complete = self.data_proprocess(image, mask, inpaint_res)
38 |
39 | return res_complete
40 |
41 | def data_preprocess(self, img, mask, enlarge_kernel=0, size=[512, 960]):
42 | img = img / 127.5 - 1
43 | mask = (mask > 0).astype(np.int)
44 | img = cv2.resize(img, (size[1], size[0]))
45 | if enlarge_kernel > 0:
46 | kernel = np.ones((enlarge_kernel, enlarge_kernel), np.uint8)
47 | mask = cv2.dilate(mask, kernel, iterations=1)
48 | mask = (mask > 0).astype(np.uint8)
49 |
50 | small_mask = cv2.resize(mask, (size[1] // 8, size[0] // 8), interpolation=cv2.INTER_NEAREST)
51 | mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST)
52 |
53 | if len(mask.shape) == 3:
54 | mask = mask[:, :, 0:1]
55 | else:
56 | mask = np.expand_dims(mask, axis=2)
57 |
58 | if len(small_mask.shape) == 3:
59 | small_mask = small_mask[:, :, 0:1]
60 | else:
61 | small_mask = np.expand_dims(small_mask, axis=2)
62 |
63 | img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
64 | mask = torch.from_numpy(mask).permute(2, 0, 1).contiguous().float()
65 | small_mask = torch.from_numpy(small_mask).permute(2, 0, 1).contiguous().float()
66 |
67 | return img*(1-mask), mask, small_mask
68 |
69 | def data_proprocess(self, img, mask, res):
70 | img = img.cpu().data.numpy()[0]
71 | mask = mask.data.numpy()[0]
72 | res = res.cpu().data.numpy()[0]
73 |
74 | res_complete = res * mask + img * (1. - mask)
75 | res_complete = (res_complete + 1) * 127.5
76 | res_complete = res_complete.transpose(1, 2, 0)
77 | if self.res_shape is not None:
78 | res_complete = cv2.resize(res_complete,
79 | (self.res_shape[1], self.res_shape[0]))
80 |
81 | return res_complete
82 |
83 |
84 | def parse_arges():
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument('--image_shape', type=int, nargs='+',
87 | default=[512, 960])
88 | parser.add_argument('--res_shape', type=int, nargs='+',
89 | default=None)
90 | parser.add_argument('--pretrained_model', type=str,
91 | default='/home/chengao/Weight/imagenet_deepfill.pth')
92 | parser.add_argument('--test_img', type=str,
93 | default='/work/cascades/chengao/DAVIS-540/bear_540p/00000.png')
94 | parser.add_argument('--test_mask', type=str,
95 | default='/work/cascades/chengao/DAVIS-540-baseline/mask_540p.png')
96 | parser.add_argument('--output_path', type=str,
97 | default='/home/chengao/res_00000.png')
98 |
99 | args = parser.parse_args()
100 |
101 | return args
102 |
103 |
104 | def main():
105 |
106 | args = parse_arges()
107 |
108 | deepfill = DeepFillv1(pretrained_model=args.pretrained_model,
109 | image_shape=args.image_shape,
110 | res_shape=args.res_shape)
111 |
112 | test_image = cv2.imread(args.test_img)
113 | mask = cv2.imread(args.test_mask, cv2.IMREAD_UNCHANGED)
114 |
115 | with torch.no_grad():
116 | img_res = deepfill.forward(test_image, mask)
117 |
118 | cv2.imwrite(args.output_path, img_res)
119 | print('Result Saved')
120 |
121 |
122 | if __name__ == '__main__':
123 | main()
124 |
--------------------------------------------------------------------------------
/RAFT/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/utils/region_fill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from scipy import sparse
4 | from scipy.sparse.linalg import spsolve
5 |
6 |
7 | def regionfill(I, mask, factor=1.0):
8 | if np.count_nonzero(mask) == 0:
9 | return I.copy()
10 | resize_mask = cv2.resize(
11 | mask.astype(float), (0, 0), fx=factor, fy=factor) > 0
12 | resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor)
13 | maskPerimeter = findBoundaryPixels(resize_mask)
14 | regionfillLaplace(resize_I, resize_mask, maskPerimeter)
15 | resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0]))
16 | resize_I[mask == 0] = I[mask == 0]
17 | return resize_I
18 |
19 |
20 | def findBoundaryPixels(mask):
21 | kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
22 | maskDilated = cv2.dilate(mask.astype(float), kernel)
23 | return (maskDilated > 0) & (mask == 0)
24 |
25 |
26 | def regionfillLaplace(I, mask, maskPerimeter):
27 | height, width = I.shape
28 | rightSide = formRightSide(I, maskPerimeter)
29 |
30 | # Location of mask pixels
31 | maskIdx = np.where(mask)
32 |
33 | # Only keep values for pixels that are in the mask
34 | rightSide = rightSide[maskIdx]
35 |
36 | # Number the mask pixels in a grid matrix
37 | grid = -np.ones((height, width))
38 | grid[maskIdx] = range(0, maskIdx[0].size)
39 | # Pad with zeros to avoid "index out of bounds" errors in the for loop
40 | grid = padMatrix(grid)
41 | gridIdx = np.where(grid >= 0)
42 |
43 | # Form the connectivity matrix D=sparse(i,j,s)
44 | # Connect each mask pixel to itself
45 | i = np.arange(0, maskIdx[0].size)
46 | j = np.arange(0, maskIdx[0].size)
47 | # The coefficient is the number of neighbors over which we average
48 | numNeighbors = computeNumberOfNeighbors(height, width)
49 | s = numNeighbors[maskIdx]
50 | # Now connect the N,E,S,W neighbors if they exist
51 | for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)):
52 | # Possible neighbors in the current direction
53 | neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]]
54 | # ConDnect mask points to neighbors with -1's
55 | index = (neighbors >= 0)
56 | i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]]))
57 | j = np.concatenate((j, neighbors[index]))
58 | s = np.concatenate((s, -np.ones(np.count_nonzero(index))))
59 |
60 | D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr()
61 | sol = spsolve(D, rightSide)
62 | I[maskIdx] = sol
63 | return I
64 |
65 |
66 | def formRightSide(I, maskPerimeter):
67 | height, width = I.shape
68 | perimeterValues = np.zeros((height, width))
69 | perimeterValues[maskPerimeter] = I[maskPerimeter]
70 | rightSide = np.zeros((height, width))
71 |
72 | rightSide[1:height - 1, 1:width - 1] = (
73 | perimeterValues[0:height - 2, 1:width - 1] +
74 | perimeterValues[2:height, 1:width - 1] +
75 | perimeterValues[1:height - 1, 0:width - 2] +
76 | perimeterValues[1:height - 1, 2:width])
77 |
78 | rightSide[1:height - 1, 0] = (
79 | perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] +
80 | perimeterValues[1:height - 1, 1])
81 |
82 | rightSide[1:height - 1, width - 1] = (
83 | perimeterValues[0:height - 2, width - 1] +
84 | perimeterValues[2:height, width - 1] +
85 | perimeterValues[1:height - 1, width - 2])
86 |
87 | rightSide[0, 1:width - 1] = (
88 | perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] +
89 | perimeterValues[0, 2:width])
90 |
91 | rightSide[height - 1, 1:width - 1] = (
92 | perimeterValues[height - 2, 1:width - 1] +
93 | perimeterValues[height - 1, 0:width - 2] +
94 | perimeterValues[height - 1, 2:width])
95 |
96 | rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0]
97 | rightSide[0, width - 1] = (
98 | perimeterValues[0, width - 2] + perimeterValues[1, width - 1])
99 | rightSide[height - 1, 0] = (
100 | perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1])
101 | rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] +
102 | perimeterValues[height - 1, width - 2])
103 | return rightSide
104 |
105 |
106 | def computeNumberOfNeighbors(height, width):
107 | # Initialize
108 | numNeighbors = np.zeros((height, width))
109 | # Interior pixels have 4 neighbors
110 | numNeighbors[1:height - 1, 1:width - 1] = 4
111 | # Border pixels have 3 neighbors
112 | numNeighbors[1:height - 1, (0, width - 1)] = 3
113 | numNeighbors[(0, height - 1), 1:width - 1] = 3
114 | # Corner pixels have 2 neighbors
115 | numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0,
116 | width - 1)] = 2
117 | return numNeighbors
118 |
119 |
120 | def padMatrix(grid):
121 | height, width = grid.shape
122 | gridPadded = -np.ones((height + 2, width + 2))
123 | gridPadded[1:height + 1, 1:width + 1] = grid
124 | gridPadded = gridPadded.astype(grid.dtype)
125 | return gridPadded
126 |
--------------------------------------------------------------------------------
/RAFT/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock, SmallUpdateBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in args._get_kwargs():
42 | args.dropout = 0
43 |
44 | if 'alternate_corr' not in args._get_kwargs():
45 | args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 |
59 | def freeze_bn(self):
60 | for m in self.modules():
61 | if isinstance(m, nn.BatchNorm2d):
62 | m.eval()
63 |
64 | def initialize_flow(self, img):
65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
66 | N, C, H, W = img.shape
67 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
68 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
69 |
70 | # optical flow computed as difference: flow = coords1 - coords0
71 | return coords0, coords1
72 |
73 | def upsample_flow(self, flow, mask):
74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
75 | N, _, H, W = flow.shape
76 | mask = mask.view(N, 1, 9, 8, 8, H, W)
77 | mask = torch.softmax(mask, dim=2)
78 |
79 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
81 |
82 | up_flow = torch.sum(mask * up_flow, dim=2)
83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
84 | return up_flow.reshape(N, 2, 8*H, 8*W)
85 |
86 |
87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
88 | """ Estimate optical flow between pair of frames """
89 |
90 | image1 = 2 * (image1 / 255.0) - 1.0
91 | image2 = 2 * (image2 / 255.0) - 1.0
92 |
93 | image1 = image1.contiguous()
94 | image2 = image2.contiguous()
95 |
96 | hdim = self.hidden_dim
97 | cdim = self.context_dim
98 |
99 | # run the feature network
100 | with autocast(enabled=self.args.mixed_precision):
101 | fmap1, fmap2 = self.fnet([image1, image2])
102 |
103 | fmap1 = fmap1.float()
104 | fmap2 = fmap2.float()
105 | if self.args.alternate_corr:
106 | corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
107 | else:
108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
109 |
110 | # run the context network
111 | with autocast(enabled=self.args.mixed_precision):
112 | cnet = self.cnet(image1)
113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
114 | net = torch.tanh(net)
115 | inp = torch.relu(inp)
116 |
117 | coords0, coords1 = self.initialize_flow(image1)
118 |
119 | if flow_init is not None:
120 | coords1 = coords1 + flow_init
121 |
122 | flow_predictions = []
123 | for itr in range(iters):
124 | coords1 = coords1.detach()
125 | corr = corr_fn(coords1) # index correlation volume
126 |
127 | flow = coords1 - coords0
128 | with autocast(enabled=self.args.mixed_precision):
129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
130 |
131 | # F(t+1) = F(t) + \Delta(t)
132 | coords1 = coords1 + delta_flow
133 |
134 | # upsample predictions
135 | if up_mask is None:
136 | flow_up = upflow8(coords1 - coords0)
137 | else:
138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
139 |
140 | flow_predictions.append(flow_up)
141 |
142 | if test_mode:
143 | return coords1 - coords0, flow_up
144 |
145 | return flow_predictions
146 |
--------------------------------------------------------------------------------
/edgeconnect/region_fill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from scipy import sparse
4 | from scipy.sparse.linalg import spsolve
5 |
6 |
7 | def regionfill(I, mask, factor=1.0):
8 | if np.count_nonzero(mask) == 0:
9 | return I.copy()
10 | resize_mask = cv2.resize(
11 | mask.astype(float), (0, 0), fx=factor, fy=factor) > 0
12 | resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor)
13 | maskPerimeter = findBoundaryPixels(resize_mask)
14 | regionfillLaplace(resize_I, resize_mask, maskPerimeter)
15 | resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0]))
16 | resize_I[mask == 0] = I[mask == 0]
17 | return resize_I
18 |
19 |
20 | def findBoundaryPixels(mask):
21 | kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
22 | maskDilated = cv2.dilate(mask.astype(float), kernel)
23 | return (maskDilated > 0) & (mask == 0)
24 |
25 |
26 | def regionfillLaplace(I, mask, maskPerimeter):
27 | height, width = I.shape
28 | rightSide = formRightSide(I, maskPerimeter)
29 |
30 | # Location of mask pixels
31 | maskIdx = np.where(mask)
32 |
33 | # Only keep values for pixels that are in the mask
34 | rightSide = rightSide[maskIdx]
35 |
36 | # Number the mask pixels in a grid matrix
37 | grid = -np.ones((height, width))
38 | grid[maskIdx] = range(0, maskIdx[0].size)
39 | # Pad with zeros to avoid "index out of bounds" errors in the for loop
40 | grid = padMatrix(grid)
41 | gridIdx = np.where(grid >= 0)
42 |
43 | # Form the connectivity matrix D=sparse(i,j,s)
44 | # Connect each mask pixel to itself
45 | i = np.arange(0, maskIdx[0].size)
46 | j = np.arange(0, maskIdx[0].size)
47 | # The coefficient is the number of neighbors over which we average
48 | numNeighbors = computeNumberOfNeighbors(height, width)
49 | s = numNeighbors[maskIdx]
50 | # Now connect the N,E,S,W neighbors if they exist
51 | for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)):
52 | # Possible neighbors in the current direction
53 | neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]]
54 | # ConDnect mask points to neighbors with -1's
55 | index = (neighbors >= 0)
56 | i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]]))
57 | j = np.concatenate((j, neighbors[index]))
58 | s = np.concatenate((s, -np.ones(np.count_nonzero(index))))
59 |
60 | D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr()
61 | sol = spsolve(D, rightSide)
62 | I[maskIdx] = sol
63 | return I
64 |
65 |
66 | def formRightSide(I, maskPerimeter):
67 | height, width = I.shape
68 | perimeterValues = np.zeros((height, width))
69 | perimeterValues[maskPerimeter] = I[maskPerimeter]
70 | rightSide = np.zeros((height, width))
71 |
72 | rightSide[1:height - 1, 1:width - 1] = (
73 | perimeterValues[0:height - 2, 1:width - 1] +
74 | perimeterValues[2:height, 1:width - 1] +
75 | perimeterValues[1:height - 1, 0:width - 2] +
76 | perimeterValues[1:height - 1, 2:width])
77 |
78 | rightSide[1:height - 1, 0] = (
79 | perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] +
80 | perimeterValues[1:height - 1, 1])
81 |
82 | rightSide[1:height - 1, width - 1] = (
83 | perimeterValues[0:height - 2, width - 1] +
84 | perimeterValues[2:height, width - 1] +
85 | perimeterValues[1:height - 1, width - 2])
86 |
87 | rightSide[0, 1:width - 1] = (
88 | perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] +
89 | perimeterValues[0, 2:width])
90 |
91 | rightSide[height - 1, 1:width - 1] = (
92 | perimeterValues[height - 2, 1:width - 1] +
93 | perimeterValues[height - 1, 0:width - 2] +
94 | perimeterValues[height - 1, 2:width])
95 |
96 | rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0]
97 | rightSide[0, width - 1] = (
98 | perimeterValues[0, width - 2] + perimeterValues[1, width - 1])
99 | rightSide[height - 1, 0] = (
100 | perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1])
101 | rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] +
102 | perimeterValues[height - 1, width - 2])
103 | return rightSide
104 |
105 |
106 | def computeNumberOfNeighbors(height, width):
107 | # Initialize
108 | numNeighbors = np.zeros((height, width))
109 | # Interior pixels have 4 neighbors
110 | numNeighbors[1:height - 1, 1:width - 1] = 4
111 | # Border pixels have 3 neighbors
112 | numNeighbors[1:height - 1, (0, width - 1)] = 3
113 | numNeighbors[(0, height - 1), 1:width - 1] = 3
114 | # Corner pixels have 2 neighbors
115 | numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0,
116 | width - 1)] = 2
117 | return numNeighbors
118 |
119 |
120 | def padMatrix(grid):
121 | height, width = grid.shape
122 | gridPadded = -np.ones((height + 2, width + 2))
123 | gridPadded[1:height + 1, 1:width + 1] = grid
124 | gridPadded = gridPadded.astype(grid.dtype)
125 | return gridPadded
126 |
127 |
128 | if __name__ == '__main__':
129 | import time
130 | x = np.linspace(0, 255, 500)
131 | xv, _ = np.meshgrid(x, x)
132 | image = ((xv + np.transpose(xv)) / 2.0).astype(int)
133 | mask = np.zeros((500, 500))
134 | mask[100:259, 100:259] = 1
135 | mask = (mask > 0)
136 | image[mask] = 0
137 | st = time.time()
138 | inpaint = regionfill(image, mask, 0.5).astype(np.uint8)
139 | print(time.time() - st)
140 | cv2.imshow('img', np.concatenate((image.astype(np.uint8), inpaint)))
141 | cv2.waitKey()
142 |
--------------------------------------------------------------------------------
/RAFT/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | '''
2 | 线上测试使用的是GPU版本(调的是piq库)。这里面我们也提供了CPU版的测试代码(调的是skimage库)。
3 | CPU版本和GPU版本分值有差别:此baseline的cpu测评结果为68.6891, gpu测评结果为68.7054
4 |
5 | error_code:
6 | -1 error: video number unmatch
7 | -2 error: image not found
8 | -3 error: image size unmatch
9 | '''
10 | import os
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import json
14 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr
15 | from skimage.metrics import structural_similarity as ssim
16 | import numpy as np
17 | import argparse
18 | import glob
19 | from PIL import Image
20 | import torch
21 | import cv2
22 | from piq import ssim, SSIMLoss
23 | from piq import psnr
24 |
25 | # CPU版本
26 | def PSNR(ximg,yimg):
27 | return compare_psnr(ximg,yimg,data_range=255)
28 |
29 | def SSIM(y,t,value_range=255):
30 | try:
31 | ssim_value = ssim(y, t, gaussian_weights=True, data_range=value_range, multichannel=True)
32 | except ValueError:
33 | #WinSize too small
34 | ssim_value = ssim(y, t, gaussian_weights=True, data_range=value_range, multichannel=True, win_size=3)
35 | return ssim_value
36 |
37 | # GPU版本。
38 | # def PSNR(ximg,yimg):
39 | # gt_tensor = torch.from_numpy(yimg.transpose([2,0,1])).unsqueeze(0).cuda()
40 | # img_tensor = torch.from_numpy(ximg.transpose([2,0,1])).unsqueeze(0).cuda()
41 | # psnr_v = psnr(img_tensor, gt_tensor, data_range=255.).item()
42 | # #print(psnr_v)
43 | # return psnr_v
44 |
45 | # def SSIM(y,t,value_range=255):
46 | # gt_ss_tensor = torch.from_numpy(y.transpose([2,0,1])).unsqueeze(0).cuda()
47 | # img_ss_tensor = torch.from_numpy(t.transpose([2,0,1])).unsqueeze(0).cuda()
48 | # loss = SSIMLoss(data_range=255.).cuda()
49 | # loss_v = loss(gt_ss_tensor, img_ss_tensor).item()
50 | # return 1-loss_v
51 |
52 |
53 | def Evaluate(files_gt, files_pred, methods = [PSNR,SSIM]):
54 | score = {}
55 | for meth in methods:
56 | name = meth.__name__
57 | results = []
58 | res=0
59 | frame_num=len(files_gt)
60 | for frame in range(0,frame_num):
61 | # ignore some tiny crops
62 | if files_gt[frame].shape[0]*files_gt[frame].shape[1]<150:
63 | continue
64 |
65 | res = meth(files_pred[frame],files_gt[frame])
66 | results.append(res)
67 |
68 | mres = np.mean(results)
69 | stdres=np.std(results)
70 | # print(name+": "+str(mres)+" Std: "+str(stdres))
71 | score['mean']=mres
72 | score['std']=stdres
73 | return score
74 |
75 |
76 | def evaluate(args):
77 | error_code=0
78 | error_flag='successful.'
79 | final_result=[]
80 |
81 | # load video folder
82 | grountruth_folder_list = sorted(glob.glob(os.path.join(args.groundtruth_folder, 'video_0*')))
83 | prediction_folder_list = sorted(glob.glob(os.path.join(args.prediction_folder,'video_0*')))
84 |
85 | if len(grountruth_folder_list) != len(prediction_folder_list):
86 | error_code=-1
87 | error_flag='folder number unmatch.'
88 | return error_code, error_flag, 0
89 |
90 | for i in range(0,len(grountruth_folder_list)):
91 | # load video
92 | video_gt=[]
93 | video_predict=[]
94 | image_list = sorted(glob.glob(os.path.join(grountruth_folder_list[i],'gt_crop/*.png')))
95 | for image_gt in image_list:
96 | video_gt.append(np.array(Image.open(image_gt)).astype(np.uint8))
97 |
98 | try:
99 | image_predict=prediction_folder_list[i]+'/crop_'+image_gt[-10:]
100 | video_predict.append(np.array(Image.open(image_predict)).astype(np.uint8))
101 | except:
102 | error_code=-2
103 | error_flag= 'read ' + image_predict +' failed.'
104 | return error_code, error_flag, 0
105 |
106 | # check image size
107 | for j in range(0,len(image_list)):
108 | if video_gt[j].shape!=video_predict[j].shape:
109 | error_code=-3
110 | error_flag= 'image size unmatch. please check video_' + str(i).zfill(4)+'/crop_'+str(j).zfill(4)
111 | return error_code, error_flag, 0
112 |
113 | # sent in whole video
114 | psnr_res = Evaluate(video_gt,video_predict, methods=[PSNR])
115 | ssim_res = Evaluate(video_gt,video_predict, methods=[SSIM])
116 |
117 | psnr_res_norm=min(80,psnr_res['mean'])
118 | ssim_res_norm=ssim_res['mean']*100
119 |
120 | result=psnr_res_norm+ssim_res_norm*0.5
121 | print(i,psnr_res_norm,ssim_res_norm,result)
122 |
123 | final_result.append(result)
124 | return error_code, error_flag, np.mean(final_result)
125 |
126 |
127 | if __name__ == "__main__":
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--groundtruth_folder',default='./')
130 | parser.add_argument('--prediction_folder',default='./')
131 | # usage: python scoring.py --groundtruth_folder ./val --prediction_folder ./result
132 |
133 | # groundtruth文件夹结构如下(跟解压后的一样):
134 | #
135 | # |-- 000000.png
136 | # |-- video_0000 -- gt_crops --|-- 000001.png
137 | # | |-- ...
138 | # val -|-- video_0001
139 | # |
140 | # |-- ...
141 |
142 | # 选手们上传的文件夹结构如下:
143 | #
144 | # |-- crop_000000.png
145 | # |-- video_0000 --|-- crop_000001.png
146 | # | |-- ...
147 | # result -|-- video_0001
148 | # |
149 | # |-- ...
150 |
151 | args = parser.parse_args()
152 | error_code, error_flag, final_result = evaluate(args)
153 | print(final_result)
154 |
155 |
156 |
--------------------------------------------------------------------------------
/utils/Poisson_blend.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import scipy.ndimage
4 | from scipy.sparse.linalg import spsolve
5 | from scipy import sparse
6 | import scipy.io as sio
7 | import numpy as np
8 | from PIL import Image
9 | import copy
10 | import cv2
11 | import os
12 | import argparse
13 |
14 |
15 | def sub2ind(pi, pj, imgH, imgW):
16 | return pj + pi * imgW
17 |
18 |
19 | def Poisson_blend(imgTrg, imgSrc_gx, imgSrc_gy, holeMask, edge=None):
20 |
21 | imgH, imgW, nCh = imgTrg.shape
22 |
23 | if not isinstance(edge, np.ndarray):
24 | edge = np.zeros((imgH, imgW), dtype=np.float32)
25 |
26 | # Initialize the reconstructed image
27 | imgRecon = np.zeros((imgH, imgW, nCh), dtype=np.float32)
28 |
29 | # prepare discrete Poisson equation
30 | A, b = solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, edge)
31 |
32 | # Independently process each channel
33 | for ch in range(nCh):
34 |
35 | # solve Poisson equation
36 | x = scipy.sparse.linalg.lsqr(A, b[:, ch, None])[0]
37 | imgRecon[:, :, ch] = x.reshape(imgH, imgW)
38 |
39 | # Combined with the known region in the target
40 | holeMaskC = np.tile(np.expand_dims(holeMask, axis=2), (1, 1, nCh))
41 | imgBlend = holeMaskC * imgRecon + (1 - holeMaskC) * imgTrg
42 |
43 | # Fill in edge pixel
44 | pi = np.expand_dims(np.where((holeMask * edge) == 1)[0], axis=1) # y, i
45 | pj = np.expand_dims(np.where((holeMask * edge) == 1)[1], axis=1) # x, j
46 |
47 | for k in range(len(pi)):
48 | if pi[k, 0] - 1 >= 0:
49 | if edge[pi[k, 0] - 1, pj[k, 0]] == 0:
50 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] - 1, pj[k, 0], :]
51 | continue
52 | if pi[k, 0] + 1 <= imgH - 1:
53 | if edge[pi[k, 0] + 1, pj[k, 0]] == 0:
54 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] + 1, pj[k, 0], :]
55 | continue
56 | if pj[k, 0] - 1 >= 0:
57 | if edge[pi[k, 0], pj[k, 0] - 1] == 0:
58 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] - 1, :]
59 | continue
60 | if pj[k, 0] + 1 <= imgW - 1:
61 | if edge[pi[k, 0], pj[k, 0] + 1] == 0:
62 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] + 1, :]
63 |
64 | return imgBlend
65 |
66 | def solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, edge):
67 |
68 | # Prepare the linear system of equations for Poisson blending
69 | imgH, imgW = holeMask.shape
70 | N = imgH * imgW
71 |
72 | # Number of unknown variables
73 | numUnknownPix = holeMask.sum()
74 |
75 | # 4-neighbors: dx and dy
76 | dx = [1, 0, -1, 0]
77 | dy = [0, 1, 0, -1]
78 |
79 | # 3
80 | # |
81 | # 2 -- * -- 0
82 | # |
83 | # 1
84 | #
85 |
86 | # Initialize (I, J, S), for sparse matrix A where A(I(k), J(k)) = S(k)
87 | I = np.empty((0, 1), dtype=np.float32)
88 | J = np.empty((0, 1), dtype=np.float32)
89 | S = np.empty((0, 1), dtype=np.float32)
90 |
91 | # Initialize b
92 | b = np.empty((0, 2), dtype=np.float32)
93 |
94 | # Precompute unkonwn pixel position
95 | pi = np.expand_dims(np.where(holeMask == 1)[0], axis=1) # y, i
96 | pj = np.expand_dims(np.where(holeMask == 1)[1], axis=1) # x, j
97 | pind = sub2ind(pi, pj, imgH, imgW)
98 |
99 | # |--------------------|
100 | # | y (i) |
101 | # | x (j) * |
102 | # | |
103 | # |--------------------|
104 |
105 | qi = np.concatenate((pi + dy[0],
106 | pi + dy[1],
107 | pi + dy[2],
108 | pi + dy[3]), axis=1)
109 |
110 | qj = np.concatenate((pj + dx[0],
111 | pj + dx[1],
112 | pj + dx[2],
113 | pj + dx[3]), axis=1)
114 |
115 | # Handling cases at image borders
116 | validN = (qi >= 0) & (qi <= imgH - 1) & (qj >= 0) & (qj <= imgW - 1)
117 | qind = np.zeros((validN.shape), dtype=np.float32)
118 | qind[validN] = sub2ind(qi[validN], qj[validN], imgH, imgW)
119 |
120 | e_start = 0 # equation counter start
121 | e_stop = 0 # equation stop
122 |
123 | # 4 neighbors
124 | I, J, S, b, e_start, e_stop = constructEquation(0, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
125 | I, J, S, b, e_start, e_stop = constructEquation(1, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
126 | I, J, S, b, e_start, e_stop = constructEquation(2, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
127 | I, J, S, b, e_start, e_stop = constructEquation(3, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
128 |
129 | nEqn = len(b)
130 | # Construct the sparse matrix A
131 | A = sparse.csr_matrix((S[:, 0], (I[:, 0], J[:, 0])), shape=(nEqn, N))
132 |
133 | return A, b
134 |
135 |
136 | def constructEquation(n, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop):
137 |
138 | # Pixel that has valid neighbors
139 | validNeighbor = validN[:, n]
140 |
141 | # Change the out-of-boundary value to 0, in order to run edge[y,x]
142 | # in the next line. It won't affect anything as validNeighbor is saved already
143 |
144 | qi_tmp = copy.deepcopy(qi)
145 | qj_tmp = copy.deepcopy(qj)
146 | qi_tmp[np.invert(validNeighbor), n] = 0
147 | qj_tmp[np.invert(validNeighbor), n] = 0
148 |
149 | # Not edge
150 | NotEdge = (edge[pi[:, 0], pj[:, 0]] == 0) * (edge[qi_tmp[:, n], qj_tmp[:, n]] == 0)
151 |
152 | # Boundary constraint
153 | Boundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 0
154 | valid = validNeighbor * NotEdge * Boundary
155 | J_tmp = pind[valid, :]
156 |
157 | # num of equations: len(J_tmp)
158 | e_stop = e_start + len(J_tmp)
159 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1)
160 | e_start = e_stop
161 |
162 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32)
163 |
164 | if n == 0:
165 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :]
166 | elif n == 2:
167 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + imgTrg[qi[valid, n], qj[valid, n], :]
168 | elif n == 1:
169 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :]
170 | elif n == 3:
171 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :]
172 |
173 | I = np.concatenate((I, I_tmp))
174 | J = np.concatenate((J, J_tmp))
175 | S = np.concatenate((S, S_tmp))
176 | b = np.concatenate((b, b_tmp))
177 |
178 |
179 | # Non-boundary constraint
180 | NonBoundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 1
181 | valid = validNeighbor * NotEdge * NonBoundary
182 |
183 | J_tmp = pind[valid, :]
184 |
185 | # num of equations: len(J_tmp)
186 | e_stop = e_start + len(J_tmp)
187 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1)
188 | e_start = e_stop
189 |
190 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32)
191 |
192 | if n == 0:
193 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :]
194 | elif n == 2:
195 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :]
196 | elif n == 1:
197 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :]
198 | elif n == 3:
199 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :]
200 |
201 | I = np.concatenate((I, I_tmp))
202 | J = np.concatenate((J, J_tmp))
203 | S = np.concatenate((S, S_tmp))
204 | b = np.concatenate((b, b_tmp))
205 |
206 | S_tmp = - np.ones(J_tmp.shape, dtype=np.float32)
207 | J_tmp = qind[valid, n, None]
208 |
209 | I = np.concatenate((I, I_tmp))
210 | J = np.concatenate((J, J_tmp))
211 | S = np.concatenate((S, S_tmp))
212 |
213 | return I, J, S, b, e_start, e_stop
214 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Baseline
2 | 本代码的基本流程:计算稠密光流(RAFT)-> 计算边缘(Canny)-> 补全边缘(EdgeConnect)-> 补全光流(解Ax=b)-> 传播RGB值(能从时序上拿来补的像素就拿来补;没有的话就拿最空的一帧来进行图像补全(Deepfill_V1)后再传播)
3 | 在基于原作代码的基础上,为了增速,我修改了一下其中的光流补全部分。原作是全图去解Ax=b,特别特别慢。我改成了crop后再进去,解Ax=b会快些。
4 |
5 | 用法:
6 | ```bash
7 | python ./tool/video_completion_modified.py --mode object_removal --path ../data/test_a/video_0000/frames_corr --path_mask ../data/test_a/video_0000/masks --outroot ../data/result_test_a/video_0000 --seamless --edge_guide
8 | ```
9 |
10 |
11 | 该baseline:
12 | 1) 在test_a数据集上,本baseline的最终分数约为68.7054。
13 | 2) 速度慢。CPU:Gold5218@2.30GHz, GPU:NVIDIA-V100, 无多线程跑了一天多。
14 | 3) 内存耗用超过8G。显存没有超过8G。
15 | 4) 效果不均。在部分视频上效果较好(e.g.舞动的人,移动的物体等,因有时序上的信息可补足),在部分视频上效果较差(e.g.水印,固定位置的物体等)。
16 | 5) 如果选手选择直接在该baseline改一改训一训,那么在速度和内存上都需要优化。
17 |
18 | 温馨提示:
19 | 1) 在比赛官网提交结果时,顶上将有进度条,且提交成功后会有提示。接收到"提交成功"的提示前不要关掉页面哦。
20 | 2) 评分失败时提示"folder number unmatch"的错误时,原因可能有以下两个: 1.即为提交的视频文件夹数量有错, video_*** 的数量要100个。请查看是否多了不相关文件夹/或者是少了某些视频文件夹; 2. 请直接从内部打包,即result.zip解压后即为 result/video_**** 而不是 aaa/result/video_****
21 | 3) 评分失败时提示"image not found"的错误时,请检查每个文件夹里的图片个数是否完整,命名是否正确。
22 | 4) 评分失败时提示"image size unmatch"的错误时,请检查图片大小是否如bbox.txt所示。
23 |
24 |
25 | 写了一个简易教程(大佬们可以忽略这个教程):https://zhuanlan.zhihu.com/p/381449269
26 |
27 |
28 | # 第二届“马栏山杯”国际音视频算法大赛-视频补全介绍
29 |
30 | ## 数据说明
31 | **训练集**:2194个视频,格式为mp4,视频时长2s~8s。视频大小主要为720x1280和1080x1920。
32 | train1(843个视频,压缩包大小1.98GB),与去年点位跟踪赛道research_1视频集相同(去年数据集里的重复视频现已去重)。
33 | train2(1010个视频,压缩包大小2.18GB),与去年点位跟踪赛道research_2视频集相同(去年数据集里的重复视频现已去重)。
34 | train3(341个视频,压缩包大小646M),为去年点位跟踪赛道val视频集 + 本赛道新视频。
35 |
36 | **验证集**:50个视频,格式为png图像,图像大小为576x1024(我们尽量裁剪掉了角标和字幕区域),帧数120帧及以下。每个样本由原视频和任意一种mask(整块挖空、局部水印、随机噪声块、人像)组成。
37 |
38 | **测试集a**: 100个挖空视频,格式为png图像,图像大小为576x1024(我们尽量裁剪掉了角标和字幕区域),帧数120帧及以下。每个样本包含成对的挖空原视频和其相对应的mask。
39 | **测试集b**: 格式同测试集a。
40 | (测试集的各个mask类型所占的比例与验证集相似。)
41 |
42 | 本竞赛要求整个竞赛过程中不能采用第三方数据,可以使用开源的预训练模型(且开源时间在2021年6月以前)。
43 | 验证集允许加入训练,测试集**禁止**加入训练。
44 | 本大赛提供的数据版权归芒果TV所有,参赛选手不能将其泄露或用于本大赛外其他用途。
45 |
46 | ## 数据下载链接
47 | 百度云链接:https://pan.baidu.com/s/1xUvLiAw3YGv6s1Ll780yUg
48 | 提取码:mgtv
49 |
50 | 直接下载链接:
51 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/train_1.zip
52 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/train_2.zip
53 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/train_3.zip
54 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/val.zip
55 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/test_a.zip
56 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/b/test_b.zip
57 |
58 | 用来校验文件传输完整的MD5值:
59 | train_1.zip: 83ced2b4e80231105eb6dc8d2fae9e29
60 | train_2.zip: 792a21b974d83084c9e2bf81af0c9e10
61 | train_3.zip: e07011f6f1d149c6508a0ae50a76e18b
62 | val.zip: 6d16b879cd618358f941477eeed9a4bd
63 | test_a.zip: b146ed76a53f556fe36c9012da03bf94
64 | test_b.zip: f2f8b00df148621f25270ad1fd5e2362
65 |
66 | linux下命令:md5sum test_a.zip
67 | windows下命令:certutil -hashfile test_a.zip MD5
68 |
69 | ## 评估指标
70 | 初赛和复赛通过评估选手提交的结果来评分,本次比赛采用PSNR和SSIM两种评价指标。对于上传的结果,评估程序将计算挖空区域的PSNR和SSIM两种指标,均采用逐帧计算并进行平均。最终,PSNR和SSIM进行加权计算,并得到最终竞赛得分。PSNR取值在[0, 80],优秀分值范围大约在[30, 50];SSIM取值在[0, 1],优秀分值范围大约在[0.8, 1]
71 | ```bash
72 | score = PSNR*2*0.5 + SSIM*100*0.5
73 | ```
74 | PSNR和SSIM的具体计算方式可看evaluate.py
75 |
76 | ## 算力要求
77 | 1) 前向推理时的内存使用不超过8G,显存使用不超过8G。
78 | 2) 在CPU(双核,2.30GHz),GPU(单卡,NVIDIA V100或NVIDIA 3090),按顺序跑完测试集(100个视频,不开多线程)的时间不超过15小时。
79 |
80 | 未满足以上限制的参赛队伍,大赛官方有权将最终总成绩认定为无效,排名由后一名依次递补。
81 |
82 | ## 作品提交要求
83 | 初赛和复赛的结果提交方式相同,都需要提交裁剪后的图片。为了缩小上传结果的大小,选手需根据各个视频内提供的bbox.txt提供的裁剪框[x,y,w,h]裁剪出对应的结果图片并打包上传。
84 | 其中,x,y的索引从0开始。例如:
85 | ```bash
86 | import cv2
87 | img = cv2.imread("result_000000.png")
88 | crop_img = img[y:y+h, x:x+w, :]
89 | cv2.imwrite("crop_000000.png", crop_img)
90 | ```
91 | 选手需要将裁剪后的图片文件按放入各个视频文件夹(video_0***),最后一起打包成*.zip格式后上传(正常大小不超过2G)。请直接在内部打包,即result.zip解压后即为 result/video_0*** 而不是 aaa/result/video_0***。
92 | 文件夹结构和命名规则如下(以test_a为例) ("result"可任意命名, 你可以随意命名成result_0608, aaa_123等等) :
93 | ```bash
94 | |—— crop_000000.png
95 | |—— video_0000 |—— crop_000001.png
96 | | |—— crop_...
97 | result |—— video_...
98 | | |—— crop_000000.png
99 | |—— video_0099 |—— crop_000001.png
100 | |—— crop_...
101 | ```
102 | ## 评测及排行
103 | 1) 初赛和复赛均提供下载数据,选手在本地进行算法调试,在比赛页面提交结果;
104 | 2) 初赛和复赛采用AB榜机制,A榜成绩供参赛队伍在比赛中查看,最终比赛排名(包括初赛和复赛)采用B榜最佳成绩;
105 | 3) 复赛TOP10团队需提供源代码、以及docker镜像,供大赛组委会进行方案和结果验证,只有复现结果的性能指标和提交最优结果的性能指标差异在允许的范围内,选手的成绩才是真实有效的;
106 | 4) 复赛TOP10团队必须提交一份**技术方案报告**来阐述详细方案 (建议长度1-4页)
107 | 5) 每支团队每天最多提交3次;
108 | 6) 排行按照得分从高到低排序,排行榜将选择团队的历史最优成绩进行排名;
109 |
110 |
111 | 下面是原作的README.md
112 | # [ECCV 2020] Flow-edge Guided Video Completion
113 |
114 | ### [[Paper](https://arxiv.org/abs/2009.01835)] [[Project Website](http://chengao.vision/FGVC/)] [[Google Colab](https://colab.research.google.com/drive/1pb6FjWdwq_q445rG2NP0dubw7LKNUkqc?usp=sharing)]
115 |
116 |
117 |
118 |
119 |
120 | We present a new flow-based video completion algorithm. Previous flow completion methods are often unable to retain the sharpness of motion boundaries. Our method first extracts and completes motion edges, and then uses them to guide piecewise-smooth flow completion with sharp edges. Existing methods propagate colors among local flow connections between adjacent frames. However, not all missing regions in a video can be reached in this way because the motion boundaries form impenetrable barriers. Our method alleviates this problem by introducing non-local flow connections to temporally distant frames, enabling propagating video content over motion boundaries. We validate our approach on the DAVIS dataset. Both visual and quantitative results show that our method compares favorably against the state-of-the-art algorithms.
121 |
122 |
123 | **[ECCV 2020] Flow-edge Guided Video Completion**
124 |
125 | [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/), and [Johannes Kopf](https://johanneskopf.de/)
126 |
127 | In European Conference on Computer Vision (ECCV), 2020
128 |
129 | ## Prerequisites
130 |
131 | - Linux (tested on CentOS Linux release 7.4.1708)
132 | - Anaconda
133 | - Python 3.8 (tested on 3.8.5)
134 | - PyTorch 1.6.0
135 |
136 | and the Python dependencies listed in requirements.txt
137 |
138 | - To get started, please run the following commands:
139 | ```
140 | conda create -n FGVC
141 | conda activate FGVC
142 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
143 | conda install matplotlib scipy
144 | pip install -r requirements.txt
145 | ```
146 |
147 | - Next, please download the model weight and demo data using the following command:
148 | ```
149 | chmod +x download_data_weights.sh
150 | ./download_data_weights.sh
151 | ```
152 |
153 | ## Quick start
154 |
155 | - Object removal:
156 | ```bash
157 | cd tool
158 | python video_completion.py \
159 | --mode object_removal \
160 | --path ../data/tennis \
161 | --path_mask ../data/tennis_mask \
162 | --outroot ../result/tennis_removal \
163 | --seamless
164 | ```
165 |
166 | - FOV extrapolation:
167 | ```bash
168 | cd tool
169 | python video_completion.py \
170 | --mode video_extrapolation \
171 | --path ../data/tennis \
172 | --outroot ../result/tennis_extrapolation \
173 | --H_scale 2 \
174 | --W_scale 2 \
175 | --seamless
176 | ```
177 |
178 | You can remove the `--seamless` flag for a faster processing time.
179 |
180 |
181 | ## License
182 | This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
183 |
184 | If you find this code useful for your research, please consider citing the following paper:
185 |
186 | @inproceedings{Gao-ECCV-FGVC,
187 | author = {Gao, Chen and Saraf, Ayush and Huang, Jia-Bin and Kopf, Johannes},
188 | title = {Flow-edge Guided Video Completion},
189 | booktitle = {European Conference on Computer Vision},
190 | year = {2020}
191 | }
192 |
193 | ## Acknowledgments
194 | - Our flow edge completion network builds upon [EdgeConnect](https://github.com/knazeri/edge-connect).
195 | - Our image inpainting network is modified from [DFVI](https://github.com/nbei/Deep-Flow-Guided-Video-Inpainting).
196 |
--------------------------------------------------------------------------------
/edgeconnect/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from PIL import Image
8 |
9 |
10 | def create_dir(dir):
11 | if not os.path.exists(dir):
12 | os.makedirs(dir)
13 |
14 |
15 | def create_mask(width, height, mask_width, mask_height, x=None, y=None):
16 | mask = np.zeros((height, width))
17 | mask_x = x if x is not None else random.randint(0, width - mask_width)
18 | mask_y = y if y is not None else random.randint(0, height - mask_height)
19 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
20 | return mask
21 |
22 |
23 | def stitch_images(inputs, *outputs, img_per_row=2):
24 | gap = 5
25 | columns = len(outputs) + 1
26 |
27 | height, width = inputs[0][:, :, 0].shape
28 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
29 | images = [inputs, *outputs]
30 |
31 | for ix in range(len(inputs)):
32 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
33 | yoffset = int(ix / img_per_row) * height
34 |
35 | for cat in range(len(images)):
36 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
37 | im = Image.fromarray(im)
38 | img.paste(im, (xoffset + cat * width, yoffset))
39 |
40 | return img
41 |
42 |
43 | def imshow(img, title=''):
44 | fig = plt.gcf()
45 | fig.canvas.set_window_title(title)
46 | plt.axis('off')
47 | plt.imshow(img, interpolation='none')
48 | plt.show()
49 |
50 |
51 | def imsave(img, path):
52 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze())
53 | im.save(path)
54 |
55 |
56 | class Progbar(object):
57 | """Displays a progress bar.
58 |
59 | Arguments:
60 | target: Total number of steps expected, None if unknown.
61 | width: Progress bar width on screen.
62 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
63 | stateful_metrics: Iterable of string names of metrics that
64 | should *not* be averaged over time. Metrics in this list
65 | will be displayed as-is. All others will be averaged
66 | by the progbar before display.
67 | interval: Minimum visual progress update interval (in seconds).
68 | """
69 |
70 | def __init__(self, target, width=25, verbose=1, interval=0.05,
71 | stateful_metrics=None):
72 | self.target = target
73 | self.width = width
74 | self.verbose = verbose
75 | self.interval = interval
76 | if stateful_metrics:
77 | self.stateful_metrics = set(stateful_metrics)
78 | else:
79 | self.stateful_metrics = set()
80 |
81 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
82 | sys.stdout.isatty()) or
83 | 'ipykernel' in sys.modules or
84 | 'posix' in sys.modules)
85 | self._total_width = 0
86 | self._seen_so_far = 0
87 | # We use a dict + list to avoid garbage collection
88 | # issues found in OrderedDict
89 | self._values = {}
90 | self._values_order = []
91 | self._start = time.time()
92 | self._last_update = 0
93 |
94 | def update(self, current, values=None):
95 | """Updates the progress bar.
96 |
97 | Arguments:
98 | current: Index of current step.
99 | values: List of tuples:
100 | `(name, value_for_last_step)`.
101 | If `name` is in `stateful_metrics`,
102 | `value_for_last_step` will be displayed as-is.
103 | Else, an average of the metric over time will be displayed.
104 | """
105 | values = values or []
106 | for k, v in values:
107 | if k not in self._values_order:
108 | self._values_order.append(k)
109 | if k not in self.stateful_metrics:
110 | if k not in self._values:
111 | self._values[k] = [v * (current - self._seen_so_far),
112 | current - self._seen_so_far]
113 | else:
114 | self._values[k][0] += v * (current - self._seen_so_far)
115 | self._values[k][1] += (current - self._seen_so_far)
116 | else:
117 | self._values[k] = v
118 | self._seen_so_far = current
119 |
120 | now = time.time()
121 | info = ' - %.0fs' % (now - self._start)
122 | if self.verbose == 1:
123 | if (now - self._last_update < self.interval and
124 | self.target is not None and current < self.target):
125 | return
126 |
127 | prev_total_width = self._total_width
128 | if self._dynamic_display:
129 | sys.stdout.write('\b' * prev_total_width)
130 | sys.stdout.write('\r')
131 | else:
132 | sys.stdout.write('\n')
133 |
134 | if self.target is not None:
135 | numdigits = int(np.floor(np.log10(self.target))) + 1
136 | barstr = '%%%dd/%d [' % (numdigits, self.target)
137 | bar = barstr % current
138 | prog = float(current) / self.target
139 | prog_width = int(self.width * prog)
140 | if prog_width > 0:
141 | bar += ('=' * (prog_width - 1))
142 | if current < self.target:
143 | bar += '>'
144 | else:
145 | bar += '='
146 | bar += ('.' * (self.width - prog_width))
147 | bar += ']'
148 | else:
149 | bar = '%7d/Unknown' % current
150 |
151 | self._total_width = len(bar)
152 | sys.stdout.write(bar)
153 |
154 | if current:
155 | time_per_unit = (now - self._start) / current
156 | else:
157 | time_per_unit = 0
158 | if self.target is not None and current < self.target:
159 | eta = time_per_unit * (self.target - current)
160 | if eta > 3600:
161 | eta_format = '%d:%02d:%02d' % (eta // 3600,
162 | (eta % 3600) // 60,
163 | eta % 60)
164 | elif eta > 60:
165 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
166 | else:
167 | eta_format = '%ds' % eta
168 |
169 | info = ' - ETA: %s' % eta_format
170 | else:
171 | if time_per_unit >= 1:
172 | info += ' %.0fs/step' % time_per_unit
173 | elif time_per_unit >= 1e-3:
174 | info += ' %.0fms/step' % (time_per_unit * 1e3)
175 | else:
176 | info += ' %.0fus/step' % (time_per_unit * 1e6)
177 |
178 | for k in self._values_order:
179 | info += ' - %s:' % k
180 | if isinstance(self._values[k], list):
181 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
182 | if abs(avg) > 1e-3:
183 | info += ' %.4f' % avg
184 | else:
185 | info += ' %.4e' % avg
186 | else:
187 | info += ' %s' % self._values[k]
188 |
189 | self._total_width += len(info)
190 | if prev_total_width > self._total_width:
191 | info += (' ' * (prev_total_width - self._total_width))
192 |
193 | if self.target is not None and current >= self.target:
194 | info += '\n'
195 |
196 | sys.stdout.write(info)
197 | sys.stdout.flush()
198 |
199 | elif self.verbose == 2:
200 | if self.target is None or current >= self.target:
201 | for k in self._values_order:
202 | info += ' - %s:' % k
203 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
204 | if avg > 1e-3:
205 | info += ' %.4f' % avg
206 | else:
207 | info += ' %.4e' % avg
208 | info += '\n'
209 |
210 | sys.stdout.write(info)
211 | sys.stdout.flush()
212 |
213 | self._last_update = now
214 |
215 | def add(self, n, values=None):
216 | self.update(self._seen_so_far + n, values)
217 |
--------------------------------------------------------------------------------
/edgeconnect/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 |
6 |
7 | class TotalVariationalLoss(nn.Module):
8 | def __init__(self):
9 | super(TotalVariationalLoss, self).__init__()
10 |
11 | def _tensor_size(self, x):
12 | return x.size()[1] * x.size()[2] * x.size()[3]
13 |
14 | def __call__(self, x):
15 |
16 | batch_size = x.size()[0]
17 | h_x = x.size()[2]
18 | w_x = x.size()[3]
19 | count_h = self._tensor_size(x[:, :, 1:, :])
20 | count_w = self._tensor_size(x[:, :, :, 1:])
21 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
22 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
23 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
24 |
25 |
26 | class AdversarialLoss(nn.Module):
27 | r"""
28 | Adversarial loss
29 | https://arxiv.org/abs/1711.10337
30 | """
31 |
32 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
33 | r"""
34 | type = nsgan | lsgan | hinge
35 | """
36 | super(AdversarialLoss, self).__init__()
37 |
38 | self.type = type
39 | self.register_buffer('real_label', torch.tensor(target_real_label))
40 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
41 |
42 | if type == 'nsgan':
43 | self.criterion = nn.BCELoss()
44 |
45 | elif type == 'lsgan':
46 | self.criterion = nn.MSELoss()
47 |
48 | elif type == 'hinge':
49 | self.criterion = nn.ReLU()
50 |
51 | def __call__(self, outputs, is_real, is_disc=None):
52 | if self.type == 'hinge':
53 | if is_disc:
54 | if is_real:
55 | outputs = -outputs
56 | return self.criterion(1 + outputs).mean()
57 | else:
58 | return (-outputs).mean()
59 |
60 | else:
61 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
62 | loss = self.criterion(outputs, labels)
63 | return loss
64 |
65 |
66 | class StyleLoss(nn.Module):
67 | r"""
68 | Perceptual loss, VGG-based
69 | https://arxiv.org/abs/1603.08155
70 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
71 | """
72 |
73 | def __init__(self):
74 | super(StyleLoss, self).__init__()
75 | self.add_module('vgg', VGG19())
76 | self.criterion = torch.nn.L1Loss()
77 |
78 | def compute_gram(self, x):
79 | b, ch, h, w = x.size()
80 | f = x.view(b, ch, w * h)
81 | f_T = f.transpose(1, 2)
82 | G = f.bmm(f_T) / (h * w * ch)
83 |
84 | return G
85 |
86 | def __call__(self, x, y):
87 | # Compute features
88 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
89 |
90 | # Compute loss
91 | style_loss = 0.0
92 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
93 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
94 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
95 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
96 |
97 | return style_loss
98 |
99 |
100 |
101 | class PerceptualLoss(nn.Module):
102 | r"""
103 | Perceptual loss, VGG-based
104 | https://arxiv.org/abs/1603.08155
105 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
106 | """
107 |
108 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
109 | super(PerceptualLoss, self).__init__()
110 | self.add_module('vgg', VGG19())
111 | self.criterion = torch.nn.L1Loss()
112 | self.weights = weights
113 |
114 | def __call__(self, x, y):
115 | # Compute features
116 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
117 |
118 | content_loss = 0.0
119 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
120 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
121 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
122 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
123 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
124 |
125 |
126 | return content_loss
127 |
128 |
129 |
130 | class VGG19(torch.nn.Module):
131 | def __init__(self):
132 | super(VGG19, self).__init__()
133 | features = models.vgg19(pretrained=True).features
134 | self.relu1_1 = torch.nn.Sequential()
135 | self.relu1_2 = torch.nn.Sequential()
136 |
137 | self.relu2_1 = torch.nn.Sequential()
138 | self.relu2_2 = torch.nn.Sequential()
139 |
140 | self.relu3_1 = torch.nn.Sequential()
141 | self.relu3_2 = torch.nn.Sequential()
142 | self.relu3_3 = torch.nn.Sequential()
143 | self.relu3_4 = torch.nn.Sequential()
144 |
145 | self.relu4_1 = torch.nn.Sequential()
146 | self.relu4_2 = torch.nn.Sequential()
147 | self.relu4_3 = torch.nn.Sequential()
148 | self.relu4_4 = torch.nn.Sequential()
149 |
150 | self.relu5_1 = torch.nn.Sequential()
151 | self.relu5_2 = torch.nn.Sequential()
152 | self.relu5_3 = torch.nn.Sequential()
153 | self.relu5_4 = torch.nn.Sequential()
154 |
155 | for x in range(2):
156 | self.relu1_1.add_module(str(x), features[x])
157 |
158 | for x in range(2, 4):
159 | self.relu1_2.add_module(str(x), features[x])
160 |
161 | for x in range(4, 7):
162 | self.relu2_1.add_module(str(x), features[x])
163 |
164 | for x in range(7, 9):
165 | self.relu2_2.add_module(str(x), features[x])
166 |
167 | for x in range(9, 12):
168 | self.relu3_1.add_module(str(x), features[x])
169 |
170 | for x in range(12, 14):
171 | self.relu3_2.add_module(str(x), features[x])
172 |
173 | for x in range(14, 16):
174 | self.relu3_3.add_module(str(x), features[x])
175 |
176 | for x in range(16, 18):
177 | self.relu3_4.add_module(str(x), features[x])
178 |
179 | for x in range(18, 21):
180 | self.relu4_1.add_module(str(x), features[x])
181 |
182 | for x in range(21, 23):
183 | self.relu4_2.add_module(str(x), features[x])
184 |
185 | for x in range(23, 25):
186 | self.relu4_3.add_module(str(x), features[x])
187 |
188 | for x in range(25, 27):
189 | self.relu4_4.add_module(str(x), features[x])
190 |
191 | for x in range(27, 30):
192 | self.relu5_1.add_module(str(x), features[x])
193 |
194 | for x in range(30, 32):
195 | self.relu5_2.add_module(str(x), features[x])
196 |
197 | for x in range(32, 34):
198 | self.relu5_3.add_module(str(x), features[x])
199 |
200 | for x in range(34, 36):
201 | self.relu5_4.add_module(str(x), features[x])
202 |
203 | # don't need the gradients, just want the features
204 | for param in self.parameters():
205 | param.requires_grad = False
206 |
207 | def forward(self, x):
208 | relu1_1 = self.relu1_1(x)
209 | relu1_2 = self.relu1_2(relu1_1)
210 |
211 | relu2_1 = self.relu2_1(relu1_2)
212 | relu2_2 = self.relu2_2(relu2_1)
213 |
214 | relu3_1 = self.relu3_1(relu2_2)
215 | relu3_2 = self.relu3_2(relu3_1)
216 | relu3_3 = self.relu3_3(relu3_2)
217 | relu3_4 = self.relu3_4(relu3_3)
218 |
219 | relu4_1 = self.relu4_1(relu3_4)
220 | relu4_2 = self.relu4_2(relu4_1)
221 | relu4_3 = self.relu4_3(relu4_2)
222 | relu4_4 = self.relu4_4(relu4_3)
223 |
224 | relu5_1 = self.relu5_1(relu4_4)
225 | relu5_2 = self.relu5_2(relu5_1)
226 | relu5_3 = self.relu5_3(relu5_2)
227 | relu5_4 = self.relu5_4(relu5_3)
228 |
229 | out = {
230 | 'relu1_1': relu1_1,
231 | 'relu1_2': relu1_2,
232 |
233 | 'relu2_1': relu2_1,
234 | 'relu2_2': relu2_2,
235 |
236 | 'relu3_1': relu3_1,
237 | 'relu3_2': relu3_2,
238 | 'relu3_3': relu3_3,
239 | 'relu3_4': relu3_4,
240 |
241 | 'relu4_1': relu4_1,
242 | 'relu4_2': relu4_2,
243 | 'relu4_3': relu4_3,
244 | 'relu4_4': relu4_4,
245 |
246 | 'relu5_1': relu5_1,
247 | 'relu5_2': relu5_2,
248 | 'relu5_3': relu5_3,
249 | 'relu5_4': relu5_4,
250 | }
251 | return out
252 |
--------------------------------------------------------------------------------
/edgeconnect/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BaseNetwork(nn.Module):
6 | def __init__(self):
7 | super(BaseNetwork, self).__init__()
8 |
9 | def init_weights(self, init_type='normal', gain=0.02):
10 | '''
11 | initialize network's weights
12 | init_type: normal | xavier | kaiming | orthogonal
13 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
14 | '''
15 |
16 | def init_func(m):
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | nn.init.normal_(m.weight.data, 0.0, gain)
21 | elif init_type == 'xavier':
22 | nn.init.xavier_normal_(m.weight.data, gain=gain)
23 | elif init_type == 'kaiming':
24 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | nn.init.orthogonal_(m.weight.data, gain=gain)
27 |
28 | if hasattr(m, 'bias') and m.bias is not None:
29 | nn.init.constant_(m.bias.data, 0.0)
30 |
31 | elif classname.find('BatchNorm2d') != -1:
32 | nn.init.normal_(m.weight.data, 1.0, gain)
33 | nn.init.constant_(m.bias.data, 0.0)
34 |
35 | self.apply(init_func)
36 |
37 |
38 | class InpaintGenerator(BaseNetwork):
39 | def __init__(self, config, residual_blocks=8, init_weights=True):
40 | super(InpaintGenerator, self).__init__()
41 | self.config = config
42 | if config.FLO == 1:
43 | if config.PASSMASK == 0:
44 | in_channels = 3
45 | elif config.PASSMASK == 1:
46 | in_channels = 4
47 | else:
48 | assert(0)
49 | out_channels = 2
50 | elif config.FLO == 0:
51 | in_channels = 4
52 | out_channels = 3
53 | else:
54 | assert(0)
55 | self.encoder = nn.Sequential(
56 | nn.ReflectionPad2d(3),
57 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0),
58 | nn.InstanceNorm2d(64, track_running_stats=False),
59 | nn.ReLU(True),
60 |
61 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
62 | nn.InstanceNorm2d(128, track_running_stats=False),
63 | nn.ReLU(True),
64 |
65 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
66 | nn.InstanceNorm2d(256, track_running_stats=False),
67 | nn.ReLU(True)
68 | )
69 |
70 | blocks = []
71 | for _ in range(residual_blocks):
72 | block = ResnetBlock(256, 2)
73 | blocks.append(block)
74 |
75 | self.middle = nn.Sequential(*blocks)
76 |
77 | self.decoder = nn.Sequential(
78 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
79 | nn.InstanceNorm2d(128, track_running_stats=False),
80 | nn.ReLU(True),
81 |
82 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
83 | nn.InstanceNorm2d(64, track_running_stats=False),
84 | nn.ReLU(True),
85 |
86 | nn.ReflectionPad2d(3),
87 | nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0),
88 | )
89 |
90 | if init_weights:
91 | self.init_weights()
92 |
93 | def forward(self, input):
94 | x = self.encoder(input)
95 | x = self.middle(x)
96 | x = self.decoder(x)
97 |
98 | if self.config.FLO == 0:
99 | x = (torch.tanh(x) + 1) / 2
100 | elif self.config.FLO == 1 and self.config.NORM == 1:
101 | if self.config.RESIDUAL == 1:
102 | assert(self.config.FILL == 1)
103 | x = torch.tanh(x + input[:, :2, :, :])
104 | elif self.config.RESIDUAL == 0:
105 | x = torch.tanh(x)
106 | else:
107 | assert(0)
108 | return x
109 |
110 |
111 | class EdgeGenerator_(BaseNetwork):
112 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
113 | super(EdgeGenerator_, self).__init__()
114 |
115 | self.encoder = nn.Sequential(
116 | nn.ReflectionPad2d(3),
117 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
118 | nn.InstanceNorm2d(64, track_running_stats=False),
119 | nn.ReLU(True),
120 |
121 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
122 | nn.InstanceNorm2d(128, track_running_stats=False),
123 | nn.ReLU(True),
124 |
125 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
126 | nn.InstanceNorm2d(256, track_running_stats=False),
127 | nn.ReLU(True)
128 | )
129 |
130 | blocks = []
131 | for _ in range(residual_blocks):
132 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
133 | blocks.append(block)
134 |
135 | self.middle = nn.Sequential(*blocks)
136 |
137 | self.decoder = nn.Sequential(
138 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
139 | nn.InstanceNorm2d(128, track_running_stats=False),
140 | nn.ReLU(True),
141 |
142 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
143 | nn.InstanceNorm2d(64, track_running_stats=False),
144 | nn.ReLU(True),
145 |
146 | nn.ReflectionPad2d(3),
147 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
148 | )
149 |
150 | if init_weights:
151 | self.init_weights()
152 |
153 | def forward(self, x):
154 | x = self.encoder(x)
155 | x = self.middle(x)
156 | x = self.decoder(x)
157 | x = torch.sigmoid(x)
158 | return x
159 |
160 |
161 | class Discriminator(BaseNetwork):
162 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
163 | super(Discriminator, self).__init__()
164 | self.use_sigmoid = use_sigmoid
165 |
166 | self.conv1 = self.features = nn.Sequential(
167 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
168 | nn.LeakyReLU(0.2, inplace=True),
169 | )
170 |
171 | self.conv2 = nn.Sequential(
172 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
173 | nn.LeakyReLU(0.2, inplace=True),
174 | )
175 |
176 | self.conv3 = nn.Sequential(
177 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
178 | nn.LeakyReLU(0.2, inplace=True),
179 | )
180 |
181 | self.conv4 = nn.Sequential(
182 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
183 | nn.LeakyReLU(0.2, inplace=True),
184 | )
185 |
186 | self.conv5 = nn.Sequential(
187 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
188 | )
189 |
190 | if init_weights:
191 | self.init_weights()
192 |
193 | def forward(self, x):
194 | conv1 = self.conv1(x)
195 | conv2 = self.conv2(conv1)
196 | conv3 = self.conv3(conv2)
197 | conv4 = self.conv4(conv3)
198 | conv5 = self.conv5(conv4)
199 |
200 | outputs = conv5
201 | if self.use_sigmoid:
202 | outputs = torch.sigmoid(conv5)
203 |
204 | return outputs, [conv1, conv2, conv3, conv4, conv5]
205 |
206 |
207 | class ResnetBlock(nn.Module):
208 | def __init__(self, dim, dilation=1, use_spectral_norm=False):
209 | super(ResnetBlock, self).__init__()
210 | self.conv_block = nn.Sequential(
211 | nn.ReflectionPad2d(dilation),
212 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
213 | nn.InstanceNorm2d(dim, track_running_stats=False),
214 | nn.ReLU(True),
215 |
216 | nn.ReflectionPad2d(1),
217 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
218 | nn.InstanceNorm2d(dim, track_running_stats=False),
219 | )
220 |
221 | def forward(self, x):
222 | out = x + self.conv_block(x)
223 |
224 | # Remove ReLU at the end of the residual block
225 | # http://torch.ch/blog/2016/02/04/resnets.html
226 |
227 | return out
228 |
229 |
230 | def spectral_norm(module, mode=True):
231 | if mode:
232 | return nn.utils.spectral_norm(module)
233 |
234 | return module
235 |
--------------------------------------------------------------------------------
/RAFT/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/RAFT/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/RAFT/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/utils/Poisson_blend_img.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import scipy.ndimage
4 | from scipy.sparse.linalg import spsolve
5 | from scipy import sparse
6 | import scipy.io as sio
7 | import numpy as np
8 | from PIL import Image
9 | import copy
10 | import cv2
11 | import os
12 | import argparse
13 |
14 |
15 | def sub2ind(pi, pj, imgH, imgW):
16 | return pj + pi * imgW
17 |
18 |
19 | def Poisson_blend_img(imgTrg, imgSrc_gx, imgSrc_gy, holeMask, gradientMask=None, edge=None):
20 |
21 | imgH, imgW, nCh = imgTrg.shape
22 |
23 | if not isinstance(gradientMask, np.ndarray):
24 | gradientMask = np.zeros((imgH, imgW), dtype=np.float32)
25 |
26 | if not isinstance(edge, np.ndarray):
27 | edge = np.zeros((imgH, imgW), dtype=np.float32)
28 |
29 | # Initialize the reconstructed image
30 | imgRecon = np.zeros((imgH, imgW, nCh), dtype=np.float32)
31 |
32 | # prepare discrete Poisson equation
33 | A, b, UnfilledMask = solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg,
34 | gradientMask, edge)
35 |
36 | # Independently process each channel
37 | for ch in range(nCh):
38 |
39 | # solve Poisson equation
40 | x = scipy.sparse.linalg.lsqr(A, b[:, ch])[0]
41 |
42 | imgRecon[:, :, ch] = x.reshape(imgH, imgW)
43 |
44 | # Combined with the known region in the target
45 | holeMaskC = np.tile(np.expand_dims(holeMask, axis=2), (1, 1, nCh))
46 | imgBlend = holeMaskC * imgRecon + (1 - holeMaskC) * imgTrg
47 |
48 |
49 | # while((UnfilledMask * edge).sum() != 0):
50 | # # Fill in edge pixel
51 | # pi = np.expand_dims(np.where((UnfilledMask * edge) == 1)[0], axis=1) # y, i
52 | # pj = np.expand_dims(np.where((UnfilledMask * edge) == 1)[1], axis=1) # x, j
53 | #
54 | # for k in range(len(pi)):
55 | # if pi[k, 0] - 1 >= 0:
56 | # if (UnfilledMask * edge)[pi[k, 0] - 1, pj[k, 0]] == 0:
57 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] - 1, pj[k, 0], :]
58 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0
59 | # continue
60 | # if pi[k, 0] + 1 <= imgH - 1:
61 | # if (UnfilledMask * edge)[pi[k, 0] + 1, pj[k, 0]] == 0:
62 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] + 1, pj[k, 0], :]
63 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0
64 | # continue
65 | # if pj[k, 0] - 1 >= 0:
66 | # if (UnfilledMask * edge)[pi[k, 0], pj[k, 0] - 1] == 0:
67 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] - 1, :]
68 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0
69 | # continue
70 | # if pj[k, 0] + 1 <= imgW - 1:
71 | # if (UnfilledMask * edge)[pi[k, 0], pj[k, 0] + 1] == 0:
72 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] + 1, :]
73 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0
74 |
75 | return imgBlend, UnfilledMask
76 |
77 | def solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg,
78 | gradientMask, edge):
79 |
80 | # UnfilledMask indicates the region that is not completed
81 | UnfilledMask_topleft = copy.deepcopy(holeMask)
82 | UnfilledMask_bottomright = copy.deepcopy(holeMask)
83 |
84 | # Prepare the linear system of equations for Poisson blending
85 | imgH, imgW = holeMask.shape
86 | N = imgH * imgW
87 |
88 | # Number of unknown variables
89 | numUnknownPix = holeMask.sum()
90 |
91 | # 4-neighbors: dx and dy
92 | dx = [1, 0, -1, 0]
93 | dy = [0, 1, 0, -1]
94 |
95 | # 3
96 | # |
97 | # 2 -- * -- 0
98 | # |
99 | # 1
100 | #
101 |
102 | # Initialize (I, J, S), for sparse matrix A where A(I(k), J(k)) = S(k)
103 | I = np.empty((0, 1), dtype=np.float32)
104 | J = np.empty((0, 1), dtype=np.float32)
105 | S = np.empty((0, 1), dtype=np.float32)
106 |
107 | # Initialize b
108 | b = np.empty((0, 3), dtype=np.float32)
109 |
110 | # Precompute unkonwn pixel position
111 | pi = np.expand_dims(np.where(holeMask == 1)[0], axis=1) # y, i
112 | pj = np.expand_dims(np.where(holeMask == 1)[1], axis=1) # x, j
113 | pind = sub2ind(pi, pj, imgH, imgW)
114 |
115 | # |--------------------|
116 | # | y (i) |
117 | # | x (j) * |
118 | # | |
119 | # |--------------------|
120 | # p[y, x]
121 |
122 | qi = np.concatenate((pi + dy[0],
123 | pi + dy[1],
124 | pi + dy[2],
125 | pi + dy[3]), axis=1)
126 |
127 | qj = np.concatenate((pj + dx[0],
128 | pj + dx[1],
129 | pj + dx[2],
130 | pj + dx[3]), axis=1)
131 |
132 | # Handling cases at image borders
133 | validN = (qi >= 0) & (qi <= imgH - 1) & (qj >= 0) & (qj <= imgW - 1)
134 | qind = np.zeros((validN.shape), dtype=np.float32)
135 | qind[validN] = sub2ind(qi[validN], qj[validN], imgH, imgW)
136 |
137 | e_start = 0 # equation counter start
138 | e_stop = 0 # equation stop
139 |
140 | # 4 neighbors
141 | I, J, S, b, e_start, e_stop = constructEquation(0, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
142 | I, J, S, b, e_start, e_stop = constructEquation(1, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
143 | I, J, S, b, e_start, e_stop = constructEquation(2, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
144 | I, J, S, b, e_start, e_stop = constructEquation(3, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop)
145 |
146 | nEqn = len(b)
147 | # Construct the sparse matrix A
148 | A = sparse.csr_matrix((S[:, 0], (I[:, 0], J[:, 0])), shape=(nEqn, N))
149 |
150 | # Check connected pixels
151 | for ind in range(0, len(pi), 1):
152 | ii = pi[ind, 0]
153 | jj = pj[ind, 0]
154 |
155 | # check up (3)
156 | if ii - 1 >= 0:
157 | if UnfilledMask_topleft[ii - 1, jj] == 0 and gradientMask[ii - 1, jj] == 0:
158 | UnfilledMask_topleft[ii, jj] = 0
159 | # check left (2)
160 | if jj - 1 >= 0:
161 | if UnfilledMask_topleft[ii, jj - 1] == 0 and gradientMask[ii, jj - 1] == 0:
162 | UnfilledMask_topleft[ii, jj] = 0
163 |
164 |
165 | for ind in range(len(pi) - 1, -1, -1):
166 | ii = pi[ind, 0]
167 | jj = pj[ind, 0]
168 |
169 | # check bottom (1)
170 | if ii + 1 <= imgH - 1:
171 | if UnfilledMask_bottomright[ii + 1, jj] == 0 and gradientMask[ii, jj] == 0:
172 | UnfilledMask_bottomright[ii, jj] = 0
173 | # check right (0)
174 | if jj + 1 <= imgW - 1:
175 | if UnfilledMask_bottomright[ii, jj + 1] == 0 and gradientMask[ii, jj] == 0:
176 | UnfilledMask_bottomright[ii, jj] = 0
177 |
178 | UnfilledMask = UnfilledMask_topleft * UnfilledMask_bottomright
179 |
180 | return A, b, UnfilledMask
181 |
182 |
183 | def constructEquation(n, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop):
184 |
185 | # Pixel that has valid neighbors
186 | validNeighbor = validN[:, n]
187 |
188 | # Change the out-of-boundary value to 0, in order to run edge[y,x]
189 | # in the next line. It won't affect anything as validNeighbor is saved already
190 |
191 | qi_tmp = copy.deepcopy(qi)
192 | qj_tmp = copy.deepcopy(qj)
193 | qi_tmp[np.invert(validNeighbor), n] = 0
194 | qj_tmp[np.invert(validNeighbor), n] = 0
195 |
196 | NotEdge = (edge[pi[:, 0], pj[:, 0]] == 0) * (edge[qi_tmp[:, n], qj_tmp[:, n]] == 0)
197 |
198 | # Have gradient
199 | if n == 0:
200 | HaveGrad = gradientMask[pi[:, 0], pj[:, 0]] == 0
201 | elif n == 2:
202 | HaveGrad = gradientMask[pi[:, 0], pj[:, 0] - 1] == 0
203 | elif n == 1:
204 | HaveGrad = gradientMask[pi[:, 0], pj[:, 0]] == 0
205 | elif n == 3:
206 | HaveGrad = gradientMask[pi[:, 0] - 1, pj[:, 0]] == 0
207 |
208 | # Boundary constraint
209 | Boundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 0
210 |
211 | valid = validNeighbor * NotEdge * HaveGrad * Boundary
212 |
213 | J_tmp = pind[valid, :]
214 |
215 | # num of equations: len(J_tmp)
216 | e_stop = e_start + len(J_tmp)
217 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1)
218 | e_start = e_stop
219 |
220 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32)
221 |
222 | if n == 0:
223 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :]
224 | elif n == 2:
225 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + imgTrg[qi[valid, n], qj[valid, n], :]
226 | elif n == 1:
227 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :]
228 | elif n == 3:
229 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :]
230 |
231 | I = np.concatenate((I, I_tmp))
232 | J = np.concatenate((J, J_tmp))
233 | S = np.concatenate((S, S_tmp))
234 | b = np.concatenate((b, b_tmp))
235 |
236 | # Non-boundary constraint
237 | NonBoundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 1
238 | valid = validNeighbor * NotEdge * HaveGrad * NonBoundary
239 |
240 | J_tmp = pind[valid, :]
241 |
242 | # num of equations: len(J_tmp)
243 | e_stop = e_start + len(J_tmp)
244 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1)
245 | e_start = e_stop
246 |
247 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32)
248 |
249 | if n == 0:
250 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :]
251 | elif n == 2:
252 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :]
253 | elif n == 1:
254 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :]
255 | elif n == 3:
256 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :]
257 |
258 | I = np.concatenate((I, I_tmp))
259 | J = np.concatenate((J, J_tmp))
260 | S = np.concatenate((S, S_tmp))
261 | b = np.concatenate((b, b_tmp))
262 |
263 | S_tmp = - np.ones(J_tmp.shape, dtype=np.float32)
264 | J_tmp = qind[valid, n, None]
265 |
266 | I = np.concatenate((I, I_tmp))
267 | J = np.concatenate((J, J_tmp))
268 | S = np.concatenate((S, S_tmp))
269 |
270 | return I, J, S, b, e_start, e_stop
271 |
--------------------------------------------------------------------------------
/edgeconnect/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from .networks import InpaintGenerator, EdgeGenerator, Discriminator
6 | from .loss import AdversarialLoss, PerceptualLoss, StyleLoss, TotalVariationalLoss
7 |
8 |
9 | class BaseModel(nn.Module):
10 | def __init__(self, name, config):
11 | super(BaseModel, self).__init__()
12 |
13 | self.name = name
14 | self.config = config
15 | self.iteration = 0
16 |
17 | self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
18 | self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')
19 |
20 | def load(self):
21 | if os.path.exists(self.gen_weights_path):
22 | print('Loading %s generator...' % self.name)
23 |
24 | if torch.cuda.is_available():
25 | data = torch.load(self.gen_weights_path)
26 | else:
27 | data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)
28 |
29 | self.generator.load_state_dict(data['generator'])
30 | self.iteration = data['iteration']
31 |
32 | # load discriminator only when training
33 | if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
34 | print('Loading %s discriminator...' % self.name)
35 |
36 | if torch.cuda.is_available():
37 | data = torch.load(self.dis_weights_path)
38 | else:
39 | data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)
40 |
41 | self.discriminator.load_state_dict(data['discriminator'])
42 |
43 | def save(self):
44 | print('\nsaving %s...\n' % self.name)
45 | torch.save({
46 | 'iteration': self.iteration,
47 | 'generator': self.generator.state_dict()
48 | }, self.gen_weights_path)
49 |
50 | torch.save({
51 | 'discriminator': self.discriminator.state_dict()
52 | }, self.dis_weights_path)
53 |
54 |
55 | class EdgeModel(BaseModel):
56 | def __init__(self, config):
57 | super(EdgeModel, self).__init__('EdgeModel', config)
58 |
59 | # generator input: [grayscale(1) + edge(1) + mask(1)]
60 | # discriminator input: (grayscale(1) + edge(1))
61 | generator = EdgeGenerator(use_spectral_norm=True)
62 | discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge')
63 | if len(config.GPU) > 1:
64 | generator = nn.DataParallel(generator, config.GPU)
65 | discriminator = nn.DataParallel(discriminator, config.GPU)
66 | l1_loss = nn.L1Loss()
67 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
68 |
69 | self.add_module('generator', generator)
70 | self.add_module('discriminator', discriminator)
71 |
72 | self.add_module('l1_loss', l1_loss)
73 | self.add_module('adversarial_loss', adversarial_loss)
74 |
75 | self.gen_optimizer = optim.Adam(
76 | params=generator.parameters(),
77 | lr=float(config.LR),
78 | betas=(config.BETA1, config.BETA2)
79 | )
80 |
81 | self.dis_optimizer = optim.Adam(
82 | params=discriminator.parameters(),
83 | lr=float(config.LR) * float(config.D2G_LR),
84 | betas=(config.BETA1, config.BETA2)
85 | )
86 |
87 | def process(self, images, edges, masks):
88 | self.iteration += 1
89 |
90 |
91 | # zero optimizers
92 | self.gen_optimizer.zero_grad()
93 | self.dis_optimizer.zero_grad()
94 |
95 |
96 | # process outputs
97 | outputs = self(images, edges, masks)
98 | gen_loss = 0
99 | dis_loss = 0
100 |
101 |
102 | # discriminator loss
103 | dis_input_real = torch.cat((images, edges), dim=1)
104 | dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
105 | dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1))
106 | dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1))
107 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
108 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
109 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
110 |
111 |
112 | # generator adversarial loss
113 | gen_input_fake = torch.cat((images, outputs), dim=1)
114 | gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1))
115 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
116 | gen_loss += gen_gan_loss
117 |
118 |
119 | # generator feature matching loss
120 | gen_fm_loss = 0
121 | for i in range(len(dis_real_feat)):
122 | gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
123 | gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT
124 | gen_loss += gen_fm_loss
125 |
126 |
127 | # create logs
128 | logs = [
129 | ("l_d1", dis_loss.item()),
130 | ("l_g1", gen_gan_loss.item()),
131 | ("l_fm", gen_fm_loss.item()),
132 | ]
133 |
134 | return outputs, gen_loss, dis_loss, logs
135 |
136 | def forward(self, images, edges, masks):
137 | edges_masked = (edges * (1 - masks))
138 | images_masked = (images * (1 - masks)) + masks
139 | inputs = torch.cat((images_masked, edges_masked, masks), dim=1)
140 | outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)]
141 | return outputs
142 |
143 | def backward(self, gen_loss=None, dis_loss=None):
144 | if dis_loss is not None:
145 | dis_loss.backward()
146 | self.dis_optimizer.step()
147 |
148 | if gen_loss is not None:
149 | gen_loss.backward()
150 | self.gen_optimizer.step()
151 |
152 |
153 | class InpaintingModel(BaseModel):
154 | def __init__(self, config):
155 | super(InpaintingModel, self).__init__('InpaintingModel', config)
156 |
157 | # generator input: [rgb(3) + edge(1)]
158 | # discriminator input: [rgb(3)]
159 | generator = InpaintGenerator(config)
160 | self.config = config
161 | if config.FLO == 1:
162 | in_channels = 2
163 | elif config.FLO == 0:
164 | in_channels = 3
165 | else:
166 | assert(0)
167 | discriminator = Discriminator(in_channels=in_channels, use_sigmoid=config.GAN_LOSS != 'hinge')
168 | if len(config.GPU) > 1:
169 | generator = nn.DataParallel(generator, config.GPU)
170 | discriminator = nn.DataParallel(discriminator , config.GPU)
171 |
172 | l1_loss = nn.L1Loss()
173 | tv_loss = TotalVariationalLoss()
174 | perceptual_loss = PerceptualLoss()
175 | style_loss = StyleLoss()
176 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
177 |
178 | self.add_module('generator', generator)
179 | self.add_module('discriminator', discriminator)
180 |
181 | self.add_module('l1_loss', l1_loss)
182 | self.add_module('tv_loss', tv_loss)
183 | self.add_module('perceptual_loss', perceptual_loss)
184 | self.add_module('style_loss', style_loss)
185 | self.add_module('adversarial_loss', adversarial_loss)
186 |
187 | self.gen_optimizer = optim.Adam(
188 | params=generator.parameters(),
189 | lr=float(config.LR),
190 | betas=(config.BETA1, config.BETA2)
191 | )
192 |
193 | self.dis_optimizer = optim.Adam(
194 | params=discriminator.parameters(),
195 | lr=float(config.LR) * float(config.D2G_LR),
196 | betas=(config.BETA1, config.BETA2)
197 | )
198 |
199 | def process(self, images, images_filled, edges, masks):
200 | self.iteration += 1
201 |
202 | # zero optimizers
203 | self.gen_optimizer.zero_grad()
204 | self.dis_optimizer.zero_grad()
205 |
206 | # process outputs
207 | outputs = self(images, images_filled, edges, masks)
208 |
209 | gen_loss = 0
210 | dis_loss = 0
211 | gen_gan_loss = 0
212 |
213 | if self.config.GAN == 1:
214 | # discriminator loss
215 | dis_input_real = images
216 | dis_input_fake = outputs.detach()
217 | dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)]
218 | dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)]
219 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
220 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
221 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
222 |
223 |
224 | # generator adversarial loss
225 | gen_input_fake = outputs
226 | gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)]
227 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
228 | gen_loss += gen_gan_loss
229 |
230 |
231 | # generator l1 loss
232 | gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
233 | gen_loss += gen_l1_loss
234 |
235 | if self.config.ENFORCE == 1:
236 | gen_l1_masked_loss = self.l1_loss(outputs * masks, images * masks) * 10 * self.config.L1_LOSS_WEIGHT
237 | gen_loss += gen_l1_masked_loss
238 | elif self.config.ENFORCE != 0:
239 | assert(0)
240 |
241 | if self.config.TV == 1:
242 | # generator tv loss
243 | gen_tv_loss = self.tv_loss(outputs) * self.config.TV_LOSS_WEIGHT
244 | gen_loss += gen_tv_loss
245 |
246 | if self.config.FLO != 1:
247 | # generator perceptual loss
248 | gen_content_loss = self.perceptual_loss(outputs, images)
249 | gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
250 | gen_loss += gen_content_loss
251 |
252 | # generator style loss
253 | gen_style_loss = self.style_loss(outputs * masks, images * masks)
254 | gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
255 | gen_loss += gen_style_loss
256 |
257 | # create logs
258 | logs = [
259 | ("l_d2", dis_loss.item()),
260 | ("l_g2", gen_gan_loss.item()),
261 | ("l_l1", gen_l1_loss.item()),
262 | ("l_per", gen_content_loss.item()),
263 | ("l_sty", gen_style_loss.item()),
264 | ]
265 | else:
266 | logs = []
267 | logs.append(("l_l1", gen_l1_loss.item()))
268 | logs.append(("l_gen", gen_loss.item()))
269 |
270 | if self.config.GAN == 1:
271 | logs.append(("l_d2", dis_loss.item()))
272 | logs.append(("l_g2", gen_gan_loss.item()))
273 |
274 | if self.config.TV == 1:
275 | logs.append(("l_tv", gen_tv_loss.item()))
276 |
277 | if self.config.ENFORCE == 1:
278 | logs.append(("l_masked_l1", gen_l1_masked_loss.item()))
279 |
280 | return outputs, gen_loss, dis_loss, logs
281 |
282 | def forward(self, images, images_filled, edges, masks):
283 |
284 | if self.config.FILL == 1:
285 | images_masked = images_filled
286 | elif self.config.FILL == 0:
287 | images_masked = (images * (1 - masks).float()) # + masks
288 | else:
289 | assert(0)
290 |
291 | if self.config.PASSMASK == 1:
292 | inputs = torch.cat((images_masked, edges, masks), dim=1)
293 | elif self.config.PASSMASK == 0:
294 | inputs = torch.cat((images_masked, edges), dim=1)
295 | else:
296 | assert(0)
297 |
298 | outputs = self.generator(inputs)
299 | # if self.config.RESIDUAL == 1:
300 | # assert(self.config.PASSMASK == 1)
301 | # outputs = self.generator(inputs) + images_filled
302 | # elif self.config.RESIDUAL == 0:
303 | # outputs = self.generator(inputs)
304 | # else:
305 | # assert(0)
306 |
307 | return outputs
308 |
309 | def backward(self, gen_loss=None, dis_loss=None):
310 |
311 | if self.config.GAN == 1:
312 | dis_loss.backward()
313 | self.dis_optimizer.step()
314 |
315 | gen_loss.backward()
316 | self.gen_optimizer.step()
317 |
--------------------------------------------------------------------------------
/edgeconnect/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import scipy
5 | import torch
6 | import random
7 | import numpy as np
8 | import torchvision.transforms.functional as F
9 | from torch.utils.data import DataLoader
10 | from PIL import Image
11 | from scipy.misc import imread
12 | from skimage.feature import canny
13 | from skimage.color import rgb2gray, gray2rgb
14 | from .utils import create_mask
15 | import src.region_fill as rf
16 |
17 | class Dataset(torch.utils.data.Dataset):
18 | def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True):
19 | super(Dataset, self).__init__()
20 |
21 | self.augment = augment
22 | self.training = training
23 | self.flo = config.FLO
24 | self.norm = config.NORM
25 | self.data = self.load_flist(flist, self.flo)
26 | self.edge_data = self.load_flist(edge_flist, 0)
27 | self.mask_data = self.load_flist(mask_flist, 0)
28 |
29 | self.input_size = config.INPUT_SIZE
30 | self.sigma = config.SIGMA
31 | self.edge = config.EDGE
32 | self.mask = config.MASK
33 | self.nms = config.NMS
34 |
35 |
36 |
37 | # in test mode, there's a one-to-one relationship between mask and image
38 | # masks are loaded non random
39 | if config.MODE == 2:
40 | self.mask = 6
41 |
42 | def __len__(self):
43 | return len(self.data)
44 |
45 | def __getitem__(self, index):
46 | try:
47 | item = self.load_item(index)
48 | except:
49 | print('loading error: ' + self.data[index])
50 | item = self.load_item(0)
51 |
52 | return item
53 |
54 | def load_name(self, index):
55 | name = self.data[index]
56 | return os.path.basename(name)
57 |
58 | def load_item(self, index):
59 | size = self.input_size
60 | factor = 1.
61 | if self.flo == 0:
62 |
63 | # load image
64 | img = imread(self.data[index])
65 |
66 | # gray to rgb
67 | if len(img.shape) < 3:
68 | img = gray2rgb(img)
69 |
70 | # resize/crop if needed
71 | if size != 0:
72 | img = self.resize(img, size[0], size[1])
73 |
74 | # create grayscale image
75 | img_gray = rgb2gray(img)
76 |
77 | # load mask
78 | mask = self.load_mask(img, index)
79 |
80 | edge = self.load_edge(img_gray, index, mask)
81 |
82 | img_filled = img
83 |
84 | else:
85 |
86 | img = self.readFlow(self.data[index])
87 |
88 | # resize/crop if needed
89 | if size != 0:
90 | img = self.flow_tf(img, [size[0], size[1]])
91 |
92 | img_gray = (img[:, :, 0] ** 2 + img[:, :, 1] ** 2) ** 0.5
93 |
94 | if self.norm == 1:
95 | # normalization
96 | # factor = (np.abs(img[:, :, 0]).max() ** 2 + np.abs(img[:, :, 1]).max() ** 2) ** 0.5
97 | factor = img_gray.max()
98 | img /= factor
99 |
100 | # load mask
101 | mask = self.load_mask(img, index)
102 |
103 | edge = self.load_edge(img_gray, index, mask)
104 | img_gray = img_gray / img_gray.max()
105 |
106 | img_filled = np.zeros(img.shape)
107 | img_filled[:, :, 0] = rf.regionfill(img[:, :, 0], mask)
108 | img_filled[:, :, 1] = rf.regionfill(img[:, :, 1], mask)
109 |
110 |
111 | # augment data
112 | if self.augment and np.random.binomial(1, 0.5) > 0:
113 | img = img[:, ::-1, ...].copy()
114 | img_filled = img_filled[:, ::-1, ...].copy()
115 | img_gray = img_gray[:, ::-1, ...]
116 | edge = edge[:, ::-1, ...]
117 | mask = mask[:, ::-1, ...]
118 |
119 | return self.to_tensor(img), self.to_tensor(img_filled), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask), factor
120 |
121 | def load_edge(self, img, index, mask):
122 | sigma = self.sigma
123 |
124 | # in test mode images are masked (with masked regions),
125 | # using 'mask' parameter prevents canny to detect edges for the masked regions
126 | mask = None if self.training else (1 - mask / 255).astype(np.bool)
127 |
128 | # canny
129 | if self.edge == 1:
130 | # no edge
131 | if sigma == -1:
132 | return np.zeros(img.shape).astype(np.float)
133 |
134 | # random sigma
135 | if sigma == 0:
136 | sigma = random.randint(1, 4)
137 | return canny(img, sigma=sigma, mask=mask).astype(np.float)
138 |
139 | # external
140 | else:
141 | imgh, imgw = img.shape[0:2]
142 | edge = imread(self.edge_data[index])
143 | edge = self.resize(edge, imgh, imgw)
144 |
145 | # non-max suppression
146 | if self.nms == 1:
147 | edge = edge * canny(img, sigma=sigma, mask=mask)
148 |
149 | return edge
150 |
151 | def load_mask(self, img, index):
152 | imgh, imgw = img.shape[0:2]
153 | mask_type = self.mask
154 |
155 | # external + random block
156 | if mask_type == 4:
157 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3
158 |
159 | # external + random block + half
160 | elif mask_type == 5:
161 | mask_type = np.random.randint(1, 4)
162 |
163 | # random block
164 | if mask_type == 1:
165 | return create_mask(imgw, imgh, imgw // 2, imgh // 2)
166 |
167 | # half
168 | if mask_type == 2:
169 | # randomly choose right or left
170 | return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
171 |
172 | # external
173 | if mask_type == 3:
174 | mask_index = random.randint(0, len(self.mask_data) - 1)
175 | mask = imread(self.mask_data[mask_index])
176 | mask = self.resize(mask, imgh, imgw, centerCrop=False)
177 | mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation
178 | return mask
179 |
180 | # test mode: load mask non random
181 |
182 | if mask_type == 6:
183 | mask = imread(self.mask_data[index])
184 | mask = self.resize(mask, imgh, imgw, centerCrop=False)
185 | mask = rgb2gray(mask)
186 | mask = (mask > 0).astype(np.uint8) * 255
187 | return mask
188 |
189 | def to_tensor(self, img):
190 | if (len(img.shape) == 3 and img.shape[2] == 2):
191 | return F.to_tensor(img).float()
192 | img = Image.fromarray(img)
193 | img_t = F.to_tensor(img).float()
194 | return img_t
195 |
196 | def resize(self, img, height, width, centerCrop=True):
197 | imgh, imgw = img.shape[0:2]
198 |
199 | if centerCrop and imgh != imgw:
200 | # center crop
201 | side = np.minimum(imgh, imgw)
202 | j = (imgh - side) // 2
203 | i = (imgw - side) // 2
204 | img = img[j:j + side, i:i + side, ...]
205 |
206 | img = scipy.misc.imresize(img, [height, width])
207 |
208 | return img
209 |
210 | def load_flist(self, flist, flo=0):
211 | if isinstance(flist, list):
212 | return flist
213 |
214 | # flist: image file path, image directory path, text file flist path
215 | if flo == 0:
216 | if isinstance(flist, str):
217 | if os.path.isdir(flist):
218 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
219 | flist.sort()
220 | return flist
221 |
222 | if os.path.isfile(flist):
223 | try:
224 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
225 | except:
226 | return [flist]
227 | else:
228 | if isinstance(flist, str):
229 | if os.path.isdir(flist):
230 | flist = list(glob.glob(flist + '/*.flo'))
231 | flist.sort()
232 | return flist
233 |
234 | if os.path.isfile(flist):
235 | try:
236 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
237 | except:
238 | return [flist]
239 |
240 | return []
241 |
242 | def create_iterator(self, batch_size):
243 | while True:
244 | sample_loader = DataLoader(
245 | dataset=self,
246 | batch_size=batch_size,
247 | drop_last=True
248 | )
249 |
250 | for item in sample_loader:
251 | yield item
252 |
253 | def readFlow(self, fn):
254 | with open(fn, 'rb') as f:
255 | magic = np.fromfile(f, np.float32, count=1)
256 | if 202021.25 != magic:
257 | print('Magic number incorrect. Invalid .flo file')
258 | return None
259 | else:
260 | w = np.fromfile(f, np.int32, count=1)
261 | h = np.fromfile(f, np.int32, count=1)
262 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
263 | # Reshape data into 3D array (columns, rows, bands)
264 | # The reshape here is for visualization, the original code is (w,h,2)
265 | return np.resize(data, (int(h), int(w), 2))
266 |
267 | def flow_to_image(self, flow):
268 |
269 | UNKNOWN_FLOW_THRESH = 1e7
270 |
271 | u = flow[:, :, 0]
272 | v = flow[:, :, 1]
273 |
274 | maxu = -999.
275 | maxv = -999.
276 | minu = 999.
277 | minv = 999.
278 |
279 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
280 | u[idxUnknow] = 0
281 | v[idxUnknow] = 0
282 |
283 | maxu = max(maxu, np.max(u))
284 | minu = min(minu, np.min(u))
285 |
286 | maxv = max(maxv, np.max(v))
287 | minv = min(minv, np.min(v))
288 |
289 | rad = np.sqrt(u ** 2 + v ** 2)
290 | maxrad = max(-1, np.max(rad))
291 |
292 | # print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv))
293 |
294 | u = u/(maxrad + np.finfo(float).eps)
295 | v = v/(maxrad + np.finfo(float).eps)
296 |
297 | img = self.compute_color(u, v)
298 |
299 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
300 | img[idx] = 0
301 |
302 | return np.uint8(img)
303 |
304 |
305 | def compute_color(self, u, v):
306 | """
307 | compute optical flow color map
308 | :param u: optical flow horizontal map
309 | :param v: optical flow vertical map
310 | :return: optical flow in color code
311 | """
312 | [h, w] = u.shape
313 | img = np.zeros([h, w, 3])
314 | nanIdx = np.isnan(u) | np.isnan(v)
315 | u[nanIdx] = 0
316 | v[nanIdx] = 0
317 |
318 | colorwheel = self.make_color_wheel()
319 | ncols = np.size(colorwheel, 0)
320 |
321 | rad = np.sqrt(u**2+v**2)
322 |
323 | a = np.arctan2(-v, -u) / np.pi
324 |
325 | fk = (a+1) / 2 * (ncols - 1) + 1
326 |
327 | k0 = np.floor(fk).astype(int)
328 |
329 | k1 = k0 + 1
330 | k1[k1 == ncols+1] = 1
331 | f = fk - k0
332 |
333 | for i in range(0, np.size(colorwheel,1)):
334 | tmp = colorwheel[:, i]
335 | col0 = tmp[k0-1] / 255
336 | col1 = tmp[k1-1] / 255
337 | col = (1-f) * col0 + f * col1
338 |
339 | idx = rad <= 1
340 | col[idx] = 1-rad[idx]*(1-col[idx])
341 | notidx = np.logical_not(idx)
342 |
343 | col[notidx] *= 0.75
344 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
345 |
346 | return img
347 |
348 |
349 | def make_color_wheel(self):
350 | """
351 | Generate color wheel according Middlebury color code
352 | :return: Color wheel
353 | """
354 | RY = 15
355 | YG = 6
356 | GC = 4
357 | CB = 11
358 | BM = 13
359 | MR = 6
360 |
361 | ncols = RY + YG + GC + CB + BM + MR
362 |
363 | colorwheel = np.zeros([ncols, 3])
364 |
365 | col = 0
366 |
367 | # RY
368 | colorwheel[0:RY, 0] = 255
369 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
370 | col += RY
371 |
372 | # YG
373 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
374 | colorwheel[col:col+YG, 1] = 255
375 | col += YG
376 |
377 | # GC
378 | colorwheel[col:col+GC, 1] = 255
379 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
380 | col += GC
381 |
382 | # CB
383 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
384 | colorwheel[col:col+CB, 2] = 255
385 | col += CB
386 |
387 | # BM
388 | colorwheel[col:col+BM, 2] = 255
389 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
390 | col += + BM
391 |
392 | # MR
393 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
394 | colorwheel[col:col+MR, 0] = 255
395 |
396 | return colorwheel
397 |
398 | def flow_tf(self, flow, size):
399 | flow_shape = flow.shape
400 | flow_resized = cv2.resize(flow, (size[1], size[0]))
401 | flow_resized[:, :, 0] *= (float(size[1]) / float(flow_shape[1]))
402 | flow_resized[:, :, 1] *= (float(size[0]) / float(flow_shape[0]))
403 |
404 | return flow_resized
405 |
--------------------------------------------------------------------------------
/models/DeepFill_Models/ops.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 |
9 | def weights_init(init_type='gaussian'):
10 | def init_fun(m):
11 | classname = m.__class__.__name__
12 | if (classname.find('Conv') == 0 or classname.find(
13 | 'Linear') == 0) and hasattr(m, 'weight'):
14 | if init_type == 'gaussian':
15 | nn.init.normal_(m.weight, 0.0, 0.02)
16 | elif init_type == 'xavier':
17 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
18 | elif init_type == 'kaiming':
19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20 | elif init_type == 'orthogonal':
21 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
22 | elif init_type == 'default':
23 | pass
24 | else:
25 | assert 0, "Unsupported initialization: {}".format(init_type)
26 | if hasattr(m, 'bias') and m.bias is not None:
27 | nn.init.constant_(m.bias, 0.0)
28 |
29 | return init_fun
30 |
31 |
32 | class Conv(nn.Module):
33 | def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU(), isGated=False):
34 | super(Conv, self).__init__()
35 | if activation is not None:
36 | self.conv = nn.Sequential(
37 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D),
38 | activation
39 | )
40 | else:
41 | self.conv = nn.Sequential(
42 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D)
43 | )
44 |
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d):
47 | m.apply(weights_init('kaiming'))
48 |
49 | def forward(self, x):
50 | x = self.conv(x)
51 | return x
52 |
53 |
54 | class Conv_Downsample(nn.Module):
55 | def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU()):
56 | super(Conv_Downsample, self).__init__()
57 |
58 | PaddingLayer = torch.nn.ZeroPad2d((0, (K-1)//2, 0, (K-1)//2))
59 |
60 | if activation is not None:
61 | self.conv = nn.Sequential(
62 | PaddingLayer,
63 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=0, dilation=D),
64 | activation
65 | )
66 | else:
67 | self.conv = nn.Sequential(
68 | PaddingLayer,
69 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=0, dilation=D)
70 | )
71 |
72 | for m in self.modules():
73 | if isinstance(m, nn.Conv2d):
74 | m.apply(weights_init('kaiming'))
75 |
76 | def forward(self, x):
77 | x = self.conv(x)
78 | return x
79 |
80 |
81 | class Down_Module(nn.Module):
82 | def __init__(self, in_ch, out_ch, activation=nn.ELU(), isRefine=False,
83 | isAttn=False, ):
84 | super(Down_Module, self).__init__()
85 | layers = []
86 | layers.append(Conv(in_ch, out_ch, K=5, P=2))
87 | # curr_dim = out_ch
88 | # layers.append(Conv_Downsample(curr_dim, curr_dim * 2, K=3, S=2, isGated=isGated))
89 |
90 | curr_dim = out_ch
91 | if isRefine:
92 | if isAttn:
93 | layers.append(Conv_Downsample(curr_dim, curr_dim, K=3, S=2))
94 | layers.append(Conv(curr_dim, 2*curr_dim, K=3, S=1))
95 | layers.append(Conv_Downsample(2*curr_dim, 4*curr_dim, K=3, S=2))
96 | layers.append(Conv(4 * curr_dim, 4 * curr_dim, K=3, S=1))
97 | curr_dim *= 4
98 | else:
99 | for i in range(2):
100 | layers.append(Conv_Downsample(curr_dim, curr_dim, K=3, S=2))
101 | layers.append(Conv(curr_dim, curr_dim*2))
102 | curr_dim *= 2
103 | else:
104 | for i in range(2):
105 | layers.append(Conv_Downsample(curr_dim, curr_dim*2, K=3, S=2))
106 | layers.append(Conv(curr_dim * 2, curr_dim * 2))
107 | curr_dim *= 2
108 |
109 | layers.append(Conv(curr_dim, curr_dim, activation=activation))
110 |
111 | self.out = nn.Sequential(*layers)
112 |
113 | def forward(self, x):
114 | return self.out(x)
115 |
116 |
117 | class Dilation_Module(nn.Module):
118 | def __init__(self, in_ch, out_ch):
119 | super(Dilation_Module, self).__init__()
120 | layers = []
121 | dilation = 1
122 | for i in range(4):
123 | dilation *= 2
124 | layers.append(Conv(in_ch, out_ch, D=dilation, P=dilation))
125 | self.out = nn.Sequential(*layers)
126 |
127 | def forward(self, x):
128 | return self.out(x)
129 |
130 |
131 | class Up_Module(nn.Module):
132 | def __init__(self, in_ch, out_ch, isRefine=False):
133 | super(Up_Module, self).__init__()
134 | layers = []
135 | curr_dim = in_ch
136 | if isRefine:
137 | layers.append(Conv(curr_dim, curr_dim//2))
138 | curr_dim //= 2
139 | else:
140 | layers.append(Conv(curr_dim, curr_dim))
141 |
142 | # conv 12~15
143 | for i in range(2):
144 | layers.append(Conv(curr_dim, curr_dim))
145 | layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
146 | layers.append(Conv(curr_dim, curr_dim//2))
147 | curr_dim //= 2
148 |
149 | layers.append(Conv(curr_dim, curr_dim//2))
150 | layers.append(Conv(curr_dim//2, out_ch, activation=None))
151 |
152 | self.out = nn.Sequential(*layers)
153 |
154 | def forward(self, x):
155 | output = self.out(x)
156 | return torch.clamp(output, min=-1., max=1.)
157 |
158 |
159 | class Up_Module_CNet(nn.Module):
160 | def __init__(self, in_ch, out_ch, isRefine=False, isGated=False):
161 | super(Up_Module_CNet, self).__init__()
162 | layers = []
163 | curr_dim = in_ch
164 | if isRefine:
165 | layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated))
166 | curr_dim //= 2
167 | else:
168 | layers.append(Conv(curr_dim, curr_dim, isGated=isGated))
169 |
170 | # conv 12~15
171 | for i in range(2):
172 | layers.append(Conv(curr_dim, curr_dim, isGated=isGated))
173 | layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
174 | layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated))
175 | curr_dim //= 2
176 |
177 | layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated))
178 | layers.append(Conv(curr_dim//2, out_ch, activation=None, isGated=isGated))
179 |
180 | self.out = nn.Sequential(*layers)
181 |
182 | def forward(self, x):
183 | output = self.out(x)
184 | return output
185 |
186 |
187 | class Flatten_Module(nn.Module):
188 | def __init__(self, in_ch, out_ch, isLocal=True):
189 | super(Flatten_Module, self).__init__()
190 | layers = []
191 | layers.append(Conv(in_ch, out_ch, K=5, S=2, P=2, activation=nn.LeakyReLU()))
192 | curr_dim = out_ch
193 |
194 | for i in range(2):
195 | layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU()))
196 | curr_dim *= 2
197 |
198 | if isLocal:
199 | layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU()))
200 | else:
201 | layers.append(Conv(curr_dim, curr_dim, K=5, S=2, P=2, activation=nn.LeakyReLU()))
202 |
203 | self.out = nn.Sequential(*layers)
204 |
205 | def forward(self, x):
206 | x = self.out(x)
207 | return x.view(x.size(0),-1) # 2B x 256*(256 or 512); front 256:16*16
208 |
209 |
210 | class Contextual_Attention_Module(nn.Module):
211 | def __init__(self, in_ch, out_ch, rate=2, stride=1, isCheck=False, device=None):
212 | super(Contextual_Attention_Module, self).__init__()
213 | self.rate = rate
214 | self.padding = nn.ZeroPad2d(1)
215 | self.up_sample = nn.Upsample(scale_factor=self.rate, mode='nearest')
216 | layers = []
217 | for i in range(2):
218 | layers.append(Conv(in_ch, out_ch))
219 | self.out = nn.Sequential(*layers)
220 | self.isCheck = isCheck
221 | self.device = device
222 |
223 | def forward(self, f, b, mask=None, ksize=3, stride=1,
224 | fuse_k=3, softmax_scale=10., training=True, fuse=True):
225 |
226 | """ Contextual attention layer implementation.
227 |
228 | Contextual attention is first introduced in publication:
229 | Generative Image Inpainting with Contextual Attention, Yu et al.
230 |
231 | Args:
232 | f: Input feature to match (foreground).
233 | b: Input feature for match (background).
234 | mask: Input mask for b, indicating patches not available.
235 | ksize: Kernel size for contextual attention.
236 | stride: Stride for extracting patches from b.
237 | rate: Dilation for matching.
238 | softmax_scale: Scaled softmax for attention.
239 | training: Indicating if current graph is training or inference.
240 |
241 | Returns:
242 | tf.Tensor: output
243 |
244 | """
245 |
246 | # get shapes
247 | raw_fs = f.size() # B x 128 x 64 x 64
248 | raw_int_fs = list(f.size())
249 | raw_int_bs = list(b.size())
250 |
251 | # extract patches from background with stride and rate
252 | kernel = 2*self.rate
253 | raw_w = self.extract_patches(b, kernel=kernel, stride=self.rate)
254 | raw_w = raw_w.permute(0, 2, 3, 4, 5, 1)
255 | raw_w = raw_w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1)
256 | raw_w = raw_w.contiguous().view(raw_int_bs[0], -1, kernel, kernel, raw_int_bs[1])
257 | raw_w = raw_w.permute(0, 1, 4, 2, 3)
258 |
259 | f = down_sample(f, scale_factor=1/self.rate, mode='nearest', device=self.device)
260 | b = down_sample(b, scale_factor=1/self.rate, mode='nearest', device=self.device)
261 |
262 | fs = f.size() # B x 128 x 32 x 32
263 | int_fs = list(f.size())
264 | f_groups = torch.split(f, 1, dim=0) # Split tensors by batch dimension; tuple is returned
265 |
266 | # from b(B*H*W*C) to w(b*k*k*c*h*w)
267 | bs = b.size() # B x 128 x 32 x 32
268 | int_bs = list(b.size())
269 | w = self.extract_patches(b)
270 | w = w.permute(0, 2, 3, 4, 5, 1)
271 | w = w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1)
272 | w = w.contiguous().view(raw_int_bs[0], -1, ksize, ksize, raw_int_bs[1])
273 | w = w.permute(0, 1, 4, 2, 3)
274 | # process mask
275 | mask = mask.clone()
276 | if mask is not None:
277 | if mask.size(2) != b.size(2):
278 | mask = down_sample(mask, scale_factor=1./self.rate, mode='nearest', device=self.device)
279 | else:
280 | mask = torch.zeros([1, 1, bs[2], bs[3]])
281 |
282 | m = self.extract_patches(mask)
283 |
284 | m = m.permute(0, 2, 3, 4, 5, 1)
285 | m = m.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1)
286 | m = m.contiguous().view(raw_int_bs[0], -1, ksize, ksize, 1)
287 | m = m.permute(0, 4, 1, 2, 3)
288 |
289 | m = m[0] # (1, 32*32, 3, 3)
290 | m = reduce_mean(m) # smoothing, maybe
291 | mm = m.eq(0.).float() # (1, 32*32, 1, 1)
292 |
293 | w_groups = torch.split(w, 1, dim=0) # Split tensors by batch dimension; tuple is returned
294 | raw_w_groups = torch.split(raw_w, 1, dim=0) # Split tensors by batch dimension; tuple is returned
295 | y = []
296 | offsets = []
297 | k = fuse_k
298 | scale = softmax_scale
299 | fuse_weight = Variable(torch.eye(k).view(1, 1, k, k)).cuda(self.device) # 1 x 1 x K x K
300 | y_test = []
301 | for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
302 | '''
303 | O => output channel as a conv filter
304 | I => input channel as a conv filter
305 | xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
306 | wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
307 | raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
308 | '''
309 | # conv for compare
310 | wi = wi[0]
311 | escape_NaN = Variable(torch.FloatTensor([1e-4])).cuda(self.device)
312 | wi_normed = wi / torch.max(l2_norm(wi), escape_NaN)
313 | yi = F.conv2d(xi, wi_normed, stride=1, padding=1) # yi => (B=1, C=32*32, H=32, W=32)
314 | y_test.append(yi)
315 | # conv implementation for fuse scores to encourage large patches
316 | if fuse:
317 | yi = yi.permute(0, 2, 3, 1)
318 | yi = yi.contiguous().view(1, fs[2] * fs[3], bs[2] * bs[3], 1)
319 | yi = yi.permute(0, 3, 1, 2) # make all of depth to spatial resolution, (B=1, I=1, H=32*32, W=32*32)
320 | yi = F.conv2d(yi, fuse_weight, stride=1, padding=1) # (B=1, C=1, H=32*32, W=32*32)
321 |
322 | yi = yi.permute(0, 2, 3, 1)
323 | yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3])
324 | # yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3]) # (B=1, 32, 32, 32, 32)
325 | yi = yi.permute(0, 2, 1, 4, 3)
326 | yi = yi.contiguous().view(1, fs[2] * fs[3], bs[2] * bs[3], 1)
327 | yi = yi.permute(0, 3, 1, 2)
328 |
329 | yi = F.conv2d(yi, fuse_weight, stride=1, padding=1)
330 | yi = yi.permute(0, 2, 3, 1)
331 | yi = yi.contiguous().view(1, fs[3], fs[2], bs[3], bs[2])
332 | yi = yi.permute(0, 2, 1, 4, 3)
333 | yi = yi.contiguous().view(1, fs[2], fs[3], bs[2] * bs[3])
334 | yi = yi.permute(0, 3, 1, 2)
335 | else:
336 | yi = yi.permute(0, 2, 3, 1)
337 | yi = yi.contiguous().view(1, fs[2], fs[3], bs[2] * bs[3])
338 | yi = yi.permute(0, 3, 1, 2) # (B=1, C=32*32, H=32, W=32)
339 | # yi = yi.contiguous().view(1, bs[2] * bs[3], fs[2], fs[3])
340 |
341 | # softmax to match
342 | yi = yi * mm # mm => (1, 32*32, 1, 1)
343 | yi = F.softmax(yi*scale, dim=1)
344 | yi = yi * mm # mask
345 |
346 | _, offset = torch.max(yi, dim=1) # argmax; index
347 | division = torch.true_divide(offset, fs[3]).long()
348 | offset = torch.stack([division, torch.true_divide(offset, fs[3])-division], dim=-1)
349 |
350 | wi_center = raw_wi[0]
351 |
352 | yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
353 | y.append(yi)
354 | offsets.append(offset)
355 |
356 | y = torch.cat(y, dim=0) # back to the mini-batch
357 | y.contiguous().view(raw_int_fs)
358 | # wi_patched = y
359 | offsets = torch.cat(offsets, dim=0)
360 | offsets = offsets.view([int_bs[0]] + [2] + int_bs[2:])
361 |
362 | # case1: visualize optical flow: minus current position
363 | h_add = Variable(torch.arange(0,float(bs[2]))).cuda(self.device).view([1, 1, bs[2], 1])
364 | h_add = h_add.expand(bs[0], 1, bs[2], bs[3])
365 | w_add = Variable(torch.arange(0,float(bs[3]))).cuda(self.device).view([1, 1, 1, bs[3]])
366 | w_add = w_add.expand(bs[0], 1, bs[2], bs[3])
367 |
368 | offsets = offsets - torch.cat([h_add, w_add], dim=1).long()
369 |
370 | # # case2: visualize which pixels are attended
371 | # flow = torch.from_numpy(highlight_flow((offsets * mask.int()).numpy()))
372 | y = self.out(y)
373 |
374 | return y, offsets
375 |
376 | def extract_patches(self, x, kernel=3, stride=1):
377 | x = self.padding(x)
378 | all_patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
379 |
380 | return all_patches
381 |
382 |
383 | def reduce_mean(x):
384 | for i in range(4):
385 | if i==1: continue
386 | x = torch.mean(x, dim=i, keepdim=True)
387 | return x
388 |
389 |
390 | def l2_norm(x):
391 | def reduce_sum(x):
392 | for i in range(4):
393 | if i==0: continue
394 | x = torch.sum(x, dim=i, keepdim=True)
395 | return x
396 |
397 | x = x**2
398 | x = reduce_sum(x)
399 | return torch.sqrt(x)
400 |
401 |
402 | def down_sample(x, size=None, scale_factor=None, mode='nearest', device=None):
403 | # define size if user has specified scale_factor
404 | if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3)))
405 | # create coordinates
406 | # size_origin = [x.size[2], x.size[3]]
407 | h = torch.true_divide(torch.arange(0, size[0]), (size[0])) * 2 - 1
408 | w = torch.true_divide(torch.arange(0, size[1]), (size[1])) * 2 - 1
409 | # create grid
410 | grid = torch.zeros(size[0],size[1],2)
411 | grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1)
412 | grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1)
413 | # expand to match batch size
414 | grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1)
415 | if x.is_cuda:
416 | if device:
417 | grid = Variable(grid).cuda(device)
418 | else:
419 | grid = Variable(grid).cuda()
420 | # do sampling
421 |
422 | return F.grid_sample(x, grid, mode=mode)
423 |
424 |
425 | def to_var(x, volatile=False, device=None):
426 | if torch.cuda.is_available():
427 | if device:
428 | x = x.cuda(device)
429 | else:
430 | x = x.cuda()
431 | return Variable(x, volatile=volatile)
432 |
--------------------------------------------------------------------------------
/tool/video_completion_modified.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))
4 |
5 | import argparse
6 | import os
7 | import cv2
8 | import glob
9 | import copy
10 | import numpy as np
11 | import cupy as cp
12 | import torch
13 | import imageio
14 | from PIL import Image
15 | import scipy.ndimage
16 | from skimage.feature import canny
17 | from skimage.transform import integral
18 | import torchvision.transforms.functional as F
19 |
20 | from RAFT import utils
21 | from RAFT import RAFT
22 |
23 | import utils.region_fill as rf
24 | from utils.Poisson_blend import Poisson_blend
25 | from utils.Poisson_blend_img import Poisson_blend_img
26 | from get_flowNN import get_flowNN
27 | from get_flowNN_gradient import get_flowNN_gradient
28 | from utils.common_utils import flow_edge
29 | from spatial_inpaint import spatial_inpaint
30 | from frame_inpaint import DeepFillv1
31 | from edgeconnect.networks import EdgeGenerator_
32 |
33 | import time
34 |
35 | def find_minbbox(masks):
36 | # find the minimum bounding box of the holdmask
37 | minbbox_tl=[] # top left point of the minimum bounding box
38 | minbbox_br=[] # bottom right point of the minimum bounding box
39 | for i in range (0,len(masks)):
40 | non_zeros=cv2.findNonZero(np.array(masks[i]*255))
41 | min_rect=cv2.boundingRect(non_zeros)
42 |
43 | # expand 10 pixels
44 | x1=max(0,min_rect[0]-10)
45 | y1=max(0,min_rect[1]-10)
46 | x2=min(masks[i].shape[1],min_rect[0]+min_rect[2]+10)
47 | y2=min(masks[i].shape[0],min_rect[1]+min_rect[3]+10)
48 |
49 | minbbox_tl.append([x1,y1])
50 | minbbox_br.append([x2,y2])
51 | return minbbox_tl,minbbox_br
52 |
53 |
54 | def to_tensor(img):
55 | img = Image.fromarray(img)
56 | img_t = F.to_tensor(img).float()
57 | return img_t
58 |
59 |
60 | def infer(args, EdgeGenerator, device, flow_img_gray, edge, mask):
61 |
62 | # Add a pytorch dataloader
63 | flow_img_gray_tensor = to_tensor(flow_img_gray)[None, :, :].float().to(device)
64 | edge_tensor = to_tensor(edge)[None, :, :].float().to(device)
65 | mask_tensor = torch.from_numpy(mask.astype(np.float64))[None, None, :, :].float().to(device)
66 |
67 | # Complete the edges
68 | edges_masked = (edge_tensor * (1 - mask_tensor))
69 | images_masked = (flow_img_gray_tensor * (1 - mask_tensor)) + mask_tensor
70 | inputs = torch.cat((images_masked, edges_masked, mask_tensor), dim=1)
71 | with torch.no_grad():
72 | edges_completed = EdgeGenerator(inputs) # in: [grayscale(1) + edge(1) + mask(1)]
73 | edges_completed = edges_completed * mask_tensor + edge_tensor * (1 - mask_tensor)
74 | edge_completed = edges_completed[0, 0].data.cpu().numpy()
75 | edge_completed[edge_completed < 0.5] = 0
76 | edge_completed[edge_completed >= 0.5] = 1
77 |
78 | return edge_completed
79 |
80 |
81 | def gradient_mask(mask):
82 |
83 | gradient_mask = np.logical_or.reduce((mask,
84 | np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0),
85 | np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1)))
86 |
87 | return gradient_mask
88 |
89 |
90 | def create_dir(dir):
91 | """Creates a directory if not exist.
92 | """
93 | if not os.path.exists(dir):
94 | os.makedirs(dir)
95 |
96 |
97 | def initialize_RAFT(args):
98 | """Initializes the RAFT model.
99 | """
100 | model = torch.nn.DataParallel(RAFT(args))
101 | model.load_state_dict(torch.load(args.model))
102 |
103 | model = model.module
104 | model.to('cuda')
105 | model.eval()
106 |
107 | return model
108 |
109 |
110 | def calculate_flow(args, model, video, mode):
111 | """Calculates optical flow.
112 | """
113 | if mode not in ['forward', 'backward']:
114 | raise NotImplementedError
115 |
116 | nFrame, _, imgH, imgW = video.shape
117 | Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
118 |
119 | # if os.path.isdir(os.path.join(args.outroot, 'flow', mode + '_flo')):
120 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow', mode + '_flo', '*.flo'))):
121 | # print("Loading {0}".format(flow_name), '\r', end='')
122 | # flow = utils.frame_utils.readFlow(flow_name)
123 | # Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
124 | # return Flow
125 |
126 | # create_dir(os.path.join(args.outroot, 'flow', mode + '_flo'))
127 | # create_dir(os.path.join(args.outroot, 'flow', mode + '_png'))
128 |
129 | with torch.no_grad():
130 | for i in range(video.shape[0] - 1):
131 | print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
132 | if mode == 'forward':
133 | # Flow i -> i + 1
134 | image1 = video[i, None]
135 | image2 = video[i + 1, None]
136 | elif mode == 'backward':
137 | # Flow i + 1 -> i
138 | image1 = video[i + 1, None]
139 | image2 = video[i, None]
140 | else:
141 | raise NotImplementedError
142 |
143 | _, flow = model(image1, image2, iters=20, test_mode=True)
144 | flow = flow[0].permute(1, 2, 0).cpu().numpy()
145 | Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
146 |
147 | # # Flow visualization.
148 | # flow_img = utils.flow_viz.flow_to_image(flow)
149 | # flow_img = Image.fromarray(flow_img)
150 |
151 | # # Saves the flow and flow_img.
152 | # flow_img.save(os.path.join(args.outroot, 'flow', mode + '_png', '%05d.png'%i))
153 | # utils.frame_utils.writeFlow(os.path.join(args.outroot, 'flow', mode + '_flo', '%05d.flo'%i), flow)
154 |
155 | return Flow
156 |
157 |
158 |
159 | def complete_flow(args, corrFlow, flow_mask, mode, minbbox_tl, minbbox_br, edge=None):
160 | """Completes flow.
161 | """
162 | if mode not in ['forward', 'backward']:
163 | raise NotImplementedError
164 |
165 | imgH, imgW, _, nFrame = corrFlow.shape
166 |
167 | # if os.path.isdir(os.path.join(args.outroot, 'flow_comp', mode + '_flo')):
168 | # compFlow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
169 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow_comp', mode + '_flo', '*.flo'))):
170 | # print("Loading {0}".format(flow_name), '\r', end='')
171 | # flow = utils.frame_utils.readFlow(flow_name)
172 | # compFlow = np.concatenate((compFlow, flow[..., None]), axis=-1)
173 | # return compFlow
174 |
175 | # create_dir(os.path.join(args.outroot, 'flow_comp', mode + '_flo'))
176 | # create_dir(os.path.join(args.outroot, 'flow_comp', mode + '_png'))
177 |
178 | compFlow=corrFlow.copy()
179 | for i in range(nFrame):
180 | print("Completing {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
181 | flow = corrFlow[:, :, :, i]
182 | if mode == 'forward':
183 | flow_crop=flow[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0],:]
184 | flow_mask_img = flow_mask[:, :, i]
185 | flow_mask_img_crop=flow_mask_img[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0]]
186 | else:
187 | flow_crop=flow[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0],:]
188 | flow_mask_img = flow_mask[:, :, i+1]
189 | flow_mask_img_crop=flow_mask_img[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0]]
190 |
191 | # cv2.imwrite("./flow_mask_img_crop.png",flow_mask_img_crop*255)
192 | flow_mask_gradient_img = gradient_mask(flow_mask_img)
193 | flow_mask_gradient_img_crop = gradient_mask(flow_mask_img_crop)
194 |
195 | if edge is not None:
196 | # imgH x (imgW - 1 + 1) x 2
197 | gradient_x = np.concatenate((np.diff(flow_crop, axis=1), np.zeros((flow_crop.shape[0], 1, 2), dtype=np.float32)), axis=1)
198 | # (imgH - 1 + 1) x imgW x 2
199 | gradient_y = np.concatenate((np.diff(flow_crop, axis=0), np.zeros((1, flow_crop.shape[1], 2), dtype=np.float32)), axis=0)
200 |
201 | # concatenate gradient_x and gradient_y
202 | gradient = np.concatenate((gradient_x, gradient_y), axis=2)
203 |
204 | # We can trust the gradient outside of flow_mask_gradient_img
205 | # We assume the gradient within flow_mask_gradient_img is 0.
206 | gradient[flow_mask_gradient_img_crop, :] = 0
207 |
208 | # Complete the flow
209 | imgSrc_gy = gradient[:, :, 2 : 4]
210 | imgSrc_gy = imgSrc_gy[0 : flow_crop.shape[0] - 1, :, :]
211 | imgSrc_gx = gradient[:, :, 0 : 2]
212 | imgSrc_gx = imgSrc_gx[:, 0 : flow_crop.shape[1] - 1, :]
213 | if mode == 'forward':
214 | edge_crop=edge[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0],i]
215 | else:
216 | edge_crop=edge[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0],i]
217 | compFlow_crop = Poisson_blend(flow_crop, imgSrc_gx, imgSrc_gy, flow_mask_img_crop, edge_crop)
218 |
219 | #return original size
220 | if mode == 'forward':
221 | compFlow[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0], :, i] = compFlow_crop
222 | else:
223 | compFlow[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0], :, i] = compFlow_crop
224 |
225 | else:
226 | flow[:, :, 0] = rf.regionfill(flow[:, :, 0], flow_mask_img)
227 | flow[:, :, 1] = rf.regionfill(flow[:, :, 1], flow_mask_img)
228 | compFlow[:, :, :, i] = flow
229 |
230 | ## Flow visualization.
231 | # flow_img = utils.flow_viz.flow_to_image(compFlow[:, :, :, i])
232 | # flow_img = Image.fromarray(flow_img)
233 |
234 | ## Saves the flow and flow_img.
235 | # flow_img.save(os.path.join(args.outroot, 'flow_comp', mode + '_png', '%05d.png'%i))
236 | # utils.frame_utils.writeFlow(os.path.join(args.outroot, 'flow_comp', mode + '_flo', '%05d.flo'%i), compFlow[:, :, :, i])
237 |
238 | return compFlow
239 |
240 |
241 | def edge_completion(args, EdgeGenerator, corrFlow, flow_mask, mode):
242 | """Calculate flow edge and complete it.
243 | """
244 |
245 | if mode not in ['forward', 'backward']:
246 | raise NotImplementedError
247 |
248 | imgH, imgW, _, nFrame = corrFlow.shape
249 | Edge = np.empty(((imgH, imgW, 0)), dtype=np.float32)
250 |
251 | for i in range(nFrame):
252 | print("Completing {0} flow edge {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
253 | flow_mask_img = flow_mask[:, :, i] if mode == 'forward' else flow_mask[:, :, i + 1]
254 |
255 | flow_img_gray = (corrFlow[:, :, 0, i] ** 2 + corrFlow[:, :, 1, i] ** 2) ** 0.5
256 | flow_img_gray = flow_img_gray / flow_img_gray.max()
257 |
258 | edge_corr = canny(flow_img_gray, sigma=2, mask=(1 - flow_mask_img).astype(np.bool))
259 | edge_completed = infer(args, EdgeGenerator, torch.device('cuda:1'), flow_img_gray, edge_corr, flow_mask_img)
260 | Edge = np.concatenate((Edge, edge_completed[..., None]), axis=-1)
261 |
262 | return Edge
263 |
264 |
265 | def video_completion_seamless(args):
266 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
267 |
268 | # Flow model.
269 | RAFT_model = initialize_RAFT(args)
270 |
271 | # Loads frames.
272 | filename_list = glob.glob(os.path.join(args.path, '*.png')) + \
273 | glob.glob(os.path.join(args.path, '*.jpg'))
274 |
275 | # Obtains imgH, imgW and nFrame.
276 | imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2]
277 | nFrame = len(filename_list)
278 |
279 | # Loads video.
280 | video = []
281 | for filename in sorted(filename_list):
282 | video.append(torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float())
283 | video = torch.stack(video, dim=0)
284 | video = video.to('cuda')
285 |
286 | # Loads masks.
287 | filename_list = glob.glob(os.path.join(args.path_mask, '*.png')) + \
288 | glob.glob(os.path.join(args.path_mask, '*.jpg'))
289 | mask = []
290 | mask_dilated = []
291 | flow_mask = []
292 | for filename in sorted(filename_list):
293 | mask_img = np.array(Image.open(filename).convert('L'))
294 | flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=3)
295 | # Close the small holes inside the foreground objects
296 | flow_mask_img = cv2.morphologyEx(flow_mask_img.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((11, 11),np.uint8)).astype(np.bool)
297 | flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.bool)
298 | flow_mask.append(flow_mask_img)
299 |
300 | # Dilate a little bit
301 | mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=3)
302 | mask_img = scipy.ndimage.binary_fill_holes(mask_img).astype(np.bool)
303 | mask.append(mask_img)
304 | mask_dilated.append(gradient_mask(mask_img))
305 |
306 | minbbox_tl,minbbox_br=find_minbbox(flow_mask)
307 |
308 | # Image inpainting model.
309 | deepfill = DeepFillv1(pretrained_model=args.deepfill_model, image_shape=[imgH, imgW])
310 |
311 | # timer
312 | time_start=time.time()
313 |
314 | # Calcutes the corrupted flow.
315 | corrFlowF = calculate_flow(args, RAFT_model, video, 'forward')
316 | corrFlowB = calculate_flow(args, RAFT_model, video, 'backward')
317 | print('\nFinish flow prediction.')
318 |
319 | # Makes sure video is in BGR (opencv) format.
320 | video = video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255.
321 |
322 | # mask indicating the missing region in the video.
323 | mask = np.stack(mask, -1).astype(np.bool)
324 | mask_dilated = np.stack(mask_dilated, -1).astype(np.bool)
325 | flow_mask = np.stack(flow_mask, -1).astype(np.bool)
326 |
327 | if args.edge_guide:
328 | # Edge completion model.
329 | EdgeGenerator = EdgeGenerator_()
330 | EdgeComp_ckpt = torch.load(args.edge_completion_model)
331 | EdgeGenerator.load_state_dict(EdgeComp_ckpt['generator'])
332 | EdgeGenerator.to(torch.device('cuda:1'))
333 | EdgeGenerator.eval()
334 |
335 | # Edge completion.
336 | FlowF_edge = edge_completion(args, EdgeGenerator, corrFlowF, flow_mask, 'forward')
337 | FlowB_edge = edge_completion(args, EdgeGenerator, corrFlowB, flow_mask, 'backward')
338 | print('\nFinish edge completion.')
339 | else:
340 | FlowF_edge, FlowB_edge = None, None
341 |
342 |
343 | # Completes the flow.
344 | videoFlowF = complete_flow(args, corrFlowF, flow_mask, 'forward', minbbox_tl, minbbox_br, FlowF_edge)
345 | videoFlowB = complete_flow(args, corrFlowB, flow_mask, 'backward', minbbox_tl, minbbox_br, FlowB_edge)
346 | print('\nFinish flow completion.')
347 |
348 | # Prepare gradients
349 | gradient_x = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32)
350 | gradient_y = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32)
351 |
352 | for indFrame in range(nFrame):
353 | img = video[:, :, :, indFrame]
354 | img[mask[:, :, indFrame], :] = 0
355 |
356 | img = cv2.inpaint((img * 255).astype(np.uint8), mask[:, :, indFrame].astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255.
357 |
358 | gradient_x_ = np.concatenate((np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1)
359 | gradient_y_ = np.concatenate((np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0)
360 | gradient_x = np.concatenate((gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1)
361 | gradient_y = np.concatenate((gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1)
362 |
363 | gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0
364 | gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0
365 |
366 |
367 | iter = 0
368 | mask_tofill = mask
369 | gradient_x_filled = gradient_x # corrupted gradient_x, mask_gradient indicates the missing gradient region
370 | gradient_y_filled = gradient_y # corrupted gradient_y, mask_gradient indicates the missing gradient region
371 | mask_gradient = mask_dilated
372 | video_comp = video
373 |
374 | # We iteratively complete the video.
375 | while(np.sum(mask) > 0):
376 | # create_dir(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter)))
377 |
378 | # Gradient propagation.
379 | gradient_x_filled, gradient_y_filled, mask_gradient = \
380 | get_flowNN_gradient(args,
381 | gradient_x_filled,
382 | gradient_y_filled,
383 | mask,
384 | mask_gradient,
385 | videoFlowF,
386 | videoFlowB,
387 | None,
388 | None)
389 |
390 | # if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending.
391 | for indFrame in range(nFrame):
392 | mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(mask_gradient[:, :, indFrame]).astype(np.bool)
393 |
394 | # After one gradient propagation iteration
395 | # gradient --> RGB
396 | for indFrame in range(nFrame):
397 | print("Poisson blending frame {0:3d}".format(indFrame))
398 |
399 | if mask[:, :, indFrame].sum() > 0:
400 | try:
401 | video_comp_crop=video_comp[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], :, indFrame]
402 | gradient_x_filled_crop=gradient_x_filled[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0]-1, :, indFrame]
403 | gradient_y_filled_crop=gradient_y_filled[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1]-1,minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], :, indFrame]
404 | mask_crop=mask[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], indFrame]
405 | mask_gradient_crop=mask_gradient[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], indFrame];
406 | frameBlend_crop, UnfilledMask_crop = Poisson_blend_img(video_comp_crop, gradient_x_filled_crop, gradient_y_filled_crop, mask_crop, mask_gradient_crop)
407 |
408 | frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame]
409 | frameBlend[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0],:]=frameBlend_crop
410 | UnfilledMask[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0]]=UnfilledMask_crop
411 | # UnfilledMask = scipy.ndimage.binary_fill_holes(UnfilledMask).astype(np.bool)
412 |
413 | # frameBlend, UnfilledMask = Poisson_blend_img(video_comp[:, :, :, indFrame], gradient_x_filled[:, 0 : imgW - 1, :, indFrame], gradient_y_filled[0 : imgH - 1, :, :, indFrame], mask[:, :, indFrame], mask_gradient[:, :, indFrame])
414 | # UnfilledMask = scipy.ndimage.binary_fill_holes(UnfilledMask).astype(np.bool)
415 | except:
416 | frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame]
417 |
418 | frameBlend = np.clip(frameBlend, 0, 1.0)
419 | tmp = cv2.inpaint((frameBlend * 255).astype(np.uint8), UnfilledMask.astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255.
420 | frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :]
421 |
422 | video_comp[:, :, :, indFrame] = frameBlend
423 | mask[:, :, indFrame] = UnfilledMask
424 |
425 | frameBlend_ = copy.deepcopy(frameBlend)
426 | # Green indicates the regions that are not filled yet.
427 | frameBlend_[mask[:, :, indFrame], :] = [0, 1., 0]
428 | else:
429 | frameBlend_ = video_comp[:, :, :, indFrame]
430 |
431 | # cv2.imwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter), '%05d.png'%indFrame), frameBlend_ * 255.)
432 |
433 | # video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1]
434 | # imageio.mimwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter), 'intermediate_{0}.mp4'.format(str(iter))), video_comp_, fps=12, quality=8, macro_block_size=1)
435 | # imageio.mimsave(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter), 'intermediate_{0}.gif'.format(str(iter))), video_comp_, format='gif', fps=12)
436 |
437 | mask, video_comp = spatial_inpaint(deepfill, mask, video_comp)
438 | iter += 1
439 |
440 | # Re-calculate gradient_x/y_filled and mask_gradient
441 | for indFrame in range(nFrame):
442 | mask_gradient[:, :, indFrame] = gradient_mask(mask[:, :, indFrame])
443 |
444 | gradient_x_filled[:, :, :, indFrame] = np.concatenate((np.diff(video_comp[:, :, :, indFrame], axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1)
445 | gradient_y_filled[:, :, :, indFrame] = np.concatenate((np.diff(video_comp[:, :, :, indFrame], axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0)
446 |
447 | gradient_x_filled[mask_gradient[:, :, indFrame], :, indFrame] = 0
448 | gradient_y_filled[mask_gradient[:, :, indFrame], :, indFrame] = 0
449 |
450 | video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1]
451 |
452 | time_end=time.time()
453 | print('time cost',time_end-time_start,'s')
454 |
455 | # write out
456 | create_dir(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final'))
457 | for i in range(nFrame):
458 | img = video_comp[:, :, :, i] * 255
459 | cv2.imwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final', '%06d.png'%i), img)
460 | # imageio.mimwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final', 'final.mp4'), video_comp_, fps=12, quality=8, macro_block_size=1)
461 | # imageio.mimsave(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final', 'final.gif'), video_comp_, format='gif', fps=12)
462 |
463 |
464 | def main(args):
465 |
466 | assert args.mode in ('object_removal', 'video_extrapolation'), (
467 | "Accepted modes: 'object_removal', 'video_extrapolation', but input is %s"
468 | ) % mode
469 |
470 | if args.seamless:
471 | video_completion_seamless(args)
472 | else:
473 | video_completion(args)
474 |
475 |
476 | if __name__ == '__main__':
477 | parser = argparse.ArgumentParser()
478 |
479 | # video completion
480 | parser.add_argument('--seamless', action='store_true', help='Whether operate in the gradient domain')
481 | parser.add_argument('--edge_guide', action='store_true', help='Whether use edge as guidance to complete flow')
482 | parser.add_argument('--mode', default='object_removal', help="modes: object_removal / video_extrapolation")
483 | parser.add_argument('--path', default='../data/frames_corr', help="dataset for evaluation")
484 | parser.add_argument('--path_mask', default='../data/masks', help="mask for object removal")
485 | parser.add_argument('--outroot', default='../result/', help="output directory")
486 | parser.add_argument('--consistencyThres', dest='consistencyThres', default=np.inf, type=float, help='flow consistency error threshold')
487 | parser.add_argument('--alpha', dest='alpha', default=0.1, type=float)
488 | parser.add_argument('--Nonlocal', dest='Nonlocal', default=False, type=bool)
489 |
490 | # RAFT
491 | parser.add_argument('--model', default='./weight/raft-things.pth', help="restore checkpoint")
492 | parser.add_argument('--small', action='store_true', help='use small model')
493 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
494 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
495 |
496 | # Deepfill
497 | parser.add_argument('--deepfill_model', default='./weight/imagenet_deepfill.pth', help="restore checkpoint")
498 |
499 | # Edge completion
500 | parser.add_argument('--edge_completion_model', default='./weight/edge_completion.pth', help="restore checkpoint")
501 |
502 | # extrapolation
503 | parser.add_argument('--H_scale', dest='H_scale', default=2, type=float, help='H extrapolation scale')
504 | parser.add_argument('--W_scale', dest='W_scale', default=2, type=float, help='W extrapolation scale')
505 |
506 | args = parser.parse_args()
507 |
508 | main(args)
509 |
510 |
--------------------------------------------------------------------------------