├── utils ├── __init__.py ├── region_fill.py ├── Poisson_blend.py └── Poisson_blend_img.py ├── edgeconnect ├── __init__.py ├── metrics.py ├── config.py ├── region_fill.py ├── utils.py ├── loss.py ├── networks.py ├── models.py └── dataset.py ├── models ├── DeepFill_Models │ ├── util.py │ ├── __init__.py │ ├── DeepFill.py │ └── ops.py └── __init__.py ├── .gitignore ├── RAFT ├── __init__.py ├── utils │ ├── __init__.py │ ├── utils.py │ ├── frame_utils.py │ ├── flow_viz.py │ └── augmentor.py ├── demo.py ├── corr.py ├── raft.py ├── update.py ├── extractor.py └── datasets.py ├── requirements.txt ├── download_data_weights.sh ├── tool ├── spatial_inpaint.py ├── frame_inpaint.py └── video_completion_modified.py ├── LICENSE ├── FAQ.md ├── evaluate.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /edgeconnect/__init__.py: -------------------------------------------------------------------------------- 1 | # empty -------------------------------------------------------------------------------- /models/DeepFill_Models/util.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/DeepFill_Models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | *.pyc 3 | data 4 | weight 5 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DeepFill_Models import DeepFill 2 | -------------------------------------------------------------------------------- /RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | -------------------------------------------------------------------------------- /RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | opencv-contrib-python 3 | imageio 4 | imageio-ffmpeg 5 | scipy 6 | scikit-image 7 | -------------------------------------------------------------------------------- /download_data_weights.sh: -------------------------------------------------------------------------------- 1 | 2 | wget https://filebox.ece.vt.edu/~chengao/FGVC/data.zip 3 | unzip data.zip 4 | rm data.zip 5 | 6 | wget https://filebox.ece.vt.edu/~chengao/FGVC/weight.zip 7 | unzip weight.zip 8 | rm weight.zip 9 | -------------------------------------------------------------------------------- /tool/spatial_inpaint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def spatial_inpaint(deepfill, mask, video_comp): 9 | 10 | keyFrameInd = np.argmax(np.sum(np.sum(mask, axis=0), axis=0)) 11 | with torch.no_grad(): 12 | img_res = deepfill.forward(video_comp[:, :, :, keyFrameInd] * 255., mask[:, :, keyFrameInd]) / 255. 13 | video_comp[mask[:, :, keyFrameInd], :, keyFrameInd] = img_res[mask[:, :, keyFrameInd], :] 14 | mask[:, :, keyFrameInd] = False 15 | 16 | return mask, video_comp 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2020 Virginia Tech Vision and Learning Lab 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | --------------------------- LICENSE FOR EdgeConnect -------------------------------- 25 | 26 | Attribution-NonCommercial 4.0 International 27 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | Q: val集里的video_0046里有全白的mask?
2 | A: val集里的video_0046最后几张因为目标出画了,mask误成了全白。可以直接忽略、删除。
3 |
4 | Q: 允许在线更新吗?
5 | A: 允许。
6 |
7 | Q: 你们可以提供掩膜吗?
8 | A: 训练集不提供掩膜,选手训练过程中可以随机挖、无穷无尽地挖~
9 |
10 | Q: 内存显存8G的限制对于训练要不要求?还是只是推断过程?
11 | A: 8G的要求指的是推理过程。训练时不做要求,别大到无法复现即可。
12 |
13 | Q: 测试集的mask也可以拿来训练吗?
14 | A: 测试集的mask可以用,val和test_a都是同样的一套水印和舞者模版,没什么差别。但是,测试集的视频不能用来训练。
15 |
16 | Q: test_a里的0095这个视频前面几张是全黑的,正常吗?
17 | A: 没有关系,原图也黑,不影响。
18 |
19 | Q: 我开一个主线程实时测内存和运行时间,然后开一个子线程进行真正的计算处理,这样可以吗?
20 | A: 可以的。真正计算是单独的一个线程是可以的。
21 |
22 | Q: 复现误差要求多少以内?
23 | A: 我们会根据选手提交的代码进行复现。复现误差小于0.5%时,不改变名次,否则会修正名次。当误差过大且选手无法给出合理解释时,会取消成绩。
24 |
25 | Q: 可以直接告诉你们哪些id我用哪种方法吗?
26 | A: 不可以去直接指定特殊id用特殊方法,但是你可以用算法去判断用哪种方法。
27 |
28 | Q: 我用同一种方法迭代出了不同的模型,我可以挑选出来最佳组合结果吗?比如一个模型在验证集上对某一类掩膜效果好,然后还有一个对另外一类掩膜效果好,提交结果的时候提交他两的组合可以么?
29 | A: 同上。不可以去直接指定特殊id用特殊模型,但是你可以用算法去判断用哪种模型。
30 |
31 | Q: 单卡是指在服务器跑还是本地呢? V100跟3090不同模型可能速度不一样?
32 | A: 服务器上。显卡型号V100或3090。要求15小时内跑完是定得比较宽泛了,选手也不能卡着15小时来。而且我们也不比速度,速度在范围内就可以了。
33 |
34 | Q: 8G指的是8000M还是8192M? 内存不超8G是指模型推断过程中的任意一个时刻都不能超过吗?
35 | A: 8192M。测试阶段我们会给选手开一个比8G稍大一点的服务器。
36 |
37 | Q:使用的预训练模型是要6月以前的是吗?有star要求吗?
38 | A: 使用的预训练模型需要为2021年6月1日以前公开的,没有star要求。
39 |
40 | 41 | 42 | -------------------------------------------------------------------------------- /edgeconnect/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EdgeAccuracy(nn.Module): 6 | """ 7 | Measures the accuracy of the edge map 8 | """ 9 | def __init__(self, threshold=0.5): 10 | super(EdgeAccuracy, self).__init__() 11 | self.threshold = threshold 12 | 13 | def __call__(self, inputs, outputs): 14 | labels = (inputs > self.threshold) 15 | outputs = (outputs > self.threshold) 16 | 17 | relevant = torch.sum(labels.float()) 18 | selected = torch.sum(outputs.float()) 19 | 20 | if relevant == 0 and selected == 0: 21 | return torch.tensor(1), torch.tensor(1) 22 | 23 | true_positive = ((outputs == labels) * labels).float() 24 | recall = torch.sum(true_positive) / (relevant + 1e-8) 25 | precision = torch.sum(true_positive) / (selected + 1e-8) 26 | 27 | return precision, recall 28 | 29 | 30 | class PSNR(nn.Module): 31 | def __init__(self, max_val): 32 | super(PSNR, self).__init__() 33 | 34 | base10 = torch.log(torch.tensor(10.0)) 35 | max_val = torch.tensor(max_val).float() 36 | 37 | self.register_buffer('base10', base10) 38 | self.register_buffer('max_val', 20 * torch.log(max_val) / base10) 39 | 40 | def __call__(self, a, b): 41 | mse = torch.mean((a.float() - b.float()) ** 2) 42 | 43 | if mse == 0: 44 | return torch.tensor(0) 45 | 46 | return self.max_val - 10 * torch.log(mse) / self.base10 47 | -------------------------------------------------------------------------------- /RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import cv2 5 | import glob 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from .raft import RAFT 11 | from .utils import flow_viz 12 | from .utils.utils import InputPadder 13 | 14 | 15 | 16 | DEVICE = 'cuda' 17 | 18 | def load_image(imfile): 19 | img = np.array(Image.open(imfile)).astype(np.uint8) 20 | img = torch.from_numpy(img).permute(2, 0, 1).float() 21 | return img 22 | 23 | 24 | def load_image_list(image_files): 25 | images = [] 26 | for imfile in sorted(image_files): 27 | images.append(load_image(imfile)) 28 | 29 | images = torch.stack(images, dim=0) 30 | images = images.to(DEVICE) 31 | 32 | padder = InputPadder(images.shape) 33 | return padder.pad(images)[0] 34 | 35 | 36 | def viz(img, flo): 37 | img = img[0].permute(1,2,0).cpu().numpy() 38 | flo = flo[0].permute(1,2,0).cpu().numpy() 39 | 40 | # map flow to rgb image 41 | flo = flow_viz.flow_to_image(flo) 42 | # img_flo = np.concatenate([img, flo], axis=0) 43 | img_flo = flo 44 | 45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 47 | # cv2.waitKey() 48 | 49 | 50 | def demo(args): 51 | model = torch.nn.DataParallel(RAFT(args)) 52 | model.load_state_dict(torch.load(args.model)) 53 | 54 | model = model.module 55 | model.to(DEVICE) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 60 | glob.glob(os.path.join(args.path, '*.jpg')) 61 | 62 | images = load_image_list(images) 63 | for i in range(images.shape[0]-1): 64 | image1 = images[i,None] 65 | image2 = images[i+1,None] 66 | 67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 68 | viz(image1, flow_up) 69 | 70 | 71 | def RAFT_infer(args): 72 | model = torch.nn.DataParallel(RAFT(args)) 73 | model.load_state_dict(torch.load(args.model)) 74 | 75 | model = model.module 76 | model.to(DEVICE) 77 | model.eval() 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /edgeconnect/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | class Config(dict): 5 | def __init__(self, config_path): 6 | with open(config_path, 'r') as f: 7 | self._yaml = f.read() 8 | self._dict = yaml.load(self._yaml) 9 | self._dict['PATH'] = os.path.dirname(config_path) 10 | 11 | def __getattr__(self, name): 12 | if self._dict.get(name) is not None: 13 | return self._dict[name] 14 | 15 | if DEFAULT_CONFIG.get(name) is not None: 16 | return DEFAULT_CONFIG[name] 17 | 18 | return None 19 | 20 | def print(self): 21 | print('Model configurations:') 22 | print('---------------------------------') 23 | print(self._yaml) 24 | print('') 25 | print('---------------------------------') 26 | print('') 27 | 28 | 29 | DEFAULT_CONFIG = { 30 | 'MODE': 1, # 1: train, 2: test, 3: eval 31 | 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model 32 | 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 33 | 'EDGE': 1, # 1: canny, 2: external 34 | 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny 35 | 'SEED': 10, # random seed 36 | 'GPU': [0], # list of gpu ids 37 | 'DEBUG': 0, # turns on debugging mode 38 | 'VERBOSE': 0, # turns on verbose mode in the output console 39 | 40 | 'LR': 0.0001, # learning rate 41 | 'D2G_LR': 0.1, # discriminator/generator learning rate ratio 42 | 'BETA1': 0.0, # adam optimizer beta1 43 | 'BETA2': 0.9, # adam optimizer beta2 44 | 'BATCH_SIZE': 8, # input batch size for training 45 | 'INPUT_SIZE': 256, # input image size for training 0 for original size 46 | 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) 47 | 'MAX_ITERS': 2e6, # maximum number of iterations to train the model 48 | 49 | 'EDGE_THRESHOLD': 0.5, # edge detection threshold 50 | 'L1_LOSS_WEIGHT': 1, # l1 loss weight 51 | 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight 52 | 'STYLE_LOSS_WEIGHT': 1, # style loss weight 53 | 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight 54 | 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight 55 | 56 | 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge 57 | 'GAN_POOL_SIZE': 0, # fake images pool size 58 | 59 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never) 60 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never) 61 | 'SAMPLE_SIZE': 12, # number of images to sample 62 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never) 63 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never) 64 | } 65 | -------------------------------------------------------------------------------- /models/DeepFill_Models/DeepFill.py: -------------------------------------------------------------------------------- 1 | from .ops import * 2 | 3 | 4 | class Generator(nn.Module): 5 | def __init__(self, first_dim=32, isCheck=False, device=None): 6 | super(Generator, self).__init__() 7 | self.isCheck = isCheck 8 | self.device = device 9 | self.stage_1 = CoarseNet(5, first_dim, device=device) 10 | self.stage_2 = RefinementNet(5, first_dim, device=device) 11 | 12 | def forward(self, masked_img, mask, small_mask): # mask : 1 x 1 x H x W 13 | 14 | # border, maybe 15 | mask = mask.expand(masked_img.size(0),1,masked_img.size(2),masked_img.size(3)) 16 | small_mask = small_mask.expand(masked_img.size(0), 1, masked_img.size(2) // 8, masked_img.size(3) // 8) 17 | if self.device: 18 | ones = to_var(torch.ones(mask.size()), device=self.device) 19 | else: 20 | ones = to_var(torch.ones(mask.size())) 21 | # stage1 22 | stage1_input = torch.cat([masked_img, ones, ones*mask], dim=1) 23 | stage1_output, resized_mask = self.stage_1(stage1_input, mask) 24 | # stage2 25 | new_masked_img = stage1_output*mask.clone() + masked_img.clone()*(1.-mask.clone()) 26 | stage2_input = torch.cat([new_masked_img, ones.clone(), ones.clone()*mask.clone()], dim=1) 27 | stage2_output, offset_flow = self.stage_2(stage2_input, small_mask) 28 | 29 | return stage1_output, stage2_output, offset_flow 30 | 31 | 32 | class CoarseNet(nn.Module): 33 | ''' 34 | # input: B x 5 x W x H 35 | # after down: B x 128(32*4) x W/4 x H/4 36 | # after atrous: same with the output size of the down module 37 | # after up : same with the input size 38 | ''' 39 | def __init__(self, in_ch, out_ch, device=None): 40 | super(CoarseNet,self).__init__() 41 | self.down = Down_Module(in_ch, out_ch) 42 | self.atrous = Dilation_Module(out_ch*4, out_ch*4) 43 | self.up = Up_Module(out_ch*4, 3) 44 | self.device=device 45 | 46 | def forward(self, x, mask): 47 | x = self.down(x) 48 | resized_mask = down_sample(mask, scale_factor=0.25, mode='nearest', device=self.device) 49 | x = self.atrous(x) 50 | x = self.up(x) 51 | 52 | return x, resized_mask 53 | 54 | 55 | class RefinementNet(nn.Module): 56 | ''' 57 | # input: B x 5 x W x H 58 | # after down: B x 128(32*4) x W/4 x H/4 59 | # after atrous: same with the output size of the down module 60 | # after up : same with the input size 61 | ''' 62 | def __init__(self, in_ch, out_ch, device=None): 63 | super(RefinementNet,self).__init__() 64 | self.down_conv_branch = Down_Module(in_ch, out_ch, isRefine=True) 65 | self.down_attn_branch = Down_Module(in_ch, out_ch, activation=nn.ReLU(), isRefine=True, isAttn=True) 66 | self.atrous = Dilation_Module(out_ch*4, out_ch*4) 67 | self.CAttn = Contextual_Attention_Module(out_ch*4, out_ch*4, device=device) 68 | self.up = Up_Module(out_ch*8, 3, isRefine=True) 69 | 70 | def forward(self, x, resized_mask): 71 | # conv branch 72 | conv_x = self.down_conv_branch(x) 73 | conv_x = self.atrous(conv_x) 74 | 75 | # attention branch 76 | attn_x = self.down_attn_branch(x) 77 | 78 | attn_x, offset_flow = self.CAttn(attn_x, attn_x, mask=resized_mask) 79 | 80 | # concat two branches 81 | deconv_x = torch.cat([conv_x, attn_x], dim=1) # deconv_x => B x 256 x W/4 x H/4 82 | x = self.up(deconv_x) 83 | 84 | return x, offset_flow 85 | -------------------------------------------------------------------------------- /RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = \ 79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) 80 | return fmap1_grad, fmap2_grad, coords_grad, None 81 | 82 | 83 | class AlternateCorrBlock: 84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 85 | self.num_levels = num_levels 86 | self.radius = radius 87 | 88 | self.pyramid = [(fmap1, fmap2)] 89 | for i in range(self.num_levels): 90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 92 | self.pyramid.append((fmap1, fmap2)) 93 | 94 | def __call__(self, coords): 95 | 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /tool/frame_inpaint.py: -------------------------------------------------------------------------------- 1 | import sys, os, argparse 2 | sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) 3 | import torch 4 | import numpy as np 5 | import cv2 6 | 7 | from models import DeepFill 8 | 9 | 10 | class DeepFillv1(object): 11 | def __init__(self, 12 | pretrained_model=None, 13 | image_shape=[512, 960], 14 | res_shape=None, 15 | device=torch.device('cuda:0')): 16 | self.image_shape = image_shape 17 | self.res_shape = res_shape 18 | self.device = device 19 | 20 | self.deepfill = DeepFill.Generator().to(device) 21 | model_weight = torch.load(pretrained_model) 22 | self.deepfill.load_state_dict(model_weight, strict=True) 23 | self.deepfill.eval() 24 | print('Load Deepfill Model from', pretrained_model) 25 | 26 | def forward(self, img, mask): 27 | 28 | img, mask, small_mask = self.data_preprocess(img, mask, size=self.image_shape) 29 | 30 | image = torch.stack([img]) 31 | mask = torch.stack([mask]) 32 | small_mask = torch.stack([small_mask]) 33 | 34 | with torch.no_grad(): 35 | _, inpaint_res, _ = self.deepfill(image.to(self.device), mask.to(self.device), small_mask.to(self.device)) 36 | 37 | res_complete = self.data_proprocess(image, mask, inpaint_res) 38 | 39 | return res_complete 40 | 41 | def data_preprocess(self, img, mask, enlarge_kernel=0, size=[512, 960]): 42 | img = img / 127.5 - 1 43 | mask = (mask > 0).astype(np.int) 44 | img = cv2.resize(img, (size[1], size[0])) 45 | if enlarge_kernel > 0: 46 | kernel = np.ones((enlarge_kernel, enlarge_kernel), np.uint8) 47 | mask = cv2.dilate(mask, kernel, iterations=1) 48 | mask = (mask > 0).astype(np.uint8) 49 | 50 | small_mask = cv2.resize(mask, (size[1] // 8, size[0] // 8), interpolation=cv2.INTER_NEAREST) 51 | mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST) 52 | 53 | if len(mask.shape) == 3: 54 | mask = mask[:, :, 0:1] 55 | else: 56 | mask = np.expand_dims(mask, axis=2) 57 | 58 | if len(small_mask.shape) == 3: 59 | small_mask = small_mask[:, :, 0:1] 60 | else: 61 | small_mask = np.expand_dims(small_mask, axis=2) 62 | 63 | img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float() 64 | mask = torch.from_numpy(mask).permute(2, 0, 1).contiguous().float() 65 | small_mask = torch.from_numpy(small_mask).permute(2, 0, 1).contiguous().float() 66 | 67 | return img*(1-mask), mask, small_mask 68 | 69 | def data_proprocess(self, img, mask, res): 70 | img = img.cpu().data.numpy()[0] 71 | mask = mask.data.numpy()[0] 72 | res = res.cpu().data.numpy()[0] 73 | 74 | res_complete = res * mask + img * (1. - mask) 75 | res_complete = (res_complete + 1) * 127.5 76 | res_complete = res_complete.transpose(1, 2, 0) 77 | if self.res_shape is not None: 78 | res_complete = cv2.resize(res_complete, 79 | (self.res_shape[1], self.res_shape[0])) 80 | 81 | return res_complete 82 | 83 | 84 | def parse_arges(): 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--image_shape', type=int, nargs='+', 87 | default=[512, 960]) 88 | parser.add_argument('--res_shape', type=int, nargs='+', 89 | default=None) 90 | parser.add_argument('--pretrained_model', type=str, 91 | default='/home/chengao/Weight/imagenet_deepfill.pth') 92 | parser.add_argument('--test_img', type=str, 93 | default='/work/cascades/chengao/DAVIS-540/bear_540p/00000.png') 94 | parser.add_argument('--test_mask', type=str, 95 | default='/work/cascades/chengao/DAVIS-540-baseline/mask_540p.png') 96 | parser.add_argument('--output_path', type=str, 97 | default='/home/chengao/res_00000.png') 98 | 99 | args = parser.parse_args() 100 | 101 | return args 102 | 103 | 104 | def main(): 105 | 106 | args = parse_arges() 107 | 108 | deepfill = DeepFillv1(pretrained_model=args.pretrained_model, 109 | image_shape=args.image_shape, 110 | res_shape=args.res_shape) 111 | 112 | test_image = cv2.imread(args.test_img) 113 | mask = cv2.imread(args.test_mask, cv2.IMREAD_UNCHANGED) 114 | 115 | with torch.no_grad(): 116 | img_res = deepfill.forward(test_image, mask) 117 | 118 | cv2.imwrite(args.output_path, img_res) 119 | print('Result Saved') 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /utils/region_fill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy import sparse 4 | from scipy.sparse.linalg import spsolve 5 | 6 | 7 | def regionfill(I, mask, factor=1.0): 8 | if np.count_nonzero(mask) == 0: 9 | return I.copy() 10 | resize_mask = cv2.resize( 11 | mask.astype(float), (0, 0), fx=factor, fy=factor) > 0 12 | resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor) 13 | maskPerimeter = findBoundaryPixels(resize_mask) 14 | regionfillLaplace(resize_I, resize_mask, maskPerimeter) 15 | resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0])) 16 | resize_I[mask == 0] = I[mask == 0] 17 | return resize_I 18 | 19 | 20 | def findBoundaryPixels(mask): 21 | kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) 22 | maskDilated = cv2.dilate(mask.astype(float), kernel) 23 | return (maskDilated > 0) & (mask == 0) 24 | 25 | 26 | def regionfillLaplace(I, mask, maskPerimeter): 27 | height, width = I.shape 28 | rightSide = formRightSide(I, maskPerimeter) 29 | 30 | # Location of mask pixels 31 | maskIdx = np.where(mask) 32 | 33 | # Only keep values for pixels that are in the mask 34 | rightSide = rightSide[maskIdx] 35 | 36 | # Number the mask pixels in a grid matrix 37 | grid = -np.ones((height, width)) 38 | grid[maskIdx] = range(0, maskIdx[0].size) 39 | # Pad with zeros to avoid "index out of bounds" errors in the for loop 40 | grid = padMatrix(grid) 41 | gridIdx = np.where(grid >= 0) 42 | 43 | # Form the connectivity matrix D=sparse(i,j,s) 44 | # Connect each mask pixel to itself 45 | i = np.arange(0, maskIdx[0].size) 46 | j = np.arange(0, maskIdx[0].size) 47 | # The coefficient is the number of neighbors over which we average 48 | numNeighbors = computeNumberOfNeighbors(height, width) 49 | s = numNeighbors[maskIdx] 50 | # Now connect the N,E,S,W neighbors if they exist 51 | for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)): 52 | # Possible neighbors in the current direction 53 | neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]] 54 | # ConDnect mask points to neighbors with -1's 55 | index = (neighbors >= 0) 56 | i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]])) 57 | j = np.concatenate((j, neighbors[index])) 58 | s = np.concatenate((s, -np.ones(np.count_nonzero(index)))) 59 | 60 | D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr() 61 | sol = spsolve(D, rightSide) 62 | I[maskIdx] = sol 63 | return I 64 | 65 | 66 | def formRightSide(I, maskPerimeter): 67 | height, width = I.shape 68 | perimeterValues = np.zeros((height, width)) 69 | perimeterValues[maskPerimeter] = I[maskPerimeter] 70 | rightSide = np.zeros((height, width)) 71 | 72 | rightSide[1:height - 1, 1:width - 1] = ( 73 | perimeterValues[0:height - 2, 1:width - 1] + 74 | perimeterValues[2:height, 1:width - 1] + 75 | perimeterValues[1:height - 1, 0:width - 2] + 76 | perimeterValues[1:height - 1, 2:width]) 77 | 78 | rightSide[1:height - 1, 0] = ( 79 | perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] + 80 | perimeterValues[1:height - 1, 1]) 81 | 82 | rightSide[1:height - 1, width - 1] = ( 83 | perimeterValues[0:height - 2, width - 1] + 84 | perimeterValues[2:height, width - 1] + 85 | perimeterValues[1:height - 1, width - 2]) 86 | 87 | rightSide[0, 1:width - 1] = ( 88 | perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] + 89 | perimeterValues[0, 2:width]) 90 | 91 | rightSide[height - 1, 1:width - 1] = ( 92 | perimeterValues[height - 2, 1:width - 1] + 93 | perimeterValues[height - 1, 0:width - 2] + 94 | perimeterValues[height - 1, 2:width]) 95 | 96 | rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0] 97 | rightSide[0, width - 1] = ( 98 | perimeterValues[0, width - 2] + perimeterValues[1, width - 1]) 99 | rightSide[height - 1, 0] = ( 100 | perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1]) 101 | rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] + 102 | perimeterValues[height - 1, width - 2]) 103 | return rightSide 104 | 105 | 106 | def computeNumberOfNeighbors(height, width): 107 | # Initialize 108 | numNeighbors = np.zeros((height, width)) 109 | # Interior pixels have 4 neighbors 110 | numNeighbors[1:height - 1, 1:width - 1] = 4 111 | # Border pixels have 3 neighbors 112 | numNeighbors[1:height - 1, (0, width - 1)] = 3 113 | numNeighbors[(0, height - 1), 1:width - 1] = 3 114 | # Corner pixels have 2 neighbors 115 | numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0, 116 | width - 1)] = 2 117 | return numNeighbors 118 | 119 | 120 | def padMatrix(grid): 121 | height, width = grid.shape 122 | gridPadded = -np.ones((height + 2, width + 2)) 123 | gridPadded[1:height + 1, 1:width + 1] = grid 124 | gridPadded = gridPadded.astype(grid.dtype) 125 | return gridPadded 126 | -------------------------------------------------------------------------------- /RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in args._get_kwargs(): 42 | args.dropout = 0 43 | 44 | if 'alternate_corr' not in args._get_kwargs(): 45 | args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | 87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 88 | """ Estimate optical flow between pair of frames """ 89 | 90 | image1 = 2 * (image1 / 255.0) - 1.0 91 | image2 = 2 * (image2 / 255.0) - 1.0 92 | 93 | image1 = image1.contiguous() 94 | image2 = image2.contiguous() 95 | 96 | hdim = self.hidden_dim 97 | cdim = self.context_dim 98 | 99 | # run the feature network 100 | with autocast(enabled=self.args.mixed_precision): 101 | fmap1, fmap2 = self.fnet([image1, image2]) 102 | 103 | fmap1 = fmap1.float() 104 | fmap2 = fmap2.float() 105 | if self.args.alternate_corr: 106 | corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) 107 | else: 108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 109 | 110 | # run the context network 111 | with autocast(enabled=self.args.mixed_precision): 112 | cnet = self.cnet(image1) 113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 114 | net = torch.tanh(net) 115 | inp = torch.relu(inp) 116 | 117 | coords0, coords1 = self.initialize_flow(image1) 118 | 119 | if flow_init is not None: 120 | coords1 = coords1 + flow_init 121 | 122 | flow_predictions = [] 123 | for itr in range(iters): 124 | coords1 = coords1.detach() 125 | corr = corr_fn(coords1) # index correlation volume 126 | 127 | flow = coords1 - coords0 128 | with autocast(enabled=self.args.mixed_precision): 129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 130 | 131 | # F(t+1) = F(t) + \Delta(t) 132 | coords1 = coords1 + delta_flow 133 | 134 | # upsample predictions 135 | if up_mask is None: 136 | flow_up = upflow8(coords1 - coords0) 137 | else: 138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 139 | 140 | flow_predictions.append(flow_up) 141 | 142 | if test_mode: 143 | return coords1 - coords0, flow_up 144 | 145 | return flow_predictions 146 | -------------------------------------------------------------------------------- /edgeconnect/region_fill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy import sparse 4 | from scipy.sparse.linalg import spsolve 5 | 6 | 7 | def regionfill(I, mask, factor=1.0): 8 | if np.count_nonzero(mask) == 0: 9 | return I.copy() 10 | resize_mask = cv2.resize( 11 | mask.astype(float), (0, 0), fx=factor, fy=factor) > 0 12 | resize_I = cv2.resize(I.astype(float), (0, 0), fx=factor, fy=factor) 13 | maskPerimeter = findBoundaryPixels(resize_mask) 14 | regionfillLaplace(resize_I, resize_mask, maskPerimeter) 15 | resize_I = cv2.resize(resize_I, (I.shape[1], I.shape[0])) 16 | resize_I[mask == 0] = I[mask == 0] 17 | return resize_I 18 | 19 | 20 | def findBoundaryPixels(mask): 21 | kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) 22 | maskDilated = cv2.dilate(mask.astype(float), kernel) 23 | return (maskDilated > 0) & (mask == 0) 24 | 25 | 26 | def regionfillLaplace(I, mask, maskPerimeter): 27 | height, width = I.shape 28 | rightSide = formRightSide(I, maskPerimeter) 29 | 30 | # Location of mask pixels 31 | maskIdx = np.where(mask) 32 | 33 | # Only keep values for pixels that are in the mask 34 | rightSide = rightSide[maskIdx] 35 | 36 | # Number the mask pixels in a grid matrix 37 | grid = -np.ones((height, width)) 38 | grid[maskIdx] = range(0, maskIdx[0].size) 39 | # Pad with zeros to avoid "index out of bounds" errors in the for loop 40 | grid = padMatrix(grid) 41 | gridIdx = np.where(grid >= 0) 42 | 43 | # Form the connectivity matrix D=sparse(i,j,s) 44 | # Connect each mask pixel to itself 45 | i = np.arange(0, maskIdx[0].size) 46 | j = np.arange(0, maskIdx[0].size) 47 | # The coefficient is the number of neighbors over which we average 48 | numNeighbors = computeNumberOfNeighbors(height, width) 49 | s = numNeighbors[maskIdx] 50 | # Now connect the N,E,S,W neighbors if they exist 51 | for direction in ((-1, 0), (0, 1), (1, 0), (0, -1)): 52 | # Possible neighbors in the current direction 53 | neighbors = grid[gridIdx[0] + direction[0], gridIdx[1] + direction[1]] 54 | # ConDnect mask points to neighbors with -1's 55 | index = (neighbors >= 0) 56 | i = np.concatenate((i, grid[gridIdx[0][index], gridIdx[1][index]])) 57 | j = np.concatenate((j, neighbors[index])) 58 | s = np.concatenate((s, -np.ones(np.count_nonzero(index)))) 59 | 60 | D = sparse.coo_matrix((s, (i.astype(int), j.astype(int)))).tocsr() 61 | sol = spsolve(D, rightSide) 62 | I[maskIdx] = sol 63 | return I 64 | 65 | 66 | def formRightSide(I, maskPerimeter): 67 | height, width = I.shape 68 | perimeterValues = np.zeros((height, width)) 69 | perimeterValues[maskPerimeter] = I[maskPerimeter] 70 | rightSide = np.zeros((height, width)) 71 | 72 | rightSide[1:height - 1, 1:width - 1] = ( 73 | perimeterValues[0:height - 2, 1:width - 1] + 74 | perimeterValues[2:height, 1:width - 1] + 75 | perimeterValues[1:height - 1, 0:width - 2] + 76 | perimeterValues[1:height - 1, 2:width]) 77 | 78 | rightSide[1:height - 1, 0] = ( 79 | perimeterValues[0:height - 2, 0] + perimeterValues[2:height, 0] + 80 | perimeterValues[1:height - 1, 1]) 81 | 82 | rightSide[1:height - 1, width - 1] = ( 83 | perimeterValues[0:height - 2, width - 1] + 84 | perimeterValues[2:height, width - 1] + 85 | perimeterValues[1:height - 1, width - 2]) 86 | 87 | rightSide[0, 1:width - 1] = ( 88 | perimeterValues[1, 1:width - 1] + perimeterValues[0, 0:width - 2] + 89 | perimeterValues[0, 2:width]) 90 | 91 | rightSide[height - 1, 1:width - 1] = ( 92 | perimeterValues[height - 2, 1:width - 1] + 93 | perimeterValues[height - 1, 0:width - 2] + 94 | perimeterValues[height - 1, 2:width]) 95 | 96 | rightSide[0, 0] = perimeterValues[0, 1] + perimeterValues[1, 0] 97 | rightSide[0, width - 1] = ( 98 | perimeterValues[0, width - 2] + perimeterValues[1, width - 1]) 99 | rightSide[height - 1, 0] = ( 100 | perimeterValues[height - 2, 0] + perimeterValues[height - 1, 1]) 101 | rightSide[height - 1, width - 1] = (perimeterValues[height - 2, width - 1] + 102 | perimeterValues[height - 1, width - 2]) 103 | return rightSide 104 | 105 | 106 | def computeNumberOfNeighbors(height, width): 107 | # Initialize 108 | numNeighbors = np.zeros((height, width)) 109 | # Interior pixels have 4 neighbors 110 | numNeighbors[1:height - 1, 1:width - 1] = 4 111 | # Border pixels have 3 neighbors 112 | numNeighbors[1:height - 1, (0, width - 1)] = 3 113 | numNeighbors[(0, height - 1), 1:width - 1] = 3 114 | # Corner pixels have 2 neighbors 115 | numNeighbors[(0, 0, height - 1, height - 1), (0, width - 1, 0, 116 | width - 1)] = 2 117 | return numNeighbors 118 | 119 | 120 | def padMatrix(grid): 121 | height, width = grid.shape 122 | gridPadded = -np.ones((height + 2, width + 2)) 123 | gridPadded[1:height + 1, 1:width + 1] = grid 124 | gridPadded = gridPadded.astype(grid.dtype) 125 | return gridPadded 126 | 127 | 128 | if __name__ == '__main__': 129 | import time 130 | x = np.linspace(0, 255, 500) 131 | xv, _ = np.meshgrid(x, x) 132 | image = ((xv + np.transpose(xv)) / 2.0).astype(int) 133 | mask = np.zeros((500, 500)) 134 | mask[100:259, 100:259] = 1 135 | mask = (mask > 0) 136 | image[mask] = 0 137 | st = time.time() 138 | inpaint = regionfill(image, mask, 0.5).astype(np.uint8) 139 | print(time.time() - st) 140 | cv2.imshow('img', np.concatenate((image.astype(np.uint8), inpaint))) 141 | cv2.waitKey() 142 | -------------------------------------------------------------------------------- /RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 线上测试使用的是GPU版本(调的是piq库)。这里面我们也提供了CPU版的测试代码(调的是skimage库)。 3 | CPU版本和GPU版本分值有差别:此baseline的cpu测评结果为68.6891, gpu测评结果为68.7054 4 | 5 | error_code: 6 | -1 error: video number unmatch 7 | -2 error: image not found 8 | -3 error: image size unmatch 9 | ''' 10 | import os 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import json 14 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 15 | from skimage.metrics import structural_similarity as ssim 16 | import numpy as np 17 | import argparse 18 | import glob 19 | from PIL import Image 20 | import torch 21 | import cv2 22 | from piq import ssim, SSIMLoss 23 | from piq import psnr 24 | 25 | # CPU版本 26 | def PSNR(ximg,yimg): 27 | return compare_psnr(ximg,yimg,data_range=255) 28 | 29 | def SSIM(y,t,value_range=255): 30 | try: 31 | ssim_value = ssim(y, t, gaussian_weights=True, data_range=value_range, multichannel=True) 32 | except ValueError: 33 | #WinSize too small 34 | ssim_value = ssim(y, t, gaussian_weights=True, data_range=value_range, multichannel=True, win_size=3) 35 | return ssim_value 36 | 37 | # GPU版本。 38 | # def PSNR(ximg,yimg): 39 | # gt_tensor = torch.from_numpy(yimg.transpose([2,0,1])).unsqueeze(0).cuda() 40 | # img_tensor = torch.from_numpy(ximg.transpose([2,0,1])).unsqueeze(0).cuda() 41 | # psnr_v = psnr(img_tensor, gt_tensor, data_range=255.).item() 42 | # #print(psnr_v) 43 | # return psnr_v 44 | 45 | # def SSIM(y,t,value_range=255): 46 | # gt_ss_tensor = torch.from_numpy(y.transpose([2,0,1])).unsqueeze(0).cuda() 47 | # img_ss_tensor = torch.from_numpy(t.transpose([2,0,1])).unsqueeze(0).cuda() 48 | # loss = SSIMLoss(data_range=255.).cuda() 49 | # loss_v = loss(gt_ss_tensor, img_ss_tensor).item() 50 | # return 1-loss_v 51 | 52 | 53 | def Evaluate(files_gt, files_pred, methods = [PSNR,SSIM]): 54 | score = {} 55 | for meth in methods: 56 | name = meth.__name__ 57 | results = [] 58 | res=0 59 | frame_num=len(files_gt) 60 | for frame in range(0,frame_num): 61 | # ignore some tiny crops 62 | if files_gt[frame].shape[0]*files_gt[frame].shape[1]<150: 63 | continue 64 | 65 | res = meth(files_pred[frame],files_gt[frame]) 66 | results.append(res) 67 | 68 | mres = np.mean(results) 69 | stdres=np.std(results) 70 | # print(name+": "+str(mres)+" Std: "+str(stdres)) 71 | score['mean']=mres 72 | score['std']=stdres 73 | return score 74 | 75 | 76 | def evaluate(args): 77 | error_code=0 78 | error_flag='successful.' 79 | final_result=[] 80 | 81 | # load video folder 82 | grountruth_folder_list = sorted(glob.glob(os.path.join(args.groundtruth_folder, 'video_0*'))) 83 | prediction_folder_list = sorted(glob.glob(os.path.join(args.prediction_folder,'video_0*'))) 84 | 85 | if len(grountruth_folder_list) != len(prediction_folder_list): 86 | error_code=-1 87 | error_flag='folder number unmatch.' 88 | return error_code, error_flag, 0 89 | 90 | for i in range(0,len(grountruth_folder_list)): 91 | # load video 92 | video_gt=[] 93 | video_predict=[] 94 | image_list = sorted(glob.glob(os.path.join(grountruth_folder_list[i],'gt_crop/*.png'))) 95 | for image_gt in image_list: 96 | video_gt.append(np.array(Image.open(image_gt)).astype(np.uint8)) 97 | 98 | try: 99 | image_predict=prediction_folder_list[i]+'/crop_'+image_gt[-10:] 100 | video_predict.append(np.array(Image.open(image_predict)).astype(np.uint8)) 101 | except: 102 | error_code=-2 103 | error_flag= 'read ' + image_predict +' failed.' 104 | return error_code, error_flag, 0 105 | 106 | # check image size 107 | for j in range(0,len(image_list)): 108 | if video_gt[j].shape!=video_predict[j].shape: 109 | error_code=-3 110 | error_flag= 'image size unmatch. please check video_' + str(i).zfill(4)+'/crop_'+str(j).zfill(4) 111 | return error_code, error_flag, 0 112 | 113 | # sent in whole video 114 | psnr_res = Evaluate(video_gt,video_predict, methods=[PSNR]) 115 | ssim_res = Evaluate(video_gt,video_predict, methods=[SSIM]) 116 | 117 | psnr_res_norm=min(80,psnr_res['mean']) 118 | ssim_res_norm=ssim_res['mean']*100 119 | 120 | result=psnr_res_norm+ssim_res_norm*0.5 121 | print(i,psnr_res_norm,ssim_res_norm,result) 122 | 123 | final_result.append(result) 124 | return error_code, error_flag, np.mean(final_result) 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--groundtruth_folder',default='./') 130 | parser.add_argument('--prediction_folder',default='./') 131 | # usage: python scoring.py --groundtruth_folder ./val --prediction_folder ./result 132 | 133 | # groundtruth文件夹结构如下(跟解压后的一样): 134 | # 135 | # |-- 000000.png 136 | # |-- video_0000 -- gt_crops --|-- 000001.png 137 | # | |-- ... 138 | # val -|-- video_0001 139 | # | 140 | # |-- ... 141 | 142 | # 选手们上传的文件夹结构如下: 143 | # 144 | # |-- crop_000000.png 145 | # |-- video_0000 --|-- crop_000001.png 146 | # | |-- ... 147 | # result -|-- video_0001 148 | # | 149 | # |-- ... 150 | 151 | args = parser.parse_args() 152 | error_code, error_flag, final_result = evaluate(args) 153 | print(final_result) 154 | 155 | 156 | -------------------------------------------------------------------------------- /utils/Poisson_blend.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import scipy.ndimage 4 | from scipy.sparse.linalg import spsolve 5 | from scipy import sparse 6 | import scipy.io as sio 7 | import numpy as np 8 | from PIL import Image 9 | import copy 10 | import cv2 11 | import os 12 | import argparse 13 | 14 | 15 | def sub2ind(pi, pj, imgH, imgW): 16 | return pj + pi * imgW 17 | 18 | 19 | def Poisson_blend(imgTrg, imgSrc_gx, imgSrc_gy, holeMask, edge=None): 20 | 21 | imgH, imgW, nCh = imgTrg.shape 22 | 23 | if not isinstance(edge, np.ndarray): 24 | edge = np.zeros((imgH, imgW), dtype=np.float32) 25 | 26 | # Initialize the reconstructed image 27 | imgRecon = np.zeros((imgH, imgW, nCh), dtype=np.float32) 28 | 29 | # prepare discrete Poisson equation 30 | A, b = solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, edge) 31 | 32 | # Independently process each channel 33 | for ch in range(nCh): 34 | 35 | # solve Poisson equation 36 | x = scipy.sparse.linalg.lsqr(A, b[:, ch, None])[0] 37 | imgRecon[:, :, ch] = x.reshape(imgH, imgW) 38 | 39 | # Combined with the known region in the target 40 | holeMaskC = np.tile(np.expand_dims(holeMask, axis=2), (1, 1, nCh)) 41 | imgBlend = holeMaskC * imgRecon + (1 - holeMaskC) * imgTrg 42 | 43 | # Fill in edge pixel 44 | pi = np.expand_dims(np.where((holeMask * edge) == 1)[0], axis=1) # y, i 45 | pj = np.expand_dims(np.where((holeMask * edge) == 1)[1], axis=1) # x, j 46 | 47 | for k in range(len(pi)): 48 | if pi[k, 0] - 1 >= 0: 49 | if edge[pi[k, 0] - 1, pj[k, 0]] == 0: 50 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] - 1, pj[k, 0], :] 51 | continue 52 | if pi[k, 0] + 1 <= imgH - 1: 53 | if edge[pi[k, 0] + 1, pj[k, 0]] == 0: 54 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] + 1, pj[k, 0], :] 55 | continue 56 | if pj[k, 0] - 1 >= 0: 57 | if edge[pi[k, 0], pj[k, 0] - 1] == 0: 58 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] - 1, :] 59 | continue 60 | if pj[k, 0] + 1 <= imgW - 1: 61 | if edge[pi[k, 0], pj[k, 0] + 1] == 0: 62 | imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] + 1, :] 63 | 64 | return imgBlend 65 | 66 | def solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, edge): 67 | 68 | # Prepare the linear system of equations for Poisson blending 69 | imgH, imgW = holeMask.shape 70 | N = imgH * imgW 71 | 72 | # Number of unknown variables 73 | numUnknownPix = holeMask.sum() 74 | 75 | # 4-neighbors: dx and dy 76 | dx = [1, 0, -1, 0] 77 | dy = [0, 1, 0, -1] 78 | 79 | # 3 80 | # | 81 | # 2 -- * -- 0 82 | # | 83 | # 1 84 | # 85 | 86 | # Initialize (I, J, S), for sparse matrix A where A(I(k), J(k)) = S(k) 87 | I = np.empty((0, 1), dtype=np.float32) 88 | J = np.empty((0, 1), dtype=np.float32) 89 | S = np.empty((0, 1), dtype=np.float32) 90 | 91 | # Initialize b 92 | b = np.empty((0, 2), dtype=np.float32) 93 | 94 | # Precompute unkonwn pixel position 95 | pi = np.expand_dims(np.where(holeMask == 1)[0], axis=1) # y, i 96 | pj = np.expand_dims(np.where(holeMask == 1)[1], axis=1) # x, j 97 | pind = sub2ind(pi, pj, imgH, imgW) 98 | 99 | # |--------------------| 100 | # | y (i) | 101 | # | x (j) * | 102 | # | | 103 | # |--------------------| 104 | 105 | qi = np.concatenate((pi + dy[0], 106 | pi + dy[1], 107 | pi + dy[2], 108 | pi + dy[3]), axis=1) 109 | 110 | qj = np.concatenate((pj + dx[0], 111 | pj + dx[1], 112 | pj + dx[2], 113 | pj + dx[3]), axis=1) 114 | 115 | # Handling cases at image borders 116 | validN = (qi >= 0) & (qi <= imgH - 1) & (qj >= 0) & (qj <= imgW - 1) 117 | qind = np.zeros((validN.shape), dtype=np.float32) 118 | qind[validN] = sub2ind(qi[validN], qj[validN], imgH, imgW) 119 | 120 | e_start = 0 # equation counter start 121 | e_stop = 0 # equation stop 122 | 123 | # 4 neighbors 124 | I, J, S, b, e_start, e_stop = constructEquation(0, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 125 | I, J, S, b, e_start, e_stop = constructEquation(1, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 126 | I, J, S, b, e_start, e_stop = constructEquation(2, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 127 | I, J, S, b, e_start, e_stop = constructEquation(3, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 128 | 129 | nEqn = len(b) 130 | # Construct the sparse matrix A 131 | A = sparse.csr_matrix((S[:, 0], (I[:, 0], J[:, 0])), shape=(nEqn, N)) 132 | 133 | return A, b 134 | 135 | 136 | def constructEquation(n, validN, holeMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop): 137 | 138 | # Pixel that has valid neighbors 139 | validNeighbor = validN[:, n] 140 | 141 | # Change the out-of-boundary value to 0, in order to run edge[y,x] 142 | # in the next line. It won't affect anything as validNeighbor is saved already 143 | 144 | qi_tmp = copy.deepcopy(qi) 145 | qj_tmp = copy.deepcopy(qj) 146 | qi_tmp[np.invert(validNeighbor), n] = 0 147 | qj_tmp[np.invert(validNeighbor), n] = 0 148 | 149 | # Not edge 150 | NotEdge = (edge[pi[:, 0], pj[:, 0]] == 0) * (edge[qi_tmp[:, n], qj_tmp[:, n]] == 0) 151 | 152 | # Boundary constraint 153 | Boundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 0 154 | valid = validNeighbor * NotEdge * Boundary 155 | J_tmp = pind[valid, :] 156 | 157 | # num of equations: len(J_tmp) 158 | e_stop = e_start + len(J_tmp) 159 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) 160 | e_start = e_stop 161 | 162 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32) 163 | 164 | if n == 0: 165 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] 166 | elif n == 2: 167 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + imgTrg[qi[valid, n], qj[valid, n], :] 168 | elif n == 1: 169 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] 170 | elif n == 3: 171 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] 172 | 173 | I = np.concatenate((I, I_tmp)) 174 | J = np.concatenate((J, J_tmp)) 175 | S = np.concatenate((S, S_tmp)) 176 | b = np.concatenate((b, b_tmp)) 177 | 178 | 179 | # Non-boundary constraint 180 | NonBoundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 1 181 | valid = validNeighbor * NotEdge * NonBoundary 182 | 183 | J_tmp = pind[valid, :] 184 | 185 | # num of equations: len(J_tmp) 186 | e_stop = e_start + len(J_tmp) 187 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) 188 | e_start = e_stop 189 | 190 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32) 191 | 192 | if n == 0: 193 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] 194 | elif n == 2: 195 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] 196 | elif n == 1: 197 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] 198 | elif n == 3: 199 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] 200 | 201 | I = np.concatenate((I, I_tmp)) 202 | J = np.concatenate((J, J_tmp)) 203 | S = np.concatenate((S, S_tmp)) 204 | b = np.concatenate((b, b_tmp)) 205 | 206 | S_tmp = - np.ones(J_tmp.shape, dtype=np.float32) 207 | J_tmp = qind[valid, n, None] 208 | 209 | I = np.concatenate((I, I_tmp)) 210 | J = np.concatenate((J, J_tmp)) 211 | S = np.concatenate((S, S_tmp)) 212 | 213 | return I, J, S, b, e_start, e_stop 214 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Baseline 2 | 本代码的基本流程:计算稠密光流(RAFT)-> 计算边缘(Canny)-> 补全边缘(EdgeConnect)-> 补全光流(解Ax=b)-> 传播RGB值(能从时序上拿来补的像素就拿来补;没有的话就拿最空的一帧来进行图像补全(Deepfill_V1)后再传播)

3 | 在基于原作代码的基础上,为了增速,我修改了一下其中的光流补全部分。原作是全图去解Ax=b,特别特别慢。我改成了crop后再进去,解Ax=b会快些。
4 | 5 | 用法: 6 | ```bash 7 | python ./tool/video_completion_modified.py --mode object_removal --path ../data/test_a/video_0000/frames_corr --path_mask ../data/test_a/video_0000/masks --outroot ../data/result_test_a/video_0000 --seamless --edge_guide 8 | ``` 9 |
10 | 11 | 该baseline:
12 | 1) 在test_a数据集上,本baseline的最终分数约为68.7054。
13 | 2) 速度慢。CPU:Gold5218@2.30GHz, GPU:NVIDIA-V100, 无多线程跑了一天多。
14 | 3) 内存耗用超过8G。显存没有超过8G。
15 | 4) 效果不均。在部分视频上效果较好(e.g.舞动的人,移动的物体等,因有时序上的信息可补足),在部分视频上效果较差(e.g.水印,固定位置的物体等)。
16 | 5) 如果选手选择直接在该baseline改一改训一训,那么在速度和内存上都需要优化。 17 | 18 | 温馨提示:
19 | 1) 在比赛官网提交结果时,顶上将有进度条,且提交成功后会有提示。接收到"提交成功"的提示前不要关掉页面哦。
20 | 2) 评分失败时提示"folder number unmatch"的错误时,原因可能有以下两个: 1.即为提交的视频文件夹数量有错, video_*** 的数量要100个。请查看是否多了不相关文件夹/或者是少了某些视频文件夹; 2. 请直接从内部打包,即result.zip解压后即为 result/video_**** 而不是 aaa/result/video_****
21 | 3) 评分失败时提示"image not found"的错误时,请检查每个文件夹里的图片个数是否完整,命名是否正确。
22 | 4) 评分失败时提示"image size unmatch"的错误时,请检查图片大小是否如bbox.txt所示。
23 |
24 | 25 | 写了一个简易教程(大佬们可以忽略这个教程):https://zhuanlan.zhihu.com/p/381449269 26 |
27 | 28 | # 第二届“马栏山杯”国际音视频算法大赛-视频补全介绍 29 | 30 | ## 数据说明 31 | **训练集**:2194个视频,格式为mp4,视频时长2s~8s。视频大小主要为720x1280和1080x1920。
32 | train1(843个视频,压缩包大小1.98GB),与去年点位跟踪赛道research_1视频集相同(去年数据集里的重复视频现已去重)。
33 | train2(1010个视频,压缩包大小2.18GB),与去年点位跟踪赛道research_2视频集相同(去年数据集里的重复视频现已去重)。
34 | train3(341个视频,压缩包大小646M),为去年点位跟踪赛道val视频集 + 本赛道新视频。
35 | 36 | **验证集**:50个视频,格式为png图像,图像大小为576x1024(我们尽量裁剪掉了角标和字幕区域),帧数120帧及以下。每个样本由原视频和任意一种mask(整块挖空、局部水印、随机噪声块、人像)组成。 37 | 38 | **测试集a**: 100个挖空视频,格式为png图像,图像大小为576x1024(我们尽量裁剪掉了角标和字幕区域),帧数120帧及以下。每个样本包含成对的挖空原视频和其相对应的mask。
39 | **测试集b**: 格式同测试集a。
40 | (测试集的各个mask类型所占的比例与验证集相似。) 41 | 42 | 本竞赛要求整个竞赛过程中不能采用第三方数据,可以使用开源的预训练模型(且开源时间在2021年6月以前)。
43 | 验证集允许加入训练,测试集**禁止**加入训练。
44 | 本大赛提供的数据版权归芒果TV所有,参赛选手不能将其泄露或用于本大赛外其他用途。
45 | 46 | ## 数据下载链接 47 | 百度云链接:https://pan.baidu.com/s/1xUvLiAw3YGv6s1Ll780yUg
48 | 提取码:mgtv
49 | 50 | 直接下载链接:
51 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/train_1.zip
52 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/train_2.zip
53 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/train_3.zip
54 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/val.zip
55 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/a/test_a.zip
56 | http://ad-implant.oss-cn-beijing.aliyuncs.com/challenge/res/8/b/test_b.zip
57 | 58 | 用来校验文件传输完整的MD5值:
59 | train_1.zip: 83ced2b4e80231105eb6dc8d2fae9e29
60 | train_2.zip: 792a21b974d83084c9e2bf81af0c9e10
61 | train_3.zip: e07011f6f1d149c6508a0ae50a76e18b
62 | val.zip: 6d16b879cd618358f941477eeed9a4bd
63 | test_a.zip: b146ed76a53f556fe36c9012da03bf94
64 | test_b.zip: f2f8b00df148621f25270ad1fd5e2362
65 |
66 | linux下命令:md5sum test_a.zip
67 | windows下命令:certutil -hashfile test_a.zip MD5
68 | 69 | ## 评估指标 70 | 初赛和复赛通过评估选手提交的结果来评分,本次比赛采用PSNR和SSIM两种评价指标。对于上传的结果,评估程序将计算挖空区域的PSNR和SSIM两种指标,均采用逐帧计算并进行平均。最终,PSNR和SSIM进行加权计算,并得到最终竞赛得分。PSNR取值在[0, 80],优秀分值范围大约在[30, 50];SSIM取值在[0, 1],优秀分值范围大约在[0.8, 1] 71 | ```bash 72 | score = PSNR*2*0.5 + SSIM*100*0.5 73 | ``` 74 | PSNR和SSIM的具体计算方式可看evaluate.py 75 | 76 | ## 算力要求 77 | 1) 前向推理时的内存使用不超过8G,显存使用不超过8G。 78 | 2) 在CPU(双核,2.30GHz),GPU(单卡,NVIDIA V100或NVIDIA 3090),按顺序跑完测试集(100个视频,不开多线程)的时间不超过15小时。 79 |
80 | 未满足以上限制的参赛队伍,大赛官方有权将最终总成绩认定为无效,排名由后一名依次递补。 81 | 82 | ## 作品提交要求 83 | 初赛和复赛的结果提交方式相同,都需要提交裁剪后的图片。为了缩小上传结果的大小,选手需根据各个视频内提供的bbox.txt提供的裁剪框[x,y,w,h]裁剪出对应的结果图片并打包上传。 84 | 其中,x,y的索引从0开始。例如: 85 | ```bash 86 | import cv2 87 | img = cv2.imread("result_000000.png") 88 | crop_img = img[y:y+h, x:x+w, :] 89 | cv2.imwrite("crop_000000.png", crop_img) 90 | ``` 91 | 选手需要将裁剪后的图片文件按放入各个视频文件夹(video_0***),最后一起打包成*.zip格式后上传(正常大小不超过2G)。请直接在内部打包,即result.zip解压后即为 result/video_0*** 而不是 aaa/result/video_0***。
92 | 文件夹结构和命名规则如下(以test_a为例) ("result"可任意命名, 你可以随意命名成result_0608, aaa_123等等) : 93 | ```bash 94 | |—— crop_000000.png 95 | |—— video_0000 |—— crop_000001.png 96 | | |—— crop_... 97 | result |—— video_... 98 | | |—— crop_000000.png 99 | |—— video_0099 |—— crop_000001.png 100 | |—— crop_... 101 | ``` 102 | ## 评测及排行 103 | 1) 初赛和复赛均提供下载数据,选手在本地进行算法调试,在比赛页面提交结果; 104 | 2) 初赛和复赛采用AB榜机制,A榜成绩供参赛队伍在比赛中查看,最终比赛排名(包括初赛和复赛)采用B榜最佳成绩; 105 | 3) 复赛TOP10团队需提供源代码、以及docker镜像,供大赛组委会进行方案和结果验证,只有复现结果的性能指标和提交最优结果的性能指标差异在允许的范围内,选手的成绩才是真实有效的; 106 | 4) 复赛TOP10团队必须提交一份**技术方案报告**来阐述详细方案 (建议长度1-4页) 107 | 5) 每支团队每天最多提交3次; 108 | 6) 排行按照得分从高到低排序,排行榜将选择团队的历史最优成绩进行排名; 109 | 110 |


111 | 下面是原作的README.md 112 | # [ECCV 2020] Flow-edge Guided Video Completion 113 | 114 | ### [[Paper](https://arxiv.org/abs/2009.01835)] [[Project Website](http://chengao.vision/FGVC/)] [[Google Colab](https://colab.research.google.com/drive/1pb6FjWdwq_q445rG2NP0dubw7LKNUkqc?usp=sharing)] 115 | 116 |

117 | 118 |

119 | 120 | We present a new flow-based video completion algorithm. Previous flow completion methods are often unable to retain the sharpness of motion boundaries. Our method first extracts and completes motion edges, and then uses them to guide piecewise-smooth flow completion with sharp edges. Existing methods propagate colors among local flow connections between adjacent frames. However, not all missing regions in a video can be reached in this way because the motion boundaries form impenetrable barriers. Our method alleviates this problem by introducing non-local flow connections to temporally distant frames, enabling propagating video content over motion boundaries. We validate our approach on the DAVIS dataset. Both visual and quantitative results show that our method compares favorably against the state-of-the-art algorithms. 121 |
122 | 123 | **[ECCV 2020] Flow-edge Guided Video Completion** 124 |
125 | [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/), and [Johannes Kopf](https://johanneskopf.de/) 126 |
127 | In European Conference on Computer Vision (ECCV), 2020 128 | 129 | ## Prerequisites 130 | 131 | - Linux (tested on CentOS Linux release 7.4.1708) 132 | - Anaconda 133 | - Python 3.8 (tested on 3.8.5) 134 | - PyTorch 1.6.0 135 | 136 | and the Python dependencies listed in requirements.txt 137 | 138 | - To get started, please run the following commands: 139 | ``` 140 | conda create -n FGVC 141 | conda activate FGVC 142 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch 143 | conda install matplotlib scipy 144 | pip install -r requirements.txt 145 | ``` 146 | 147 | - Next, please download the model weight and demo data using the following command: 148 | ``` 149 | chmod +x download_data_weights.sh 150 | ./download_data_weights.sh 151 | ``` 152 | 153 | ## Quick start 154 | 155 | - Object removal: 156 | ```bash 157 | cd tool 158 | python video_completion.py \ 159 | --mode object_removal \ 160 | --path ../data/tennis \ 161 | --path_mask ../data/tennis_mask \ 162 | --outroot ../result/tennis_removal \ 163 | --seamless 164 | ``` 165 | 166 | - FOV extrapolation: 167 | ```bash 168 | cd tool 169 | python video_completion.py \ 170 | --mode video_extrapolation \ 171 | --path ../data/tennis \ 172 | --outroot ../result/tennis_extrapolation \ 173 | --H_scale 2 \ 174 | --W_scale 2 \ 175 | --seamless 176 | ``` 177 | 178 | You can remove the `--seamless` flag for a faster processing time. 179 | 180 | 181 | ## License 182 | This work is licensed under MIT License. See [LICENSE](LICENSE) for details. 183 | 184 | If you find this code useful for your research, please consider citing the following paper: 185 | 186 | @inproceedings{Gao-ECCV-FGVC, 187 | author = {Gao, Chen and Saraf, Ayush and Huang, Jia-Bin and Kopf, Johannes}, 188 | title = {Flow-edge Guided Video Completion}, 189 | booktitle = {European Conference on Computer Vision}, 190 | year = {2020} 191 | } 192 | 193 | ## Acknowledgments 194 | - Our flow edge completion network builds upon [EdgeConnect](https://github.com/knazeri/edge-connect). 195 | - Our image inpainting network is modified from [DFVI](https://github.com/nbei/Deep-Flow-Guided-Video-Inpainting). 196 | -------------------------------------------------------------------------------- /edgeconnect/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | 9 | 10 | def create_dir(dir): 11 | if not os.path.exists(dir): 12 | os.makedirs(dir) 13 | 14 | 15 | def create_mask(width, height, mask_width, mask_height, x=None, y=None): 16 | mask = np.zeros((height, width)) 17 | mask_x = x if x is not None else random.randint(0, width - mask_width) 18 | mask_y = y if y is not None else random.randint(0, height - mask_height) 19 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 20 | return mask 21 | 22 | 23 | def stitch_images(inputs, *outputs, img_per_row=2): 24 | gap = 5 25 | columns = len(outputs) + 1 26 | 27 | height, width = inputs[0][:, :, 0].shape 28 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row))) 29 | images = [inputs, *outputs] 30 | 31 | for ix in range(len(inputs)): 32 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap 33 | yoffset = int(ix / img_per_row) * height 34 | 35 | for cat in range(len(images)): 36 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze() 37 | im = Image.fromarray(im) 38 | img.paste(im, (xoffset + cat * width, yoffset)) 39 | 40 | return img 41 | 42 | 43 | def imshow(img, title=''): 44 | fig = plt.gcf() 45 | fig.canvas.set_window_title(title) 46 | plt.axis('off') 47 | plt.imshow(img, interpolation='none') 48 | plt.show() 49 | 50 | 51 | def imsave(img, path): 52 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze()) 53 | im.save(path) 54 | 55 | 56 | class Progbar(object): 57 | """Displays a progress bar. 58 | 59 | Arguments: 60 | target: Total number of steps expected, None if unknown. 61 | width: Progress bar width on screen. 62 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 63 | stateful_metrics: Iterable of string names of metrics that 64 | should *not* be averaged over time. Metrics in this list 65 | will be displayed as-is. All others will be averaged 66 | by the progbar before display. 67 | interval: Minimum visual progress update interval (in seconds). 68 | """ 69 | 70 | def __init__(self, target, width=25, verbose=1, interval=0.05, 71 | stateful_metrics=None): 72 | self.target = target 73 | self.width = width 74 | self.verbose = verbose 75 | self.interval = interval 76 | if stateful_metrics: 77 | self.stateful_metrics = set(stateful_metrics) 78 | else: 79 | self.stateful_metrics = set() 80 | 81 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 82 | sys.stdout.isatty()) or 83 | 'ipykernel' in sys.modules or 84 | 'posix' in sys.modules) 85 | self._total_width = 0 86 | self._seen_so_far = 0 87 | # We use a dict + list to avoid garbage collection 88 | # issues found in OrderedDict 89 | self._values = {} 90 | self._values_order = [] 91 | self._start = time.time() 92 | self._last_update = 0 93 | 94 | def update(self, current, values=None): 95 | """Updates the progress bar. 96 | 97 | Arguments: 98 | current: Index of current step. 99 | values: List of tuples: 100 | `(name, value_for_last_step)`. 101 | If `name` is in `stateful_metrics`, 102 | `value_for_last_step` will be displayed as-is. 103 | Else, an average of the metric over time will be displayed. 104 | """ 105 | values = values or [] 106 | for k, v in values: 107 | if k not in self._values_order: 108 | self._values_order.append(k) 109 | if k not in self.stateful_metrics: 110 | if k not in self._values: 111 | self._values[k] = [v * (current - self._seen_so_far), 112 | current - self._seen_so_far] 113 | else: 114 | self._values[k][0] += v * (current - self._seen_so_far) 115 | self._values[k][1] += (current - self._seen_so_far) 116 | else: 117 | self._values[k] = v 118 | self._seen_so_far = current 119 | 120 | now = time.time() 121 | info = ' - %.0fs' % (now - self._start) 122 | if self.verbose == 1: 123 | if (now - self._last_update < self.interval and 124 | self.target is not None and current < self.target): 125 | return 126 | 127 | prev_total_width = self._total_width 128 | if self._dynamic_display: 129 | sys.stdout.write('\b' * prev_total_width) 130 | sys.stdout.write('\r') 131 | else: 132 | sys.stdout.write('\n') 133 | 134 | if self.target is not None: 135 | numdigits = int(np.floor(np.log10(self.target))) + 1 136 | barstr = '%%%dd/%d [' % (numdigits, self.target) 137 | bar = barstr % current 138 | prog = float(current) / self.target 139 | prog_width = int(self.width * prog) 140 | if prog_width > 0: 141 | bar += ('=' * (prog_width - 1)) 142 | if current < self.target: 143 | bar += '>' 144 | else: 145 | bar += '=' 146 | bar += ('.' * (self.width - prog_width)) 147 | bar += ']' 148 | else: 149 | bar = '%7d/Unknown' % current 150 | 151 | self._total_width = len(bar) 152 | sys.stdout.write(bar) 153 | 154 | if current: 155 | time_per_unit = (now - self._start) / current 156 | else: 157 | time_per_unit = 0 158 | if self.target is not None and current < self.target: 159 | eta = time_per_unit * (self.target - current) 160 | if eta > 3600: 161 | eta_format = '%d:%02d:%02d' % (eta // 3600, 162 | (eta % 3600) // 60, 163 | eta % 60) 164 | elif eta > 60: 165 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 166 | else: 167 | eta_format = '%ds' % eta 168 | 169 | info = ' - ETA: %s' % eta_format 170 | else: 171 | if time_per_unit >= 1: 172 | info += ' %.0fs/step' % time_per_unit 173 | elif time_per_unit >= 1e-3: 174 | info += ' %.0fms/step' % (time_per_unit * 1e3) 175 | else: 176 | info += ' %.0fus/step' % (time_per_unit * 1e6) 177 | 178 | for k in self._values_order: 179 | info += ' - %s:' % k 180 | if isinstance(self._values[k], list): 181 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 182 | if abs(avg) > 1e-3: 183 | info += ' %.4f' % avg 184 | else: 185 | info += ' %.4e' % avg 186 | else: 187 | info += ' %s' % self._values[k] 188 | 189 | self._total_width += len(info) 190 | if prev_total_width > self._total_width: 191 | info += (' ' * (prev_total_width - self._total_width)) 192 | 193 | if self.target is not None and current >= self.target: 194 | info += '\n' 195 | 196 | sys.stdout.write(info) 197 | sys.stdout.flush() 198 | 199 | elif self.verbose == 2: 200 | if self.target is None or current >= self.target: 201 | for k in self._values_order: 202 | info += ' - %s:' % k 203 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 204 | if avg > 1e-3: 205 | info += ' %.4f' % avg 206 | else: 207 | info += ' %.4e' % avg 208 | info += '\n' 209 | 210 | sys.stdout.write(info) 211 | sys.stdout.flush() 212 | 213 | self._last_update = now 214 | 215 | def add(self, n, values=None): 216 | self.update(self._seen_so_far + n, values) 217 | -------------------------------------------------------------------------------- /edgeconnect/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | class TotalVariationalLoss(nn.Module): 8 | def __init__(self): 9 | super(TotalVariationalLoss, self).__init__() 10 | 11 | def _tensor_size(self, x): 12 | return x.size()[1] * x.size()[2] * x.size()[3] 13 | 14 | def __call__(self, x): 15 | 16 | batch_size = x.size()[0] 17 | h_x = x.size()[2] 18 | w_x = x.size()[3] 19 | count_h = self._tensor_size(x[:, :, 1:, :]) 20 | count_w = self._tensor_size(x[:, :, :, 1:]) 21 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 22 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 23 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_size 24 | 25 | 26 | class AdversarialLoss(nn.Module): 27 | r""" 28 | Adversarial loss 29 | https://arxiv.org/abs/1711.10337 30 | """ 31 | 32 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): 33 | r""" 34 | type = nsgan | lsgan | hinge 35 | """ 36 | super(AdversarialLoss, self).__init__() 37 | 38 | self.type = type 39 | self.register_buffer('real_label', torch.tensor(target_real_label)) 40 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 41 | 42 | if type == 'nsgan': 43 | self.criterion = nn.BCELoss() 44 | 45 | elif type == 'lsgan': 46 | self.criterion = nn.MSELoss() 47 | 48 | elif type == 'hinge': 49 | self.criterion = nn.ReLU() 50 | 51 | def __call__(self, outputs, is_real, is_disc=None): 52 | if self.type == 'hinge': 53 | if is_disc: 54 | if is_real: 55 | outputs = -outputs 56 | return self.criterion(1 + outputs).mean() 57 | else: 58 | return (-outputs).mean() 59 | 60 | else: 61 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) 62 | loss = self.criterion(outputs, labels) 63 | return loss 64 | 65 | 66 | class StyleLoss(nn.Module): 67 | r""" 68 | Perceptual loss, VGG-based 69 | https://arxiv.org/abs/1603.08155 70 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 71 | """ 72 | 73 | def __init__(self): 74 | super(StyleLoss, self).__init__() 75 | self.add_module('vgg', VGG19()) 76 | self.criterion = torch.nn.L1Loss() 77 | 78 | def compute_gram(self, x): 79 | b, ch, h, w = x.size() 80 | f = x.view(b, ch, w * h) 81 | f_T = f.transpose(1, 2) 82 | G = f.bmm(f_T) / (h * w * ch) 83 | 84 | return G 85 | 86 | def __call__(self, x, y): 87 | # Compute features 88 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 89 | 90 | # Compute loss 91 | style_loss = 0.0 92 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 93 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 94 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 95 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 96 | 97 | return style_loss 98 | 99 | 100 | 101 | class PerceptualLoss(nn.Module): 102 | r""" 103 | Perceptual loss, VGG-based 104 | https://arxiv.org/abs/1603.08155 105 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 106 | """ 107 | 108 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 109 | super(PerceptualLoss, self).__init__() 110 | self.add_module('vgg', VGG19()) 111 | self.criterion = torch.nn.L1Loss() 112 | self.weights = weights 113 | 114 | def __call__(self, x, y): 115 | # Compute features 116 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 117 | 118 | content_loss = 0.0 119 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 120 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 121 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 122 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 123 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 124 | 125 | 126 | return content_loss 127 | 128 | 129 | 130 | class VGG19(torch.nn.Module): 131 | def __init__(self): 132 | super(VGG19, self).__init__() 133 | features = models.vgg19(pretrained=True).features 134 | self.relu1_1 = torch.nn.Sequential() 135 | self.relu1_2 = torch.nn.Sequential() 136 | 137 | self.relu2_1 = torch.nn.Sequential() 138 | self.relu2_2 = torch.nn.Sequential() 139 | 140 | self.relu3_1 = torch.nn.Sequential() 141 | self.relu3_2 = torch.nn.Sequential() 142 | self.relu3_3 = torch.nn.Sequential() 143 | self.relu3_4 = torch.nn.Sequential() 144 | 145 | self.relu4_1 = torch.nn.Sequential() 146 | self.relu4_2 = torch.nn.Sequential() 147 | self.relu4_3 = torch.nn.Sequential() 148 | self.relu4_4 = torch.nn.Sequential() 149 | 150 | self.relu5_1 = torch.nn.Sequential() 151 | self.relu5_2 = torch.nn.Sequential() 152 | self.relu5_3 = torch.nn.Sequential() 153 | self.relu5_4 = torch.nn.Sequential() 154 | 155 | for x in range(2): 156 | self.relu1_1.add_module(str(x), features[x]) 157 | 158 | for x in range(2, 4): 159 | self.relu1_2.add_module(str(x), features[x]) 160 | 161 | for x in range(4, 7): 162 | self.relu2_1.add_module(str(x), features[x]) 163 | 164 | for x in range(7, 9): 165 | self.relu2_2.add_module(str(x), features[x]) 166 | 167 | for x in range(9, 12): 168 | self.relu3_1.add_module(str(x), features[x]) 169 | 170 | for x in range(12, 14): 171 | self.relu3_2.add_module(str(x), features[x]) 172 | 173 | for x in range(14, 16): 174 | self.relu3_3.add_module(str(x), features[x]) 175 | 176 | for x in range(16, 18): 177 | self.relu3_4.add_module(str(x), features[x]) 178 | 179 | for x in range(18, 21): 180 | self.relu4_1.add_module(str(x), features[x]) 181 | 182 | for x in range(21, 23): 183 | self.relu4_2.add_module(str(x), features[x]) 184 | 185 | for x in range(23, 25): 186 | self.relu4_3.add_module(str(x), features[x]) 187 | 188 | for x in range(25, 27): 189 | self.relu4_4.add_module(str(x), features[x]) 190 | 191 | for x in range(27, 30): 192 | self.relu5_1.add_module(str(x), features[x]) 193 | 194 | for x in range(30, 32): 195 | self.relu5_2.add_module(str(x), features[x]) 196 | 197 | for x in range(32, 34): 198 | self.relu5_3.add_module(str(x), features[x]) 199 | 200 | for x in range(34, 36): 201 | self.relu5_4.add_module(str(x), features[x]) 202 | 203 | # don't need the gradients, just want the features 204 | for param in self.parameters(): 205 | param.requires_grad = False 206 | 207 | def forward(self, x): 208 | relu1_1 = self.relu1_1(x) 209 | relu1_2 = self.relu1_2(relu1_1) 210 | 211 | relu2_1 = self.relu2_1(relu1_2) 212 | relu2_2 = self.relu2_2(relu2_1) 213 | 214 | relu3_1 = self.relu3_1(relu2_2) 215 | relu3_2 = self.relu3_2(relu3_1) 216 | relu3_3 = self.relu3_3(relu3_2) 217 | relu3_4 = self.relu3_4(relu3_3) 218 | 219 | relu4_1 = self.relu4_1(relu3_4) 220 | relu4_2 = self.relu4_2(relu4_1) 221 | relu4_3 = self.relu4_3(relu4_2) 222 | relu4_4 = self.relu4_4(relu4_3) 223 | 224 | relu5_1 = self.relu5_1(relu4_4) 225 | relu5_2 = self.relu5_2(relu5_1) 226 | relu5_3 = self.relu5_3(relu5_2) 227 | relu5_4 = self.relu5_4(relu5_3) 228 | 229 | out = { 230 | 'relu1_1': relu1_1, 231 | 'relu1_2': relu1_2, 232 | 233 | 'relu2_1': relu2_1, 234 | 'relu2_2': relu2_2, 235 | 236 | 'relu3_1': relu3_1, 237 | 'relu3_2': relu3_2, 238 | 'relu3_3': relu3_3, 239 | 'relu3_4': relu3_4, 240 | 241 | 'relu4_1': relu4_1, 242 | 'relu4_2': relu4_2, 243 | 'relu4_3': relu4_3, 244 | 'relu4_4': relu4_4, 245 | 246 | 'relu5_1': relu5_1, 247 | 'relu5_2': relu5_2, 248 | 'relu5_3': relu5_3, 249 | 'relu5_4': relu5_4, 250 | } 251 | return out 252 | -------------------------------------------------------------------------------- /edgeconnect/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseNetwork(nn.Module): 6 | def __init__(self): 7 | super(BaseNetwork, self).__init__() 8 | 9 | def init_weights(self, init_type='normal', gain=0.02): 10 | ''' 11 | initialize network's weights 12 | init_type: normal | xavier | kaiming | orthogonal 13 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 14 | ''' 15 | 16 | def init_func(m): 17 | classname = m.__class__.__name__ 18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 19 | if init_type == 'normal': 20 | nn.init.normal_(m.weight.data, 0.0, gain) 21 | elif init_type == 'xavier': 22 | nn.init.xavier_normal_(m.weight.data, gain=gain) 23 | elif init_type == 'kaiming': 24 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 25 | elif init_type == 'orthogonal': 26 | nn.init.orthogonal_(m.weight.data, gain=gain) 27 | 28 | if hasattr(m, 'bias') and m.bias is not None: 29 | nn.init.constant_(m.bias.data, 0.0) 30 | 31 | elif classname.find('BatchNorm2d') != -1: 32 | nn.init.normal_(m.weight.data, 1.0, gain) 33 | nn.init.constant_(m.bias.data, 0.0) 34 | 35 | self.apply(init_func) 36 | 37 | 38 | class InpaintGenerator(BaseNetwork): 39 | def __init__(self, config, residual_blocks=8, init_weights=True): 40 | super(InpaintGenerator, self).__init__() 41 | self.config = config 42 | if config.FLO == 1: 43 | if config.PASSMASK == 0: 44 | in_channels = 3 45 | elif config.PASSMASK == 1: 46 | in_channels = 4 47 | else: 48 | assert(0) 49 | out_channels = 2 50 | elif config.FLO == 0: 51 | in_channels = 4 52 | out_channels = 3 53 | else: 54 | assert(0) 55 | self.encoder = nn.Sequential( 56 | nn.ReflectionPad2d(3), 57 | nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), 58 | nn.InstanceNorm2d(64, track_running_stats=False), 59 | nn.ReLU(True), 60 | 61 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 62 | nn.InstanceNorm2d(128, track_running_stats=False), 63 | nn.ReLU(True), 64 | 65 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), 66 | nn.InstanceNorm2d(256, track_running_stats=False), 67 | nn.ReLU(True) 68 | ) 69 | 70 | blocks = [] 71 | for _ in range(residual_blocks): 72 | block = ResnetBlock(256, 2) 73 | blocks.append(block) 74 | 75 | self.middle = nn.Sequential(*blocks) 76 | 77 | self.decoder = nn.Sequential( 78 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), 79 | nn.InstanceNorm2d(128, track_running_stats=False), 80 | nn.ReLU(True), 81 | 82 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), 83 | nn.InstanceNorm2d(64, track_running_stats=False), 84 | nn.ReLU(True), 85 | 86 | nn.ReflectionPad2d(3), 87 | nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=7, padding=0), 88 | ) 89 | 90 | if init_weights: 91 | self.init_weights() 92 | 93 | def forward(self, input): 94 | x = self.encoder(input) 95 | x = self.middle(x) 96 | x = self.decoder(x) 97 | 98 | if self.config.FLO == 0: 99 | x = (torch.tanh(x) + 1) / 2 100 | elif self.config.FLO == 1 and self.config.NORM == 1: 101 | if self.config.RESIDUAL == 1: 102 | assert(self.config.FILL == 1) 103 | x = torch.tanh(x + input[:, :2, :, :]) 104 | elif self.config.RESIDUAL == 0: 105 | x = torch.tanh(x) 106 | else: 107 | assert(0) 108 | return x 109 | 110 | 111 | class EdgeGenerator_(BaseNetwork): 112 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True): 113 | super(EdgeGenerator_, self).__init__() 114 | 115 | self.encoder = nn.Sequential( 116 | nn.ReflectionPad2d(3), 117 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm), 118 | nn.InstanceNorm2d(64, track_running_stats=False), 119 | nn.ReLU(True), 120 | 121 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 122 | nn.InstanceNorm2d(128, track_running_stats=False), 123 | nn.ReLU(True), 124 | 125 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm), 126 | nn.InstanceNorm2d(256, track_running_stats=False), 127 | nn.ReLU(True) 128 | ) 129 | 130 | blocks = [] 131 | for _ in range(residual_blocks): 132 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) 133 | blocks.append(block) 134 | 135 | self.middle = nn.Sequential(*blocks) 136 | 137 | self.decoder = nn.Sequential( 138 | spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 139 | nn.InstanceNorm2d(128, track_running_stats=False), 140 | nn.ReLU(True), 141 | 142 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm), 143 | nn.InstanceNorm2d(64, track_running_stats=False), 144 | nn.ReLU(True), 145 | 146 | nn.ReflectionPad2d(3), 147 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0), 148 | ) 149 | 150 | if init_weights: 151 | self.init_weights() 152 | 153 | def forward(self, x): 154 | x = self.encoder(x) 155 | x = self.middle(x) 156 | x = self.decoder(x) 157 | x = torch.sigmoid(x) 158 | return x 159 | 160 | 161 | class Discriminator(BaseNetwork): 162 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True): 163 | super(Discriminator, self).__init__() 164 | self.use_sigmoid = use_sigmoid 165 | 166 | self.conv1 = self.features = nn.Sequential( 167 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 168 | nn.LeakyReLU(0.2, inplace=True), 169 | ) 170 | 171 | self.conv2 = nn.Sequential( 172 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 173 | nn.LeakyReLU(0.2, inplace=True), 174 | ) 175 | 176 | self.conv3 = nn.Sequential( 177 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 178 | nn.LeakyReLU(0.2, inplace=True), 179 | ) 180 | 181 | self.conv4 = nn.Sequential( 182 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 183 | nn.LeakyReLU(0.2, inplace=True), 184 | ) 185 | 186 | self.conv5 = nn.Sequential( 187 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 188 | ) 189 | 190 | if init_weights: 191 | self.init_weights() 192 | 193 | def forward(self, x): 194 | conv1 = self.conv1(x) 195 | conv2 = self.conv2(conv1) 196 | conv3 = self.conv3(conv2) 197 | conv4 = self.conv4(conv3) 198 | conv5 = self.conv5(conv4) 199 | 200 | outputs = conv5 201 | if self.use_sigmoid: 202 | outputs = torch.sigmoid(conv5) 203 | 204 | return outputs, [conv1, conv2, conv3, conv4, conv5] 205 | 206 | 207 | class ResnetBlock(nn.Module): 208 | def __init__(self, dim, dilation=1, use_spectral_norm=False): 209 | super(ResnetBlock, self).__init__() 210 | self.conv_block = nn.Sequential( 211 | nn.ReflectionPad2d(dilation), 212 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 213 | nn.InstanceNorm2d(dim, track_running_stats=False), 214 | nn.ReLU(True), 215 | 216 | nn.ReflectionPad2d(1), 217 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 218 | nn.InstanceNorm2d(dim, track_running_stats=False), 219 | ) 220 | 221 | def forward(self, x): 222 | out = x + self.conv_block(x) 223 | 224 | # Remove ReLU at the end of the residual block 225 | # http://torch.ch/blog/2016/02/04/resnets.html 226 | 227 | return out 228 | 229 | 230 | def spectral_norm(module, mode=True): 231 | if mode: 232 | return nn.utils.spectral_norm(module) 233 | 234 | return module 235 | -------------------------------------------------------------------------------- /RAFT/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /RAFT/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /RAFT/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /utils/Poisson_blend_img.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import scipy.ndimage 4 | from scipy.sparse.linalg import spsolve 5 | from scipy import sparse 6 | import scipy.io as sio 7 | import numpy as np 8 | from PIL import Image 9 | import copy 10 | import cv2 11 | import os 12 | import argparse 13 | 14 | 15 | def sub2ind(pi, pj, imgH, imgW): 16 | return pj + pi * imgW 17 | 18 | 19 | def Poisson_blend_img(imgTrg, imgSrc_gx, imgSrc_gy, holeMask, gradientMask=None, edge=None): 20 | 21 | imgH, imgW, nCh = imgTrg.shape 22 | 23 | if not isinstance(gradientMask, np.ndarray): 24 | gradientMask = np.zeros((imgH, imgW), dtype=np.float32) 25 | 26 | if not isinstance(edge, np.ndarray): 27 | edge = np.zeros((imgH, imgW), dtype=np.float32) 28 | 29 | # Initialize the reconstructed image 30 | imgRecon = np.zeros((imgH, imgW, nCh), dtype=np.float32) 31 | 32 | # prepare discrete Poisson equation 33 | A, b, UnfilledMask = solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, 34 | gradientMask, edge) 35 | 36 | # Independently process each channel 37 | for ch in range(nCh): 38 | 39 | # solve Poisson equation 40 | x = scipy.sparse.linalg.lsqr(A, b[:, ch])[0] 41 | 42 | imgRecon[:, :, ch] = x.reshape(imgH, imgW) 43 | 44 | # Combined with the known region in the target 45 | holeMaskC = np.tile(np.expand_dims(holeMask, axis=2), (1, 1, nCh)) 46 | imgBlend = holeMaskC * imgRecon + (1 - holeMaskC) * imgTrg 47 | 48 | 49 | # while((UnfilledMask * edge).sum() != 0): 50 | # # Fill in edge pixel 51 | # pi = np.expand_dims(np.where((UnfilledMask * edge) == 1)[0], axis=1) # y, i 52 | # pj = np.expand_dims(np.where((UnfilledMask * edge) == 1)[1], axis=1) # x, j 53 | # 54 | # for k in range(len(pi)): 55 | # if pi[k, 0] - 1 >= 0: 56 | # if (UnfilledMask * edge)[pi[k, 0] - 1, pj[k, 0]] == 0: 57 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] - 1, pj[k, 0], :] 58 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 59 | # continue 60 | # if pi[k, 0] + 1 <= imgH - 1: 61 | # if (UnfilledMask * edge)[pi[k, 0] + 1, pj[k, 0]] == 0: 62 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0] + 1, pj[k, 0], :] 63 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 64 | # continue 65 | # if pj[k, 0] - 1 >= 0: 66 | # if (UnfilledMask * edge)[pi[k, 0], pj[k, 0] - 1] == 0: 67 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] - 1, :] 68 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 69 | # continue 70 | # if pj[k, 0] + 1 <= imgW - 1: 71 | # if (UnfilledMask * edge)[pi[k, 0], pj[k, 0] + 1] == 0: 72 | # imgBlend[pi[k, 0], pj[k, 0], :] = imgBlend[pi[k, 0], pj[k, 0] + 1, :] 73 | # UnfilledMask[pi[k, 0], pj[k, 0]] = 0 74 | 75 | return imgBlend, UnfilledMask 76 | 77 | def solvePoisson(holeMask, imgSrc_gx, imgSrc_gy, imgTrg, 78 | gradientMask, edge): 79 | 80 | # UnfilledMask indicates the region that is not completed 81 | UnfilledMask_topleft = copy.deepcopy(holeMask) 82 | UnfilledMask_bottomright = copy.deepcopy(holeMask) 83 | 84 | # Prepare the linear system of equations for Poisson blending 85 | imgH, imgW = holeMask.shape 86 | N = imgH * imgW 87 | 88 | # Number of unknown variables 89 | numUnknownPix = holeMask.sum() 90 | 91 | # 4-neighbors: dx and dy 92 | dx = [1, 0, -1, 0] 93 | dy = [0, 1, 0, -1] 94 | 95 | # 3 96 | # | 97 | # 2 -- * -- 0 98 | # | 99 | # 1 100 | # 101 | 102 | # Initialize (I, J, S), for sparse matrix A where A(I(k), J(k)) = S(k) 103 | I = np.empty((0, 1), dtype=np.float32) 104 | J = np.empty((0, 1), dtype=np.float32) 105 | S = np.empty((0, 1), dtype=np.float32) 106 | 107 | # Initialize b 108 | b = np.empty((0, 3), dtype=np.float32) 109 | 110 | # Precompute unkonwn pixel position 111 | pi = np.expand_dims(np.where(holeMask == 1)[0], axis=1) # y, i 112 | pj = np.expand_dims(np.where(holeMask == 1)[1], axis=1) # x, j 113 | pind = sub2ind(pi, pj, imgH, imgW) 114 | 115 | # |--------------------| 116 | # | y (i) | 117 | # | x (j) * | 118 | # | | 119 | # |--------------------| 120 | # p[y, x] 121 | 122 | qi = np.concatenate((pi + dy[0], 123 | pi + dy[1], 124 | pi + dy[2], 125 | pi + dy[3]), axis=1) 126 | 127 | qj = np.concatenate((pj + dx[0], 128 | pj + dx[1], 129 | pj + dx[2], 130 | pj + dx[3]), axis=1) 131 | 132 | # Handling cases at image borders 133 | validN = (qi >= 0) & (qi <= imgH - 1) & (qj >= 0) & (qj <= imgW - 1) 134 | qind = np.zeros((validN.shape), dtype=np.float32) 135 | qind[validN] = sub2ind(qi[validN], qj[validN], imgH, imgW) 136 | 137 | e_start = 0 # equation counter start 138 | e_stop = 0 # equation stop 139 | 140 | # 4 neighbors 141 | I, J, S, b, e_start, e_stop = constructEquation(0, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 142 | I, J, S, b, e_start, e_stop = constructEquation(1, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 143 | I, J, S, b, e_start, e_stop = constructEquation(2, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 144 | I, J, S, b, e_start, e_stop = constructEquation(3, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop) 145 | 146 | nEqn = len(b) 147 | # Construct the sparse matrix A 148 | A = sparse.csr_matrix((S[:, 0], (I[:, 0], J[:, 0])), shape=(nEqn, N)) 149 | 150 | # Check connected pixels 151 | for ind in range(0, len(pi), 1): 152 | ii = pi[ind, 0] 153 | jj = pj[ind, 0] 154 | 155 | # check up (3) 156 | if ii - 1 >= 0: 157 | if UnfilledMask_topleft[ii - 1, jj] == 0 and gradientMask[ii - 1, jj] == 0: 158 | UnfilledMask_topleft[ii, jj] = 0 159 | # check left (2) 160 | if jj - 1 >= 0: 161 | if UnfilledMask_topleft[ii, jj - 1] == 0 and gradientMask[ii, jj - 1] == 0: 162 | UnfilledMask_topleft[ii, jj] = 0 163 | 164 | 165 | for ind in range(len(pi) - 1, -1, -1): 166 | ii = pi[ind, 0] 167 | jj = pj[ind, 0] 168 | 169 | # check bottom (1) 170 | if ii + 1 <= imgH - 1: 171 | if UnfilledMask_bottomright[ii + 1, jj] == 0 and gradientMask[ii, jj] == 0: 172 | UnfilledMask_bottomright[ii, jj] = 0 173 | # check right (0) 174 | if jj + 1 <= imgW - 1: 175 | if UnfilledMask_bottomright[ii, jj + 1] == 0 and gradientMask[ii, jj] == 0: 176 | UnfilledMask_bottomright[ii, jj] = 0 177 | 178 | UnfilledMask = UnfilledMask_topleft * UnfilledMask_bottomright 179 | 180 | return A, b, UnfilledMask 181 | 182 | 183 | def constructEquation(n, validN, holeMask, gradientMask, edge, imgSrc_gx, imgSrc_gy, imgTrg, pi, pj, pind, qi, qj, qind, I, J, S, b, e_start, e_stop): 184 | 185 | # Pixel that has valid neighbors 186 | validNeighbor = validN[:, n] 187 | 188 | # Change the out-of-boundary value to 0, in order to run edge[y,x] 189 | # in the next line. It won't affect anything as validNeighbor is saved already 190 | 191 | qi_tmp = copy.deepcopy(qi) 192 | qj_tmp = copy.deepcopy(qj) 193 | qi_tmp[np.invert(validNeighbor), n] = 0 194 | qj_tmp[np.invert(validNeighbor), n] = 0 195 | 196 | NotEdge = (edge[pi[:, 0], pj[:, 0]] == 0) * (edge[qi_tmp[:, n], qj_tmp[:, n]] == 0) 197 | 198 | # Have gradient 199 | if n == 0: 200 | HaveGrad = gradientMask[pi[:, 0], pj[:, 0]] == 0 201 | elif n == 2: 202 | HaveGrad = gradientMask[pi[:, 0], pj[:, 0] - 1] == 0 203 | elif n == 1: 204 | HaveGrad = gradientMask[pi[:, 0], pj[:, 0]] == 0 205 | elif n == 3: 206 | HaveGrad = gradientMask[pi[:, 0] - 1, pj[:, 0]] == 0 207 | 208 | # Boundary constraint 209 | Boundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 0 210 | 211 | valid = validNeighbor * NotEdge * HaveGrad * Boundary 212 | 213 | J_tmp = pind[valid, :] 214 | 215 | # num of equations: len(J_tmp) 216 | e_stop = e_start + len(J_tmp) 217 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) 218 | e_start = e_stop 219 | 220 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32) 221 | 222 | if n == 0: 223 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] 224 | elif n == 2: 225 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] + imgTrg[qi[valid, n], qj[valid, n], :] 226 | elif n == 1: 227 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] 228 | elif n == 3: 229 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] + imgTrg[qi[valid, n], qj[valid, n], :] 230 | 231 | I = np.concatenate((I, I_tmp)) 232 | J = np.concatenate((J, J_tmp)) 233 | S = np.concatenate((S, S_tmp)) 234 | b = np.concatenate((b, b_tmp)) 235 | 236 | # Non-boundary constraint 237 | NonBoundary = holeMask[qi_tmp[:, n], qj_tmp[:, n]] == 1 238 | valid = validNeighbor * NotEdge * HaveGrad * NonBoundary 239 | 240 | J_tmp = pind[valid, :] 241 | 242 | # num of equations: len(J_tmp) 243 | e_stop = e_start + len(J_tmp) 244 | I_tmp = np.arange(e_start, e_stop, dtype=np.float32).reshape(-1, 1) 245 | e_start = e_stop 246 | 247 | S_tmp = np.ones(J_tmp.shape, dtype=np.float32) 248 | 249 | if n == 0: 250 | b_tmp = - imgSrc_gx[pi[valid, 0], pj[valid, 0], :] 251 | elif n == 2: 252 | b_tmp = imgSrc_gx[pi[valid, 0], pj[valid, 0] - 1, :] 253 | elif n == 1: 254 | b_tmp = - imgSrc_gy[pi[valid, 0], pj[valid, 0], :] 255 | elif n == 3: 256 | b_tmp = imgSrc_gy[pi[valid, 0] - 1, pj[valid, 0], :] 257 | 258 | I = np.concatenate((I, I_tmp)) 259 | J = np.concatenate((J, J_tmp)) 260 | S = np.concatenate((S, S_tmp)) 261 | b = np.concatenate((b, b_tmp)) 262 | 263 | S_tmp = - np.ones(J_tmp.shape, dtype=np.float32) 264 | J_tmp = qind[valid, n, None] 265 | 266 | I = np.concatenate((I, I_tmp)) 267 | J = np.concatenate((J, J_tmp)) 268 | S = np.concatenate((S, S_tmp)) 269 | 270 | return I, J, S, b, e_start, e_stop 271 | -------------------------------------------------------------------------------- /edgeconnect/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from .networks import InpaintGenerator, EdgeGenerator, Discriminator 6 | from .loss import AdversarialLoss, PerceptualLoss, StyleLoss, TotalVariationalLoss 7 | 8 | 9 | class BaseModel(nn.Module): 10 | def __init__(self, name, config): 11 | super(BaseModel, self).__init__() 12 | 13 | self.name = name 14 | self.config = config 15 | self.iteration = 0 16 | 17 | self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth') 18 | self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth') 19 | 20 | def load(self): 21 | if os.path.exists(self.gen_weights_path): 22 | print('Loading %s generator...' % self.name) 23 | 24 | if torch.cuda.is_available(): 25 | data = torch.load(self.gen_weights_path) 26 | else: 27 | data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage) 28 | 29 | self.generator.load_state_dict(data['generator']) 30 | self.iteration = data['iteration'] 31 | 32 | # load discriminator only when training 33 | if self.config.MODE == 1 and os.path.exists(self.dis_weights_path): 34 | print('Loading %s discriminator...' % self.name) 35 | 36 | if torch.cuda.is_available(): 37 | data = torch.load(self.dis_weights_path) 38 | else: 39 | data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage) 40 | 41 | self.discriminator.load_state_dict(data['discriminator']) 42 | 43 | def save(self): 44 | print('\nsaving %s...\n' % self.name) 45 | torch.save({ 46 | 'iteration': self.iteration, 47 | 'generator': self.generator.state_dict() 48 | }, self.gen_weights_path) 49 | 50 | torch.save({ 51 | 'discriminator': self.discriminator.state_dict() 52 | }, self.dis_weights_path) 53 | 54 | 55 | class EdgeModel(BaseModel): 56 | def __init__(self, config): 57 | super(EdgeModel, self).__init__('EdgeModel', config) 58 | 59 | # generator input: [grayscale(1) + edge(1) + mask(1)] 60 | # discriminator input: (grayscale(1) + edge(1)) 61 | generator = EdgeGenerator(use_spectral_norm=True) 62 | discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge') 63 | if len(config.GPU) > 1: 64 | generator = nn.DataParallel(generator, config.GPU) 65 | discriminator = nn.DataParallel(discriminator, config.GPU) 66 | l1_loss = nn.L1Loss() 67 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) 68 | 69 | self.add_module('generator', generator) 70 | self.add_module('discriminator', discriminator) 71 | 72 | self.add_module('l1_loss', l1_loss) 73 | self.add_module('adversarial_loss', adversarial_loss) 74 | 75 | self.gen_optimizer = optim.Adam( 76 | params=generator.parameters(), 77 | lr=float(config.LR), 78 | betas=(config.BETA1, config.BETA2) 79 | ) 80 | 81 | self.dis_optimizer = optim.Adam( 82 | params=discriminator.parameters(), 83 | lr=float(config.LR) * float(config.D2G_LR), 84 | betas=(config.BETA1, config.BETA2) 85 | ) 86 | 87 | def process(self, images, edges, masks): 88 | self.iteration += 1 89 | 90 | 91 | # zero optimizers 92 | self.gen_optimizer.zero_grad() 93 | self.dis_optimizer.zero_grad() 94 | 95 | 96 | # process outputs 97 | outputs = self(images, edges, masks) 98 | gen_loss = 0 99 | dis_loss = 0 100 | 101 | 102 | # discriminator loss 103 | dis_input_real = torch.cat((images, edges), dim=1) 104 | dis_input_fake = torch.cat((images, outputs.detach()), dim=1) 105 | dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1)) 106 | dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1)) 107 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 108 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 109 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 110 | 111 | 112 | # generator adversarial loss 113 | gen_input_fake = torch.cat((images, outputs), dim=1) 114 | gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1)) 115 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) 116 | gen_loss += gen_gan_loss 117 | 118 | 119 | # generator feature matching loss 120 | gen_fm_loss = 0 121 | for i in range(len(dis_real_feat)): 122 | gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach()) 123 | gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT 124 | gen_loss += gen_fm_loss 125 | 126 | 127 | # create logs 128 | logs = [ 129 | ("l_d1", dis_loss.item()), 130 | ("l_g1", gen_gan_loss.item()), 131 | ("l_fm", gen_fm_loss.item()), 132 | ] 133 | 134 | return outputs, gen_loss, dis_loss, logs 135 | 136 | def forward(self, images, edges, masks): 137 | edges_masked = (edges * (1 - masks)) 138 | images_masked = (images * (1 - masks)) + masks 139 | inputs = torch.cat((images_masked, edges_masked, masks), dim=1) 140 | outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)] 141 | return outputs 142 | 143 | def backward(self, gen_loss=None, dis_loss=None): 144 | if dis_loss is not None: 145 | dis_loss.backward() 146 | self.dis_optimizer.step() 147 | 148 | if gen_loss is not None: 149 | gen_loss.backward() 150 | self.gen_optimizer.step() 151 | 152 | 153 | class InpaintingModel(BaseModel): 154 | def __init__(self, config): 155 | super(InpaintingModel, self).__init__('InpaintingModel', config) 156 | 157 | # generator input: [rgb(3) + edge(1)] 158 | # discriminator input: [rgb(3)] 159 | generator = InpaintGenerator(config) 160 | self.config = config 161 | if config.FLO == 1: 162 | in_channels = 2 163 | elif config.FLO == 0: 164 | in_channels = 3 165 | else: 166 | assert(0) 167 | discriminator = Discriminator(in_channels=in_channels, use_sigmoid=config.GAN_LOSS != 'hinge') 168 | if len(config.GPU) > 1: 169 | generator = nn.DataParallel(generator, config.GPU) 170 | discriminator = nn.DataParallel(discriminator , config.GPU) 171 | 172 | l1_loss = nn.L1Loss() 173 | tv_loss = TotalVariationalLoss() 174 | perceptual_loss = PerceptualLoss() 175 | style_loss = StyleLoss() 176 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) 177 | 178 | self.add_module('generator', generator) 179 | self.add_module('discriminator', discriminator) 180 | 181 | self.add_module('l1_loss', l1_loss) 182 | self.add_module('tv_loss', tv_loss) 183 | self.add_module('perceptual_loss', perceptual_loss) 184 | self.add_module('style_loss', style_loss) 185 | self.add_module('adversarial_loss', adversarial_loss) 186 | 187 | self.gen_optimizer = optim.Adam( 188 | params=generator.parameters(), 189 | lr=float(config.LR), 190 | betas=(config.BETA1, config.BETA2) 191 | ) 192 | 193 | self.dis_optimizer = optim.Adam( 194 | params=discriminator.parameters(), 195 | lr=float(config.LR) * float(config.D2G_LR), 196 | betas=(config.BETA1, config.BETA2) 197 | ) 198 | 199 | def process(self, images, images_filled, edges, masks): 200 | self.iteration += 1 201 | 202 | # zero optimizers 203 | self.gen_optimizer.zero_grad() 204 | self.dis_optimizer.zero_grad() 205 | 206 | # process outputs 207 | outputs = self(images, images_filled, edges, masks) 208 | 209 | gen_loss = 0 210 | dis_loss = 0 211 | gen_gan_loss = 0 212 | 213 | if self.config.GAN == 1: 214 | # discriminator loss 215 | dis_input_real = images 216 | dis_input_fake = outputs.detach() 217 | dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)] 218 | dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)] 219 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 220 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 221 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 222 | 223 | 224 | # generator adversarial loss 225 | gen_input_fake = outputs 226 | gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)] 227 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT 228 | gen_loss += gen_gan_loss 229 | 230 | 231 | # generator l1 loss 232 | gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks) 233 | gen_loss += gen_l1_loss 234 | 235 | if self.config.ENFORCE == 1: 236 | gen_l1_masked_loss = self.l1_loss(outputs * masks, images * masks) * 10 * self.config.L1_LOSS_WEIGHT 237 | gen_loss += gen_l1_masked_loss 238 | elif self.config.ENFORCE != 0: 239 | assert(0) 240 | 241 | if self.config.TV == 1: 242 | # generator tv loss 243 | gen_tv_loss = self.tv_loss(outputs) * self.config.TV_LOSS_WEIGHT 244 | gen_loss += gen_tv_loss 245 | 246 | if self.config.FLO != 1: 247 | # generator perceptual loss 248 | gen_content_loss = self.perceptual_loss(outputs, images) 249 | gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT 250 | gen_loss += gen_content_loss 251 | 252 | # generator style loss 253 | gen_style_loss = self.style_loss(outputs * masks, images * masks) 254 | gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT 255 | gen_loss += gen_style_loss 256 | 257 | # create logs 258 | logs = [ 259 | ("l_d2", dis_loss.item()), 260 | ("l_g2", gen_gan_loss.item()), 261 | ("l_l1", gen_l1_loss.item()), 262 | ("l_per", gen_content_loss.item()), 263 | ("l_sty", gen_style_loss.item()), 264 | ] 265 | else: 266 | logs = [] 267 | logs.append(("l_l1", gen_l1_loss.item())) 268 | logs.append(("l_gen", gen_loss.item())) 269 | 270 | if self.config.GAN == 1: 271 | logs.append(("l_d2", dis_loss.item())) 272 | logs.append(("l_g2", gen_gan_loss.item())) 273 | 274 | if self.config.TV == 1: 275 | logs.append(("l_tv", gen_tv_loss.item())) 276 | 277 | if self.config.ENFORCE == 1: 278 | logs.append(("l_masked_l1", gen_l1_masked_loss.item())) 279 | 280 | return outputs, gen_loss, dis_loss, logs 281 | 282 | def forward(self, images, images_filled, edges, masks): 283 | 284 | if self.config.FILL == 1: 285 | images_masked = images_filled 286 | elif self.config.FILL == 0: 287 | images_masked = (images * (1 - masks).float()) # + masks 288 | else: 289 | assert(0) 290 | 291 | if self.config.PASSMASK == 1: 292 | inputs = torch.cat((images_masked, edges, masks), dim=1) 293 | elif self.config.PASSMASK == 0: 294 | inputs = torch.cat((images_masked, edges), dim=1) 295 | else: 296 | assert(0) 297 | 298 | outputs = self.generator(inputs) 299 | # if self.config.RESIDUAL == 1: 300 | # assert(self.config.PASSMASK == 1) 301 | # outputs = self.generator(inputs) + images_filled 302 | # elif self.config.RESIDUAL == 0: 303 | # outputs = self.generator(inputs) 304 | # else: 305 | # assert(0) 306 | 307 | return outputs 308 | 309 | def backward(self, gen_loss=None, dis_loss=None): 310 | 311 | if self.config.GAN == 1: 312 | dis_loss.backward() 313 | self.dis_optimizer.step() 314 | 315 | gen_loss.backward() 316 | self.gen_optimizer.step() 317 | -------------------------------------------------------------------------------- /edgeconnect/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import scipy 5 | import torch 6 | import random 7 | import numpy as np 8 | import torchvision.transforms.functional as F 9 | from torch.utils.data import DataLoader 10 | from PIL import Image 11 | from scipy.misc import imread 12 | from skimage.feature import canny 13 | from skimage.color import rgb2gray, gray2rgb 14 | from .utils import create_mask 15 | import src.region_fill as rf 16 | 17 | class Dataset(torch.utils.data.Dataset): 18 | def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True): 19 | super(Dataset, self).__init__() 20 | 21 | self.augment = augment 22 | self.training = training 23 | self.flo = config.FLO 24 | self.norm = config.NORM 25 | self.data = self.load_flist(flist, self.flo) 26 | self.edge_data = self.load_flist(edge_flist, 0) 27 | self.mask_data = self.load_flist(mask_flist, 0) 28 | 29 | self.input_size = config.INPUT_SIZE 30 | self.sigma = config.SIGMA 31 | self.edge = config.EDGE 32 | self.mask = config.MASK 33 | self.nms = config.NMS 34 | 35 | 36 | 37 | # in test mode, there's a one-to-one relationship between mask and image 38 | # masks are loaded non random 39 | if config.MODE == 2: 40 | self.mask = 6 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self, index): 46 | try: 47 | item = self.load_item(index) 48 | except: 49 | print('loading error: ' + self.data[index]) 50 | item = self.load_item(0) 51 | 52 | return item 53 | 54 | def load_name(self, index): 55 | name = self.data[index] 56 | return os.path.basename(name) 57 | 58 | def load_item(self, index): 59 | size = self.input_size 60 | factor = 1. 61 | if self.flo == 0: 62 | 63 | # load image 64 | img = imread(self.data[index]) 65 | 66 | # gray to rgb 67 | if len(img.shape) < 3: 68 | img = gray2rgb(img) 69 | 70 | # resize/crop if needed 71 | if size != 0: 72 | img = self.resize(img, size[0], size[1]) 73 | 74 | # create grayscale image 75 | img_gray = rgb2gray(img) 76 | 77 | # load mask 78 | mask = self.load_mask(img, index) 79 | 80 | edge = self.load_edge(img_gray, index, mask) 81 | 82 | img_filled = img 83 | 84 | else: 85 | 86 | img = self.readFlow(self.data[index]) 87 | 88 | # resize/crop if needed 89 | if size != 0: 90 | img = self.flow_tf(img, [size[0], size[1]]) 91 | 92 | img_gray = (img[:, :, 0] ** 2 + img[:, :, 1] ** 2) ** 0.5 93 | 94 | if self.norm == 1: 95 | # normalization 96 | # factor = (np.abs(img[:, :, 0]).max() ** 2 + np.abs(img[:, :, 1]).max() ** 2) ** 0.5 97 | factor = img_gray.max() 98 | img /= factor 99 | 100 | # load mask 101 | mask = self.load_mask(img, index) 102 | 103 | edge = self.load_edge(img_gray, index, mask) 104 | img_gray = img_gray / img_gray.max() 105 | 106 | img_filled = np.zeros(img.shape) 107 | img_filled[:, :, 0] = rf.regionfill(img[:, :, 0], mask) 108 | img_filled[:, :, 1] = rf.regionfill(img[:, :, 1], mask) 109 | 110 | 111 | # augment data 112 | if self.augment and np.random.binomial(1, 0.5) > 0: 113 | img = img[:, ::-1, ...].copy() 114 | img_filled = img_filled[:, ::-1, ...].copy() 115 | img_gray = img_gray[:, ::-1, ...] 116 | edge = edge[:, ::-1, ...] 117 | mask = mask[:, ::-1, ...] 118 | 119 | return self.to_tensor(img), self.to_tensor(img_filled), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask), factor 120 | 121 | def load_edge(self, img, index, mask): 122 | sigma = self.sigma 123 | 124 | # in test mode images are masked (with masked regions), 125 | # using 'mask' parameter prevents canny to detect edges for the masked regions 126 | mask = None if self.training else (1 - mask / 255).astype(np.bool) 127 | 128 | # canny 129 | if self.edge == 1: 130 | # no edge 131 | if sigma == -1: 132 | return np.zeros(img.shape).astype(np.float) 133 | 134 | # random sigma 135 | if sigma == 0: 136 | sigma = random.randint(1, 4) 137 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 138 | 139 | # external 140 | else: 141 | imgh, imgw = img.shape[0:2] 142 | edge = imread(self.edge_data[index]) 143 | edge = self.resize(edge, imgh, imgw) 144 | 145 | # non-max suppression 146 | if self.nms == 1: 147 | edge = edge * canny(img, sigma=sigma, mask=mask) 148 | 149 | return edge 150 | 151 | def load_mask(self, img, index): 152 | imgh, imgw = img.shape[0:2] 153 | mask_type = self.mask 154 | 155 | # external + random block 156 | if mask_type == 4: 157 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3 158 | 159 | # external + random block + half 160 | elif mask_type == 5: 161 | mask_type = np.random.randint(1, 4) 162 | 163 | # random block 164 | if mask_type == 1: 165 | return create_mask(imgw, imgh, imgw // 2, imgh // 2) 166 | 167 | # half 168 | if mask_type == 2: 169 | # randomly choose right or left 170 | return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 171 | 172 | # external 173 | if mask_type == 3: 174 | mask_index = random.randint(0, len(self.mask_data) - 1) 175 | mask = imread(self.mask_data[mask_index]) 176 | mask = self.resize(mask, imgh, imgw, centerCrop=False) 177 | mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation 178 | return mask 179 | 180 | # test mode: load mask non random 181 | 182 | if mask_type == 6: 183 | mask = imread(self.mask_data[index]) 184 | mask = self.resize(mask, imgh, imgw, centerCrop=False) 185 | mask = rgb2gray(mask) 186 | mask = (mask > 0).astype(np.uint8) * 255 187 | return mask 188 | 189 | def to_tensor(self, img): 190 | if (len(img.shape) == 3 and img.shape[2] == 2): 191 | return F.to_tensor(img).float() 192 | img = Image.fromarray(img) 193 | img_t = F.to_tensor(img).float() 194 | return img_t 195 | 196 | def resize(self, img, height, width, centerCrop=True): 197 | imgh, imgw = img.shape[0:2] 198 | 199 | if centerCrop and imgh != imgw: 200 | # center crop 201 | side = np.minimum(imgh, imgw) 202 | j = (imgh - side) // 2 203 | i = (imgw - side) // 2 204 | img = img[j:j + side, i:i + side, ...] 205 | 206 | img = scipy.misc.imresize(img, [height, width]) 207 | 208 | return img 209 | 210 | def load_flist(self, flist, flo=0): 211 | if isinstance(flist, list): 212 | return flist 213 | 214 | # flist: image file path, image directory path, text file flist path 215 | if flo == 0: 216 | if isinstance(flist, str): 217 | if os.path.isdir(flist): 218 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 219 | flist.sort() 220 | return flist 221 | 222 | if os.path.isfile(flist): 223 | try: 224 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 225 | except: 226 | return [flist] 227 | else: 228 | if isinstance(flist, str): 229 | if os.path.isdir(flist): 230 | flist = list(glob.glob(flist + '/*.flo')) 231 | flist.sort() 232 | return flist 233 | 234 | if os.path.isfile(flist): 235 | try: 236 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 237 | except: 238 | return [flist] 239 | 240 | return [] 241 | 242 | def create_iterator(self, batch_size): 243 | while True: 244 | sample_loader = DataLoader( 245 | dataset=self, 246 | batch_size=batch_size, 247 | drop_last=True 248 | ) 249 | 250 | for item in sample_loader: 251 | yield item 252 | 253 | def readFlow(self, fn): 254 | with open(fn, 'rb') as f: 255 | magic = np.fromfile(f, np.float32, count=1) 256 | if 202021.25 != magic: 257 | print('Magic number incorrect. Invalid .flo file') 258 | return None 259 | else: 260 | w = np.fromfile(f, np.int32, count=1) 261 | h = np.fromfile(f, np.int32, count=1) 262 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 263 | # Reshape data into 3D array (columns, rows, bands) 264 | # The reshape here is for visualization, the original code is (w,h,2) 265 | return np.resize(data, (int(h), int(w), 2)) 266 | 267 | def flow_to_image(self, flow): 268 | 269 | UNKNOWN_FLOW_THRESH = 1e7 270 | 271 | u = flow[:, :, 0] 272 | v = flow[:, :, 1] 273 | 274 | maxu = -999. 275 | maxv = -999. 276 | minu = 999. 277 | minv = 999. 278 | 279 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 280 | u[idxUnknow] = 0 281 | v[idxUnknow] = 0 282 | 283 | maxu = max(maxu, np.max(u)) 284 | minu = min(minu, np.min(u)) 285 | 286 | maxv = max(maxv, np.max(v)) 287 | minv = min(minv, np.min(v)) 288 | 289 | rad = np.sqrt(u ** 2 + v ** 2) 290 | maxrad = max(-1, np.max(rad)) 291 | 292 | # print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) 293 | 294 | u = u/(maxrad + np.finfo(float).eps) 295 | v = v/(maxrad + np.finfo(float).eps) 296 | 297 | img = self.compute_color(u, v) 298 | 299 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 300 | img[idx] = 0 301 | 302 | return np.uint8(img) 303 | 304 | 305 | def compute_color(self, u, v): 306 | """ 307 | compute optical flow color map 308 | :param u: optical flow horizontal map 309 | :param v: optical flow vertical map 310 | :return: optical flow in color code 311 | """ 312 | [h, w] = u.shape 313 | img = np.zeros([h, w, 3]) 314 | nanIdx = np.isnan(u) | np.isnan(v) 315 | u[nanIdx] = 0 316 | v[nanIdx] = 0 317 | 318 | colorwheel = self.make_color_wheel() 319 | ncols = np.size(colorwheel, 0) 320 | 321 | rad = np.sqrt(u**2+v**2) 322 | 323 | a = np.arctan2(-v, -u) / np.pi 324 | 325 | fk = (a+1) / 2 * (ncols - 1) + 1 326 | 327 | k0 = np.floor(fk).astype(int) 328 | 329 | k1 = k0 + 1 330 | k1[k1 == ncols+1] = 1 331 | f = fk - k0 332 | 333 | for i in range(0, np.size(colorwheel,1)): 334 | tmp = colorwheel[:, i] 335 | col0 = tmp[k0-1] / 255 336 | col1 = tmp[k1-1] / 255 337 | col = (1-f) * col0 + f * col1 338 | 339 | idx = rad <= 1 340 | col[idx] = 1-rad[idx]*(1-col[idx]) 341 | notidx = np.logical_not(idx) 342 | 343 | col[notidx] *= 0.75 344 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 345 | 346 | return img 347 | 348 | 349 | def make_color_wheel(self): 350 | """ 351 | Generate color wheel according Middlebury color code 352 | :return: Color wheel 353 | """ 354 | RY = 15 355 | YG = 6 356 | GC = 4 357 | CB = 11 358 | BM = 13 359 | MR = 6 360 | 361 | ncols = RY + YG + GC + CB + BM + MR 362 | 363 | colorwheel = np.zeros([ncols, 3]) 364 | 365 | col = 0 366 | 367 | # RY 368 | colorwheel[0:RY, 0] = 255 369 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 370 | col += RY 371 | 372 | # YG 373 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 374 | colorwheel[col:col+YG, 1] = 255 375 | col += YG 376 | 377 | # GC 378 | colorwheel[col:col+GC, 1] = 255 379 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 380 | col += GC 381 | 382 | # CB 383 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 384 | colorwheel[col:col+CB, 2] = 255 385 | col += CB 386 | 387 | # BM 388 | colorwheel[col:col+BM, 2] = 255 389 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 390 | col += + BM 391 | 392 | # MR 393 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 394 | colorwheel[col:col+MR, 0] = 255 395 | 396 | return colorwheel 397 | 398 | def flow_tf(self, flow, size): 399 | flow_shape = flow.shape 400 | flow_resized = cv2.resize(flow, (size[1], size[0])) 401 | flow_resized[:, :, 0] *= (float(size[1]) / float(flow_shape[1])) 402 | flow_resized[:, :, 1] *= (float(size[0]) / float(flow_shape[0])) 403 | 404 | return flow_resized 405 | -------------------------------------------------------------------------------- /models/DeepFill_Models/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | def weights_init(init_type='gaussian'): 10 | def init_fun(m): 11 | classname = m.__class__.__name__ 12 | if (classname.find('Conv') == 0 or classname.find( 13 | 'Linear') == 0) and hasattr(m, 'weight'): 14 | if init_type == 'gaussian': 15 | nn.init.normal_(m.weight, 0.0, 0.02) 16 | elif init_type == 'xavier': 17 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) 18 | elif init_type == 'kaiming': 19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | elif init_type == 'orthogonal': 21 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) 22 | elif init_type == 'default': 23 | pass 24 | else: 25 | assert 0, "Unsupported initialization: {}".format(init_type) 26 | if hasattr(m, 'bias') and m.bias is not None: 27 | nn.init.constant_(m.bias, 0.0) 28 | 29 | return init_fun 30 | 31 | 32 | class Conv(nn.Module): 33 | def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU(), isGated=False): 34 | super(Conv, self).__init__() 35 | if activation is not None: 36 | self.conv = nn.Sequential( 37 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D), 38 | activation 39 | ) 40 | else: 41 | self.conv = nn.Sequential( 42 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D) 43 | ) 44 | 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | m.apply(weights_init('kaiming')) 48 | 49 | def forward(self, x): 50 | x = self.conv(x) 51 | return x 52 | 53 | 54 | class Conv_Downsample(nn.Module): 55 | def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU()): 56 | super(Conv_Downsample, self).__init__() 57 | 58 | PaddingLayer = torch.nn.ZeroPad2d((0, (K-1)//2, 0, (K-1)//2)) 59 | 60 | if activation is not None: 61 | self.conv = nn.Sequential( 62 | PaddingLayer, 63 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=0, dilation=D), 64 | activation 65 | ) 66 | else: 67 | self.conv = nn.Sequential( 68 | PaddingLayer, 69 | nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=0, dilation=D) 70 | ) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | m.apply(weights_init('kaiming')) 75 | 76 | def forward(self, x): 77 | x = self.conv(x) 78 | return x 79 | 80 | 81 | class Down_Module(nn.Module): 82 | def __init__(self, in_ch, out_ch, activation=nn.ELU(), isRefine=False, 83 | isAttn=False, ): 84 | super(Down_Module, self).__init__() 85 | layers = [] 86 | layers.append(Conv(in_ch, out_ch, K=5, P=2)) 87 | # curr_dim = out_ch 88 | # layers.append(Conv_Downsample(curr_dim, curr_dim * 2, K=3, S=2, isGated=isGated)) 89 | 90 | curr_dim = out_ch 91 | if isRefine: 92 | if isAttn: 93 | layers.append(Conv_Downsample(curr_dim, curr_dim, K=3, S=2)) 94 | layers.append(Conv(curr_dim, 2*curr_dim, K=3, S=1)) 95 | layers.append(Conv_Downsample(2*curr_dim, 4*curr_dim, K=3, S=2)) 96 | layers.append(Conv(4 * curr_dim, 4 * curr_dim, K=3, S=1)) 97 | curr_dim *= 4 98 | else: 99 | for i in range(2): 100 | layers.append(Conv_Downsample(curr_dim, curr_dim, K=3, S=2)) 101 | layers.append(Conv(curr_dim, curr_dim*2)) 102 | curr_dim *= 2 103 | else: 104 | for i in range(2): 105 | layers.append(Conv_Downsample(curr_dim, curr_dim*2, K=3, S=2)) 106 | layers.append(Conv(curr_dim * 2, curr_dim * 2)) 107 | curr_dim *= 2 108 | 109 | layers.append(Conv(curr_dim, curr_dim, activation=activation)) 110 | 111 | self.out = nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | return self.out(x) 115 | 116 | 117 | class Dilation_Module(nn.Module): 118 | def __init__(self, in_ch, out_ch): 119 | super(Dilation_Module, self).__init__() 120 | layers = [] 121 | dilation = 1 122 | for i in range(4): 123 | dilation *= 2 124 | layers.append(Conv(in_ch, out_ch, D=dilation, P=dilation)) 125 | self.out = nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | return self.out(x) 129 | 130 | 131 | class Up_Module(nn.Module): 132 | def __init__(self, in_ch, out_ch, isRefine=False): 133 | super(Up_Module, self).__init__() 134 | layers = [] 135 | curr_dim = in_ch 136 | if isRefine: 137 | layers.append(Conv(curr_dim, curr_dim//2)) 138 | curr_dim //= 2 139 | else: 140 | layers.append(Conv(curr_dim, curr_dim)) 141 | 142 | # conv 12~15 143 | for i in range(2): 144 | layers.append(Conv(curr_dim, curr_dim)) 145 | layers.append(nn.Upsample(scale_factor=2, mode='nearest')) 146 | layers.append(Conv(curr_dim, curr_dim//2)) 147 | curr_dim //= 2 148 | 149 | layers.append(Conv(curr_dim, curr_dim//2)) 150 | layers.append(Conv(curr_dim//2, out_ch, activation=None)) 151 | 152 | self.out = nn.Sequential(*layers) 153 | 154 | def forward(self, x): 155 | output = self.out(x) 156 | return torch.clamp(output, min=-1., max=1.) 157 | 158 | 159 | class Up_Module_CNet(nn.Module): 160 | def __init__(self, in_ch, out_ch, isRefine=False, isGated=False): 161 | super(Up_Module_CNet, self).__init__() 162 | layers = [] 163 | curr_dim = in_ch 164 | if isRefine: 165 | layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated)) 166 | curr_dim //= 2 167 | else: 168 | layers.append(Conv(curr_dim, curr_dim, isGated=isGated)) 169 | 170 | # conv 12~15 171 | for i in range(2): 172 | layers.append(Conv(curr_dim, curr_dim, isGated=isGated)) 173 | layers.append(nn.Upsample(scale_factor=2, mode='nearest')) 174 | layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated)) 175 | curr_dim //= 2 176 | 177 | layers.append(Conv(curr_dim, curr_dim//2, isGated=isGated)) 178 | layers.append(Conv(curr_dim//2, out_ch, activation=None, isGated=isGated)) 179 | 180 | self.out = nn.Sequential(*layers) 181 | 182 | def forward(self, x): 183 | output = self.out(x) 184 | return output 185 | 186 | 187 | class Flatten_Module(nn.Module): 188 | def __init__(self, in_ch, out_ch, isLocal=True): 189 | super(Flatten_Module, self).__init__() 190 | layers = [] 191 | layers.append(Conv(in_ch, out_ch, K=5, S=2, P=2, activation=nn.LeakyReLU())) 192 | curr_dim = out_ch 193 | 194 | for i in range(2): 195 | layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU())) 196 | curr_dim *= 2 197 | 198 | if isLocal: 199 | layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU())) 200 | else: 201 | layers.append(Conv(curr_dim, curr_dim, K=5, S=2, P=2, activation=nn.LeakyReLU())) 202 | 203 | self.out = nn.Sequential(*layers) 204 | 205 | def forward(self, x): 206 | x = self.out(x) 207 | return x.view(x.size(0),-1) # 2B x 256*(256 or 512); front 256:16*16 208 | 209 | 210 | class Contextual_Attention_Module(nn.Module): 211 | def __init__(self, in_ch, out_ch, rate=2, stride=1, isCheck=False, device=None): 212 | super(Contextual_Attention_Module, self).__init__() 213 | self.rate = rate 214 | self.padding = nn.ZeroPad2d(1) 215 | self.up_sample = nn.Upsample(scale_factor=self.rate, mode='nearest') 216 | layers = [] 217 | for i in range(2): 218 | layers.append(Conv(in_ch, out_ch)) 219 | self.out = nn.Sequential(*layers) 220 | self.isCheck = isCheck 221 | self.device = device 222 | 223 | def forward(self, f, b, mask=None, ksize=3, stride=1, 224 | fuse_k=3, softmax_scale=10., training=True, fuse=True): 225 | 226 | """ Contextual attention layer implementation. 227 | 228 | Contextual attention is first introduced in publication: 229 | Generative Image Inpainting with Contextual Attention, Yu et al. 230 | 231 | Args: 232 | f: Input feature to match (foreground). 233 | b: Input feature for match (background). 234 | mask: Input mask for b, indicating patches not available. 235 | ksize: Kernel size for contextual attention. 236 | stride: Stride for extracting patches from b. 237 | rate: Dilation for matching. 238 | softmax_scale: Scaled softmax for attention. 239 | training: Indicating if current graph is training or inference. 240 | 241 | Returns: 242 | tf.Tensor: output 243 | 244 | """ 245 | 246 | # get shapes 247 | raw_fs = f.size() # B x 128 x 64 x 64 248 | raw_int_fs = list(f.size()) 249 | raw_int_bs = list(b.size()) 250 | 251 | # extract patches from background with stride and rate 252 | kernel = 2*self.rate 253 | raw_w = self.extract_patches(b, kernel=kernel, stride=self.rate) 254 | raw_w = raw_w.permute(0, 2, 3, 4, 5, 1) 255 | raw_w = raw_w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1) 256 | raw_w = raw_w.contiguous().view(raw_int_bs[0], -1, kernel, kernel, raw_int_bs[1]) 257 | raw_w = raw_w.permute(0, 1, 4, 2, 3) 258 | 259 | f = down_sample(f, scale_factor=1/self.rate, mode='nearest', device=self.device) 260 | b = down_sample(b, scale_factor=1/self.rate, mode='nearest', device=self.device) 261 | 262 | fs = f.size() # B x 128 x 32 x 32 263 | int_fs = list(f.size()) 264 | f_groups = torch.split(f, 1, dim=0) # Split tensors by batch dimension; tuple is returned 265 | 266 | # from b(B*H*W*C) to w(b*k*k*c*h*w) 267 | bs = b.size() # B x 128 x 32 x 32 268 | int_bs = list(b.size()) 269 | w = self.extract_patches(b) 270 | w = w.permute(0, 2, 3, 4, 5, 1) 271 | w = w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1) 272 | w = w.contiguous().view(raw_int_bs[0], -1, ksize, ksize, raw_int_bs[1]) 273 | w = w.permute(0, 1, 4, 2, 3) 274 | # process mask 275 | mask = mask.clone() 276 | if mask is not None: 277 | if mask.size(2) != b.size(2): 278 | mask = down_sample(mask, scale_factor=1./self.rate, mode='nearest', device=self.device) 279 | else: 280 | mask = torch.zeros([1, 1, bs[2], bs[3]]) 281 | 282 | m = self.extract_patches(mask) 283 | 284 | m = m.permute(0, 2, 3, 4, 5, 1) 285 | m = m.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1) 286 | m = m.contiguous().view(raw_int_bs[0], -1, ksize, ksize, 1) 287 | m = m.permute(0, 4, 1, 2, 3) 288 | 289 | m = m[0] # (1, 32*32, 3, 3) 290 | m = reduce_mean(m) # smoothing, maybe 291 | mm = m.eq(0.).float() # (1, 32*32, 1, 1) 292 | 293 | w_groups = torch.split(w, 1, dim=0) # Split tensors by batch dimension; tuple is returned 294 | raw_w_groups = torch.split(raw_w, 1, dim=0) # Split tensors by batch dimension; tuple is returned 295 | y = [] 296 | offsets = [] 297 | k = fuse_k 298 | scale = softmax_scale 299 | fuse_weight = Variable(torch.eye(k).view(1, 1, k, k)).cuda(self.device) # 1 x 1 x K x K 300 | y_test = [] 301 | for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): 302 | ''' 303 | O => output channel as a conv filter 304 | I => input channel as a conv filter 305 | xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) 306 | wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) 307 | raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) 308 | ''' 309 | # conv for compare 310 | wi = wi[0] 311 | escape_NaN = Variable(torch.FloatTensor([1e-4])).cuda(self.device) 312 | wi_normed = wi / torch.max(l2_norm(wi), escape_NaN) 313 | yi = F.conv2d(xi, wi_normed, stride=1, padding=1) # yi => (B=1, C=32*32, H=32, W=32) 314 | y_test.append(yi) 315 | # conv implementation for fuse scores to encourage large patches 316 | if fuse: 317 | yi = yi.permute(0, 2, 3, 1) 318 | yi = yi.contiguous().view(1, fs[2] * fs[3], bs[2] * bs[3], 1) 319 | yi = yi.permute(0, 3, 1, 2) # make all of depth to spatial resolution, (B=1, I=1, H=32*32, W=32*32) 320 | yi = F.conv2d(yi, fuse_weight, stride=1, padding=1) # (B=1, C=1, H=32*32, W=32*32) 321 | 322 | yi = yi.permute(0, 2, 3, 1) 323 | yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3]) 324 | # yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3]) # (B=1, 32, 32, 32, 32) 325 | yi = yi.permute(0, 2, 1, 4, 3) 326 | yi = yi.contiguous().view(1, fs[2] * fs[3], bs[2] * bs[3], 1) 327 | yi = yi.permute(0, 3, 1, 2) 328 | 329 | yi = F.conv2d(yi, fuse_weight, stride=1, padding=1) 330 | yi = yi.permute(0, 2, 3, 1) 331 | yi = yi.contiguous().view(1, fs[3], fs[2], bs[3], bs[2]) 332 | yi = yi.permute(0, 2, 1, 4, 3) 333 | yi = yi.contiguous().view(1, fs[2], fs[3], bs[2] * bs[3]) 334 | yi = yi.permute(0, 3, 1, 2) 335 | else: 336 | yi = yi.permute(0, 2, 3, 1) 337 | yi = yi.contiguous().view(1, fs[2], fs[3], bs[2] * bs[3]) 338 | yi = yi.permute(0, 3, 1, 2) # (B=1, C=32*32, H=32, W=32) 339 | # yi = yi.contiguous().view(1, bs[2] * bs[3], fs[2], fs[3]) 340 | 341 | # softmax to match 342 | yi = yi * mm # mm => (1, 32*32, 1, 1) 343 | yi = F.softmax(yi*scale, dim=1) 344 | yi = yi * mm # mask 345 | 346 | _, offset = torch.max(yi, dim=1) # argmax; index 347 | division = torch.true_divide(offset, fs[3]).long() 348 | offset = torch.stack([division, torch.true_divide(offset, fs[3])-division], dim=-1) 349 | 350 | wi_center = raw_wi[0] 351 | 352 | yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64) 353 | y.append(yi) 354 | offsets.append(offset) 355 | 356 | y = torch.cat(y, dim=0) # back to the mini-batch 357 | y.contiguous().view(raw_int_fs) 358 | # wi_patched = y 359 | offsets = torch.cat(offsets, dim=0) 360 | offsets = offsets.view([int_bs[0]] + [2] + int_bs[2:]) 361 | 362 | # case1: visualize optical flow: minus current position 363 | h_add = Variable(torch.arange(0,float(bs[2]))).cuda(self.device).view([1, 1, bs[2], 1]) 364 | h_add = h_add.expand(bs[0], 1, bs[2], bs[3]) 365 | w_add = Variable(torch.arange(0,float(bs[3]))).cuda(self.device).view([1, 1, 1, bs[3]]) 366 | w_add = w_add.expand(bs[0], 1, bs[2], bs[3]) 367 | 368 | offsets = offsets - torch.cat([h_add, w_add], dim=1).long() 369 | 370 | # # case2: visualize which pixels are attended 371 | # flow = torch.from_numpy(highlight_flow((offsets * mask.int()).numpy())) 372 | y = self.out(y) 373 | 374 | return y, offsets 375 | 376 | def extract_patches(self, x, kernel=3, stride=1): 377 | x = self.padding(x) 378 | all_patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride) 379 | 380 | return all_patches 381 | 382 | 383 | def reduce_mean(x): 384 | for i in range(4): 385 | if i==1: continue 386 | x = torch.mean(x, dim=i, keepdim=True) 387 | return x 388 | 389 | 390 | def l2_norm(x): 391 | def reduce_sum(x): 392 | for i in range(4): 393 | if i==0: continue 394 | x = torch.sum(x, dim=i, keepdim=True) 395 | return x 396 | 397 | x = x**2 398 | x = reduce_sum(x) 399 | return torch.sqrt(x) 400 | 401 | 402 | def down_sample(x, size=None, scale_factor=None, mode='nearest', device=None): 403 | # define size if user has specified scale_factor 404 | if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3))) 405 | # create coordinates 406 | # size_origin = [x.size[2], x.size[3]] 407 | h = torch.true_divide(torch.arange(0, size[0]), (size[0])) * 2 - 1 408 | w = torch.true_divide(torch.arange(0, size[1]), (size[1])) * 2 - 1 409 | # create grid 410 | grid = torch.zeros(size[0],size[1],2) 411 | grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1) 412 | grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1) 413 | # expand to match batch size 414 | grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1) 415 | if x.is_cuda: 416 | if device: 417 | grid = Variable(grid).cuda(device) 418 | else: 419 | grid = Variable(grid).cuda() 420 | # do sampling 421 | 422 | return F.grid_sample(x, grid, mode=mode) 423 | 424 | 425 | def to_var(x, volatile=False, device=None): 426 | if torch.cuda.is_available(): 427 | if device: 428 | x = x.cuda(device) 429 | else: 430 | x = x.cuda() 431 | return Variable(x, volatile=volatile) 432 | -------------------------------------------------------------------------------- /tool/video_completion_modified.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import glob 9 | import copy 10 | import numpy as np 11 | import cupy as cp 12 | import torch 13 | import imageio 14 | from PIL import Image 15 | import scipy.ndimage 16 | from skimage.feature import canny 17 | from skimage.transform import integral 18 | import torchvision.transforms.functional as F 19 | 20 | from RAFT import utils 21 | from RAFT import RAFT 22 | 23 | import utils.region_fill as rf 24 | from utils.Poisson_blend import Poisson_blend 25 | from utils.Poisson_blend_img import Poisson_blend_img 26 | from get_flowNN import get_flowNN 27 | from get_flowNN_gradient import get_flowNN_gradient 28 | from utils.common_utils import flow_edge 29 | from spatial_inpaint import spatial_inpaint 30 | from frame_inpaint import DeepFillv1 31 | from edgeconnect.networks import EdgeGenerator_ 32 | 33 | import time 34 | 35 | def find_minbbox(masks): 36 | # find the minimum bounding box of the holdmask 37 | minbbox_tl=[] # top left point of the minimum bounding box 38 | minbbox_br=[] # bottom right point of the minimum bounding box 39 | for i in range (0,len(masks)): 40 | non_zeros=cv2.findNonZero(np.array(masks[i]*255)) 41 | min_rect=cv2.boundingRect(non_zeros) 42 | 43 | # expand 10 pixels 44 | x1=max(0,min_rect[0]-10) 45 | y1=max(0,min_rect[1]-10) 46 | x2=min(masks[i].shape[1],min_rect[0]+min_rect[2]+10) 47 | y2=min(masks[i].shape[0],min_rect[1]+min_rect[3]+10) 48 | 49 | minbbox_tl.append([x1,y1]) 50 | minbbox_br.append([x2,y2]) 51 | return minbbox_tl,minbbox_br 52 | 53 | 54 | def to_tensor(img): 55 | img = Image.fromarray(img) 56 | img_t = F.to_tensor(img).float() 57 | return img_t 58 | 59 | 60 | def infer(args, EdgeGenerator, device, flow_img_gray, edge, mask): 61 | 62 | # Add a pytorch dataloader 63 | flow_img_gray_tensor = to_tensor(flow_img_gray)[None, :, :].float().to(device) 64 | edge_tensor = to_tensor(edge)[None, :, :].float().to(device) 65 | mask_tensor = torch.from_numpy(mask.astype(np.float64))[None, None, :, :].float().to(device) 66 | 67 | # Complete the edges 68 | edges_masked = (edge_tensor * (1 - mask_tensor)) 69 | images_masked = (flow_img_gray_tensor * (1 - mask_tensor)) + mask_tensor 70 | inputs = torch.cat((images_masked, edges_masked, mask_tensor), dim=1) 71 | with torch.no_grad(): 72 | edges_completed = EdgeGenerator(inputs) # in: [grayscale(1) + edge(1) + mask(1)] 73 | edges_completed = edges_completed * mask_tensor + edge_tensor * (1 - mask_tensor) 74 | edge_completed = edges_completed[0, 0].data.cpu().numpy() 75 | edge_completed[edge_completed < 0.5] = 0 76 | edge_completed[edge_completed >= 0.5] = 1 77 | 78 | return edge_completed 79 | 80 | 81 | def gradient_mask(mask): 82 | 83 | gradient_mask = np.logical_or.reduce((mask, 84 | np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0), 85 | np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1))) 86 | 87 | return gradient_mask 88 | 89 | 90 | def create_dir(dir): 91 | """Creates a directory if not exist. 92 | """ 93 | if not os.path.exists(dir): 94 | os.makedirs(dir) 95 | 96 | 97 | def initialize_RAFT(args): 98 | """Initializes the RAFT model. 99 | """ 100 | model = torch.nn.DataParallel(RAFT(args)) 101 | model.load_state_dict(torch.load(args.model)) 102 | 103 | model = model.module 104 | model.to('cuda') 105 | model.eval() 106 | 107 | return model 108 | 109 | 110 | def calculate_flow(args, model, video, mode): 111 | """Calculates optical flow. 112 | """ 113 | if mode not in ['forward', 'backward']: 114 | raise NotImplementedError 115 | 116 | nFrame, _, imgH, imgW = video.shape 117 | Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) 118 | 119 | # if os.path.isdir(os.path.join(args.outroot, 'flow', mode + '_flo')): 120 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow', mode + '_flo', '*.flo'))): 121 | # print("Loading {0}".format(flow_name), '\r', end='') 122 | # flow = utils.frame_utils.readFlow(flow_name) 123 | # Flow = np.concatenate((Flow, flow[..., None]), axis=-1) 124 | # return Flow 125 | 126 | # create_dir(os.path.join(args.outroot, 'flow', mode + '_flo')) 127 | # create_dir(os.path.join(args.outroot, 'flow', mode + '_png')) 128 | 129 | with torch.no_grad(): 130 | for i in range(video.shape[0] - 1): 131 | print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') 132 | if mode == 'forward': 133 | # Flow i -> i + 1 134 | image1 = video[i, None] 135 | image2 = video[i + 1, None] 136 | elif mode == 'backward': 137 | # Flow i + 1 -> i 138 | image1 = video[i + 1, None] 139 | image2 = video[i, None] 140 | else: 141 | raise NotImplementedError 142 | 143 | _, flow = model(image1, image2, iters=20, test_mode=True) 144 | flow = flow[0].permute(1, 2, 0).cpu().numpy() 145 | Flow = np.concatenate((Flow, flow[..., None]), axis=-1) 146 | 147 | # # Flow visualization. 148 | # flow_img = utils.flow_viz.flow_to_image(flow) 149 | # flow_img = Image.fromarray(flow_img) 150 | 151 | # # Saves the flow and flow_img. 152 | # flow_img.save(os.path.join(args.outroot, 'flow', mode + '_png', '%05d.png'%i)) 153 | # utils.frame_utils.writeFlow(os.path.join(args.outroot, 'flow', mode + '_flo', '%05d.flo'%i), flow) 154 | 155 | return Flow 156 | 157 | 158 | 159 | def complete_flow(args, corrFlow, flow_mask, mode, minbbox_tl, minbbox_br, edge=None): 160 | """Completes flow. 161 | """ 162 | if mode not in ['forward', 'backward']: 163 | raise NotImplementedError 164 | 165 | imgH, imgW, _, nFrame = corrFlow.shape 166 | 167 | # if os.path.isdir(os.path.join(args.outroot, 'flow_comp', mode + '_flo')): 168 | # compFlow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) 169 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow_comp', mode + '_flo', '*.flo'))): 170 | # print("Loading {0}".format(flow_name), '\r', end='') 171 | # flow = utils.frame_utils.readFlow(flow_name) 172 | # compFlow = np.concatenate((compFlow, flow[..., None]), axis=-1) 173 | # return compFlow 174 | 175 | # create_dir(os.path.join(args.outroot, 'flow_comp', mode + '_flo')) 176 | # create_dir(os.path.join(args.outroot, 'flow_comp', mode + '_png')) 177 | 178 | compFlow=corrFlow.copy() 179 | for i in range(nFrame): 180 | print("Completing {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') 181 | flow = corrFlow[:, :, :, i] 182 | if mode == 'forward': 183 | flow_crop=flow[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0],:] 184 | flow_mask_img = flow_mask[:, :, i] 185 | flow_mask_img_crop=flow_mask_img[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0]] 186 | else: 187 | flow_crop=flow[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0],:] 188 | flow_mask_img = flow_mask[:, :, i+1] 189 | flow_mask_img_crop=flow_mask_img[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0]] 190 | 191 | # cv2.imwrite("./flow_mask_img_crop.png",flow_mask_img_crop*255) 192 | flow_mask_gradient_img = gradient_mask(flow_mask_img) 193 | flow_mask_gradient_img_crop = gradient_mask(flow_mask_img_crop) 194 | 195 | if edge is not None: 196 | # imgH x (imgW - 1 + 1) x 2 197 | gradient_x = np.concatenate((np.diff(flow_crop, axis=1), np.zeros((flow_crop.shape[0], 1, 2), dtype=np.float32)), axis=1) 198 | # (imgH - 1 + 1) x imgW x 2 199 | gradient_y = np.concatenate((np.diff(flow_crop, axis=0), np.zeros((1, flow_crop.shape[1], 2), dtype=np.float32)), axis=0) 200 | 201 | # concatenate gradient_x and gradient_y 202 | gradient = np.concatenate((gradient_x, gradient_y), axis=2) 203 | 204 | # We can trust the gradient outside of flow_mask_gradient_img 205 | # We assume the gradient within flow_mask_gradient_img is 0. 206 | gradient[flow_mask_gradient_img_crop, :] = 0 207 | 208 | # Complete the flow 209 | imgSrc_gy = gradient[:, :, 2 : 4] 210 | imgSrc_gy = imgSrc_gy[0 : flow_crop.shape[0] - 1, :, :] 211 | imgSrc_gx = gradient[:, :, 0 : 2] 212 | imgSrc_gx = imgSrc_gx[:, 0 : flow_crop.shape[1] - 1, :] 213 | if mode == 'forward': 214 | edge_crop=edge[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0],i] 215 | else: 216 | edge_crop=edge[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0],i] 217 | compFlow_crop = Poisson_blend(flow_crop, imgSrc_gx, imgSrc_gy, flow_mask_img_crop, edge_crop) 218 | 219 | #return original size 220 | if mode == 'forward': 221 | compFlow[minbbox_tl[i][1]:minbbox_br[i][1],minbbox_tl[i][0]:minbbox_br[i][0], :, i] = compFlow_crop 222 | else: 223 | compFlow[minbbox_tl[i+1][1]:minbbox_br[i+1][1],minbbox_tl[i+1][0]:minbbox_br[i+1][0], :, i] = compFlow_crop 224 | 225 | else: 226 | flow[:, :, 0] = rf.regionfill(flow[:, :, 0], flow_mask_img) 227 | flow[:, :, 1] = rf.regionfill(flow[:, :, 1], flow_mask_img) 228 | compFlow[:, :, :, i] = flow 229 | 230 | ## Flow visualization. 231 | # flow_img = utils.flow_viz.flow_to_image(compFlow[:, :, :, i]) 232 | # flow_img = Image.fromarray(flow_img) 233 | 234 | ## Saves the flow and flow_img. 235 | # flow_img.save(os.path.join(args.outroot, 'flow_comp', mode + '_png', '%05d.png'%i)) 236 | # utils.frame_utils.writeFlow(os.path.join(args.outroot, 'flow_comp', mode + '_flo', '%05d.flo'%i), compFlow[:, :, :, i]) 237 | 238 | return compFlow 239 | 240 | 241 | def edge_completion(args, EdgeGenerator, corrFlow, flow_mask, mode): 242 | """Calculate flow edge and complete it. 243 | """ 244 | 245 | if mode not in ['forward', 'backward']: 246 | raise NotImplementedError 247 | 248 | imgH, imgW, _, nFrame = corrFlow.shape 249 | Edge = np.empty(((imgH, imgW, 0)), dtype=np.float32) 250 | 251 | for i in range(nFrame): 252 | print("Completing {0} flow edge {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') 253 | flow_mask_img = flow_mask[:, :, i] if mode == 'forward' else flow_mask[:, :, i + 1] 254 | 255 | flow_img_gray = (corrFlow[:, :, 0, i] ** 2 + corrFlow[:, :, 1, i] ** 2) ** 0.5 256 | flow_img_gray = flow_img_gray / flow_img_gray.max() 257 | 258 | edge_corr = canny(flow_img_gray, sigma=2, mask=(1 - flow_mask_img).astype(np.bool)) 259 | edge_completed = infer(args, EdgeGenerator, torch.device('cuda:1'), flow_img_gray, edge_corr, flow_mask_img) 260 | Edge = np.concatenate((Edge, edge_completed[..., None]), axis=-1) 261 | 262 | return Edge 263 | 264 | 265 | def video_completion_seamless(args): 266 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 267 | 268 | # Flow model. 269 | RAFT_model = initialize_RAFT(args) 270 | 271 | # Loads frames. 272 | filename_list = glob.glob(os.path.join(args.path, '*.png')) + \ 273 | glob.glob(os.path.join(args.path, '*.jpg')) 274 | 275 | # Obtains imgH, imgW and nFrame. 276 | imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2] 277 | nFrame = len(filename_list) 278 | 279 | # Loads video. 280 | video = [] 281 | for filename in sorted(filename_list): 282 | video.append(torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float()) 283 | video = torch.stack(video, dim=0) 284 | video = video.to('cuda') 285 | 286 | # Loads masks. 287 | filename_list = glob.glob(os.path.join(args.path_mask, '*.png')) + \ 288 | glob.glob(os.path.join(args.path_mask, '*.jpg')) 289 | mask = [] 290 | mask_dilated = [] 291 | flow_mask = [] 292 | for filename in sorted(filename_list): 293 | mask_img = np.array(Image.open(filename).convert('L')) 294 | flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=3) 295 | # Close the small holes inside the foreground objects 296 | flow_mask_img = cv2.morphologyEx(flow_mask_img.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((11, 11),np.uint8)).astype(np.bool) 297 | flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.bool) 298 | flow_mask.append(flow_mask_img) 299 | 300 | # Dilate a little bit 301 | mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=3) 302 | mask_img = scipy.ndimage.binary_fill_holes(mask_img).astype(np.bool) 303 | mask.append(mask_img) 304 | mask_dilated.append(gradient_mask(mask_img)) 305 | 306 | minbbox_tl,minbbox_br=find_minbbox(flow_mask) 307 | 308 | # Image inpainting model. 309 | deepfill = DeepFillv1(pretrained_model=args.deepfill_model, image_shape=[imgH, imgW]) 310 | 311 | # timer 312 | time_start=time.time() 313 | 314 | # Calcutes the corrupted flow. 315 | corrFlowF = calculate_flow(args, RAFT_model, video, 'forward') 316 | corrFlowB = calculate_flow(args, RAFT_model, video, 'backward') 317 | print('\nFinish flow prediction.') 318 | 319 | # Makes sure video is in BGR (opencv) format. 320 | video = video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255. 321 | 322 | # mask indicating the missing region in the video. 323 | mask = np.stack(mask, -1).astype(np.bool) 324 | mask_dilated = np.stack(mask_dilated, -1).astype(np.bool) 325 | flow_mask = np.stack(flow_mask, -1).astype(np.bool) 326 | 327 | if args.edge_guide: 328 | # Edge completion model. 329 | EdgeGenerator = EdgeGenerator_() 330 | EdgeComp_ckpt = torch.load(args.edge_completion_model) 331 | EdgeGenerator.load_state_dict(EdgeComp_ckpt['generator']) 332 | EdgeGenerator.to(torch.device('cuda:1')) 333 | EdgeGenerator.eval() 334 | 335 | # Edge completion. 336 | FlowF_edge = edge_completion(args, EdgeGenerator, corrFlowF, flow_mask, 'forward') 337 | FlowB_edge = edge_completion(args, EdgeGenerator, corrFlowB, flow_mask, 'backward') 338 | print('\nFinish edge completion.') 339 | else: 340 | FlowF_edge, FlowB_edge = None, None 341 | 342 | 343 | # Completes the flow. 344 | videoFlowF = complete_flow(args, corrFlowF, flow_mask, 'forward', minbbox_tl, minbbox_br, FlowF_edge) 345 | videoFlowB = complete_flow(args, corrFlowB, flow_mask, 'backward', minbbox_tl, minbbox_br, FlowB_edge) 346 | print('\nFinish flow completion.') 347 | 348 | # Prepare gradients 349 | gradient_x = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32) 350 | gradient_y = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32) 351 | 352 | for indFrame in range(nFrame): 353 | img = video[:, :, :, indFrame] 354 | img[mask[:, :, indFrame], :] = 0 355 | 356 | img = cv2.inpaint((img * 255).astype(np.uint8), mask[:, :, indFrame].astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255. 357 | 358 | gradient_x_ = np.concatenate((np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1) 359 | gradient_y_ = np.concatenate((np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0) 360 | gradient_x = np.concatenate((gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1) 361 | gradient_y = np.concatenate((gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1) 362 | 363 | gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0 364 | gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0 365 | 366 | 367 | iter = 0 368 | mask_tofill = mask 369 | gradient_x_filled = gradient_x # corrupted gradient_x, mask_gradient indicates the missing gradient region 370 | gradient_y_filled = gradient_y # corrupted gradient_y, mask_gradient indicates the missing gradient region 371 | mask_gradient = mask_dilated 372 | video_comp = video 373 | 374 | # We iteratively complete the video. 375 | while(np.sum(mask) > 0): 376 | # create_dir(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter))) 377 | 378 | # Gradient propagation. 379 | gradient_x_filled, gradient_y_filled, mask_gradient = \ 380 | get_flowNN_gradient(args, 381 | gradient_x_filled, 382 | gradient_y_filled, 383 | mask, 384 | mask_gradient, 385 | videoFlowF, 386 | videoFlowB, 387 | None, 388 | None) 389 | 390 | # if there exist holes in mask, Poisson blending will fail. So I did this trick. I sacrifice some value. Another solution is to modify Poisson blending. 391 | for indFrame in range(nFrame): 392 | mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes(mask_gradient[:, :, indFrame]).astype(np.bool) 393 | 394 | # After one gradient propagation iteration 395 | # gradient --> RGB 396 | for indFrame in range(nFrame): 397 | print("Poisson blending frame {0:3d}".format(indFrame)) 398 | 399 | if mask[:, :, indFrame].sum() > 0: 400 | try: 401 | video_comp_crop=video_comp[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], :, indFrame] 402 | gradient_x_filled_crop=gradient_x_filled[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0]-1, :, indFrame] 403 | gradient_y_filled_crop=gradient_y_filled[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1]-1,minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], :, indFrame] 404 | mask_crop=mask[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], indFrame] 405 | mask_gradient_crop=mask_gradient[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0], indFrame]; 406 | frameBlend_crop, UnfilledMask_crop = Poisson_blend_img(video_comp_crop, gradient_x_filled_crop, gradient_y_filled_crop, mask_crop, mask_gradient_crop) 407 | 408 | frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame] 409 | frameBlend[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0],:]=frameBlend_crop 410 | UnfilledMask[minbbox_tl[indFrame][1]:minbbox_br[indFrame][1],minbbox_tl[indFrame][0]:minbbox_br[indFrame][0]]=UnfilledMask_crop 411 | # UnfilledMask = scipy.ndimage.binary_fill_holes(UnfilledMask).astype(np.bool) 412 | 413 | # frameBlend, UnfilledMask = Poisson_blend_img(video_comp[:, :, :, indFrame], gradient_x_filled[:, 0 : imgW - 1, :, indFrame], gradient_y_filled[0 : imgH - 1, :, :, indFrame], mask[:, :, indFrame], mask_gradient[:, :, indFrame]) 414 | # UnfilledMask = scipy.ndimage.binary_fill_holes(UnfilledMask).astype(np.bool) 415 | except: 416 | frameBlend, UnfilledMask = video_comp[:, :, :, indFrame], mask[:, :, indFrame] 417 | 418 | frameBlend = np.clip(frameBlend, 0, 1.0) 419 | tmp = cv2.inpaint((frameBlend * 255).astype(np.uint8), UnfilledMask.astype(np.uint8), 3, cv2.INPAINT_TELEA).astype(np.float32) / 255. 420 | frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :] 421 | 422 | video_comp[:, :, :, indFrame] = frameBlend 423 | mask[:, :, indFrame] = UnfilledMask 424 | 425 | frameBlend_ = copy.deepcopy(frameBlend) 426 | # Green indicates the regions that are not filled yet. 427 | frameBlend_[mask[:, :, indFrame], :] = [0, 1., 0] 428 | else: 429 | frameBlend_ = video_comp[:, :, :, indFrame] 430 | 431 | # cv2.imwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter), '%05d.png'%indFrame), frameBlend_ * 255.) 432 | 433 | # video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1] 434 | # imageio.mimwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter), 'intermediate_{0}.mp4'.format(str(iter))), video_comp_, fps=12, quality=8, macro_block_size=1) 435 | # imageio.mimsave(os.path.join(args.outroot, 'frame_seamless_comp_' + str(iter), 'intermediate_{0}.gif'.format(str(iter))), video_comp_, format='gif', fps=12) 436 | 437 | mask, video_comp = spatial_inpaint(deepfill, mask, video_comp) 438 | iter += 1 439 | 440 | # Re-calculate gradient_x/y_filled and mask_gradient 441 | for indFrame in range(nFrame): 442 | mask_gradient[:, :, indFrame] = gradient_mask(mask[:, :, indFrame]) 443 | 444 | gradient_x_filled[:, :, :, indFrame] = np.concatenate((np.diff(video_comp[:, :, :, indFrame], axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1) 445 | gradient_y_filled[:, :, :, indFrame] = np.concatenate((np.diff(video_comp[:, :, :, indFrame], axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0) 446 | 447 | gradient_x_filled[mask_gradient[:, :, indFrame], :, indFrame] = 0 448 | gradient_y_filled[mask_gradient[:, :, indFrame], :, indFrame] = 0 449 | 450 | video_comp_ = (video_comp * 255).astype(np.uint8).transpose(3, 0, 1, 2)[:, :, :, ::-1] 451 | 452 | time_end=time.time() 453 | print('time cost',time_end-time_start,'s') 454 | 455 | # write out 456 | create_dir(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final')) 457 | for i in range(nFrame): 458 | img = video_comp[:, :, :, i] * 255 459 | cv2.imwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final', '%06d.png'%i), img) 460 | # imageio.mimwrite(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final', 'final.mp4'), video_comp_, fps=12, quality=8, macro_block_size=1) 461 | # imageio.mimsave(os.path.join(args.outroot, 'frame_seamless_comp_' + 'final', 'final.gif'), video_comp_, format='gif', fps=12) 462 | 463 | 464 | def main(args): 465 | 466 | assert args.mode in ('object_removal', 'video_extrapolation'), ( 467 | "Accepted modes: 'object_removal', 'video_extrapolation', but input is %s" 468 | ) % mode 469 | 470 | if args.seamless: 471 | video_completion_seamless(args) 472 | else: 473 | video_completion(args) 474 | 475 | 476 | if __name__ == '__main__': 477 | parser = argparse.ArgumentParser() 478 | 479 | # video completion 480 | parser.add_argument('--seamless', action='store_true', help='Whether operate in the gradient domain') 481 | parser.add_argument('--edge_guide', action='store_true', help='Whether use edge as guidance to complete flow') 482 | parser.add_argument('--mode', default='object_removal', help="modes: object_removal / video_extrapolation") 483 | parser.add_argument('--path', default='../data/frames_corr', help="dataset for evaluation") 484 | parser.add_argument('--path_mask', default='../data/masks', help="mask for object removal") 485 | parser.add_argument('--outroot', default='../result/', help="output directory") 486 | parser.add_argument('--consistencyThres', dest='consistencyThres', default=np.inf, type=float, help='flow consistency error threshold') 487 | parser.add_argument('--alpha', dest='alpha', default=0.1, type=float) 488 | parser.add_argument('--Nonlocal', dest='Nonlocal', default=False, type=bool) 489 | 490 | # RAFT 491 | parser.add_argument('--model', default='./weight/raft-things.pth', help="restore checkpoint") 492 | parser.add_argument('--small', action='store_true', help='use small model') 493 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 494 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 495 | 496 | # Deepfill 497 | parser.add_argument('--deepfill_model', default='./weight/imagenet_deepfill.pth', help="restore checkpoint") 498 | 499 | # Edge completion 500 | parser.add_argument('--edge_completion_model', default='./weight/edge_completion.pth', help="restore checkpoint") 501 | 502 | # extrapolation 503 | parser.add_argument('--H_scale', dest='H_scale', default=2, type=float, help='H extrapolation scale') 504 | parser.add_argument('--W_scale', dest='W_scale', default=2, type=float, help='W extrapolation scale') 505 | 506 | args = parser.parse_args() 507 | 508 | main(args) 509 | 510 | --------------------------------------------------------------------------------