├── src
├── __init__.py
├── configs
│ ├── __init__.py
│ ├── DIP.py
│ ├── DIP_Vid.py
│ ├── DIP_Vid_Flow.py
│ ├── DIP_Vid_3DCN.py
│ └── base.py
├── models
│ ├── __init__.py
│ ├── perceptual.py
│ ├── common.py
│ ├── encoder_decoder_2d.py
│ ├── encoder_decoder_3d.py
│ └── pwc_net.py
├── correlation
│ ├── __init__.py
│ ├── README.md
│ └── correlation.py
├── flow_estimator.py
├── utils.py
├── inpainting_dataset.py
└── inpainting_test.py
├── .gitignore
├── img
└── rollerblade.gif
├── data
└── README.md
├── requirements.txt
├── pretrained_models
└── README.md
├── Dockerfile
├── train.py
├── README.md
└── demo.ipynb
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/correlation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .git
2 | data
3 | result
4 | pretrained_models
5 | *.pyc
6 | *.ipynb_checkpoints/
7 | *__pycache__
--------------------------------------------------------------------------------
/img/rollerblade.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Haotianz94/IL_video_inpainting/HEAD/img/rollerblade.gif
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | Please download the demo data from [here](https://drive.google.com/open?id=1MJDCjj1aIUbW0OK9UnewhXlkKX9zllQd).
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.0.0
2 | torchvision==0.2.0
3 | ipykernel==5.1.2
4 | jupyter==1.0.0
5 | opencv-contrib-python==4.1.0
6 | matplotlib==3.0.3
7 | scipy==1.3.1
8 | cupy==6.3.0
--------------------------------------------------------------------------------
/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | Please download the pre-trained model weights for PWC-Net from [here](https://drive.google.com/open?id=1vyoQFBz--DEkUq-0gucbWYrVbfYTOZfz).
2 | (The model was originally trained by Simon Niklaus from [here](https://github.com/sniklaus/pytorch-pwc)).
--------------------------------------------------------------------------------
/src/correlation/README.md:
--------------------------------------------------------------------------------
1 | This is an adaptation of the FlowNet2 implemenation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
--------------------------------------------------------------------------------
/src/configs/DIP.py:
--------------------------------------------------------------------------------
1 | from configs.base import cfg
2 | cfg = cfg.copy()
3 |
4 | # Dataset
5 | cfg['batch_size'] = 1
6 | cfg['batch_stride'] = [1]
7 | cfg['batch_mode'] = 'seq'
8 | cfg['resize'] = (192, 384)
9 |
10 | # Model
11 | cfg['net_type_G'] = '2d'
12 |
13 | # Loss
14 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0}
15 |
16 | # Optimize
17 | cfg['train_mode'] = 'DIP'
18 | cfg['num_iter'] = 5000
19 | cfg['num_pass'] = 1
20 | cfg['fine_tune'] = False
21 | cfg['param_noise'] = True
--------------------------------------------------------------------------------
/src/configs/DIP_Vid.py:
--------------------------------------------------------------------------------
1 | from configs.base import cfg
2 | cfg = cfg.copy()
3 |
4 | # Dataset
5 | cfg['batch_size'] = 5
6 | cfg['batch_stride'] = [1]
7 | cfg['batch_mode'] = 'random'
8 | cfg['resize'] = (192, 384)
9 |
10 | # Model
11 | cfg['net_type_G'] = '2d'
12 |
13 | # Loss
14 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0}
15 |
16 | # Optimize
17 | cfg['train_mode'] = 'DIP-Vid'
18 | cfg['num_iter'] = 100
19 | cfg['num_pass'] = 20
20 | cfg['fine_tune'] = True
21 | cfg['param_noise'] = False
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.0-cudnn7-devel
2 | MAINTAINER Haotian zhang
3 |
4 | RUN apt-get update && apt-get install -y rsync openssh-server vim wget unzip htop tmux
5 | RUN apt-get install -y libsm6 libxrender1 libgtk2.0-dev
6 |
7 | RUN apt-get install python3-pip -y
8 | RUN pip3 install --upgrade pip
9 |
10 | RUN pip3 install torch==1.0.0 torchvision==0.2.0
11 | RUN pip3 install ipykernel jupyter
12 | RUN pip3 install opencv-contrib-python matplotlib
13 | RUN pip3 install cupy scipy
14 |
15 | EXPOSE 8888
--------------------------------------------------------------------------------
/src/configs/DIP_Vid_Flow.py:
--------------------------------------------------------------------------------
1 | from configs.base import cfg
2 | cfg = cfg.copy()
3 |
4 | # Dataset
5 | cfg['batch_size'] = 5
6 | cfg['batch_stride'] = [1, 3, 5]
7 | cfg['batch_mode'] = 'random'
8 | cfg['resize'] = (192, 384)
9 |
10 | # Model
11 | cfg['net_type_G'] = '2d'
12 |
13 | # Loss
14 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0.1, 'consistency': 1, 'perceptual': 0.01}
15 |
16 | # Optimize
17 | cfg['train_mode'] = 'DIP-Vid-Flow'
18 | cfg['num_iter'] = 100
19 | cfg['num_pass'] = 20
20 | cfg['fine_tune'] = True
21 | cfg['param_noise'] = False
22 |
23 | # Result
24 | cfg['save_every_batch'] = 50
25 |
--------------------------------------------------------------------------------
/src/configs/DIP_Vid_3DCN.py:
--------------------------------------------------------------------------------
1 | from configs.base import cfg
2 | cfg = cfg.copy()
3 |
4 | # Dataset
5 | cfg['batch_size'] = 5
6 | cfg['batch_stride'] = [1]
7 | cfg['batch_mode'] = 'random'
8 | cfg['resize'] = (192, 384)
9 |
10 | # Model
11 | cfg['net_type_G'] = '3d'
12 | cfg['filter_size_down'] = (3, 5, 5)
13 | cfg['filter_size_up'] = (3, 3, 3)
14 | cfg['filter_size_skip'] = (1, 1, 1)
15 |
16 | # Loss
17 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0}
18 |
19 | # Optimize
20 | cfg['train_mode'] = 'DIP-Vid-3DCN'
21 | cfg['num_iter'] = 100
22 | cfg['num_pass'] = 20
23 | cfg['fine_tune'] = True
24 | cfg['param_noise'] = False
--------------------------------------------------------------------------------
/src/models/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LossNetwork(nn.Module):
6 | """
7 | Extract certain feature maps from pretrained VGG model, used for computing perceptual loss
8 | """
9 | def __init__(self, vgg_model=None, output_layer=['3', '8', '15']):
10 | super(LossNetwork, self).__init__()
11 | if vgg_model is None:
12 | # prepare fixed VGG16
13 | conv = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False)
14 | conv.weight.data.fill_(1)
15 | pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
16 | relu = torch.nn.ReLU()
17 | model = []
18 | model += ([conv, relu] * 2 + [pool]) * 2
19 | model += ([conv, relu] * 4 + [pool]) * 2
20 | model += [conv, relu] * 4
21 | self.vgg_layers = model
22 | else:
23 | self.vgg_layers = vgg_model.features
24 | self.output_layer = output_layer
25 | self.layer_name_mapping = {
26 | '3': "relu1_2",
27 | '8': "relu2_2",
28 | '15': "relu3_3",
29 | '22': "relu4_3"
30 | }
31 |
32 | def forward(self, x):
33 | feature_list = []
34 | if type(self.vgg_layers) == list:
35 | for layer, module in enumerate(self.vgg_layers):
36 | x = module(x)
37 | if str(layer) in self.output_layer:
38 | feature_list.append(x)
39 | else:
40 | for name, module in self.vgg_layers._modules.items():
41 | x = module(x)
42 | if name in self.output_layer:
43 | feature_list.append(x)
44 | return feature_list
45 |
--------------------------------------------------------------------------------
/src/configs/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import logging
4 |
5 | cfg = {}
6 | # Dataset
7 | cfg['video_path'] = 'data/bmx-trees.avi'
8 | cfg['mask_path'] = 'data/bmx-trees_mask.avi'
9 | cfg['batch_size'] = 5
10 | cfg['batch_stride'] = [1]
11 | cfg['batch_mode'] = 'random'
12 | cfg['traverse_step'] = 1
13 | cfg['frame_sum'] = 100
14 | cfg['resize'] = None
15 | cfg['interpolation'] = cv2.INTER_AREA
16 | cfg['input_type'] = 'noise' # 'mesh_grid'
17 | cfg['input_ratio'] = 0.1
18 | cfg['dilation_iter'] = 0
19 | cfg['input_noise_std'] = 0
20 |
21 | # Model
22 | cfg['net_type_G'] = '2d' # 3d
23 | cfg['net_type_L'] = 'VGG16'
24 | cfg['net_depth'] = 6
25 | cfg['input_channel'] = 1
26 | cfg['output_channel_img'] = 3
27 | cfg['num_channels_down'] = [16, 32, 64, 128, 128, 128]
28 | cfg['num_channels_up'] = [16, 32, 64, 128, 128, 128]
29 | cfg['num_channels_skip'] = [4, 4, 4, 4, 4, 4]
30 | cfg['filter_size_down'] = 5
31 | cfg['filter_size_up'] = 3
32 | cfg['filter_size_skip'] = 1
33 | cfg['use_skip'] = True
34 | cfg['dtype'] = torch.cuda.FloatTensor
35 |
36 | # Loss
37 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0}
38 | cfg['loss_recon'] = 'L2'
39 | cfg['perceptual_layers'] = ['3', '8', '15']
40 |
41 | # Optimize
42 | cfg['train_mode'] = 'DIP-Vid-Flow'
43 | cfg['baseline'] = True
44 | cfg['LR'] = 1e-2
45 | cfg['optimizer_G'] = 'Adam'
46 | cfg['fine_tune'] = True
47 | cfg['param_noise'] = True
48 | cfg['num_iter'] = 100
49 | cfg['num_pass'] = 20
50 |
51 | # Result
52 | cfg['save_every_iter'] = 100
53 | cfg['save_every_pass'] = 1
54 | cfg['plot'] = False
55 | cfg['save'] = True
56 | cfg['save_batch'] = False
57 | cfg['res_dir'] = None
58 |
59 | # Log
60 | cfg['logging_level'] = logging.INFO # logging.DEBUG
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | sys.path.append('src')
4 |
5 | from inpainting_test import InpaintingTest
6 | from configs.DIP import cfg as DIP_cfg
7 | from configs.DIP_Vid import cfg as DIP_Vid_cfg
8 | from configs.DIP_Vid_3DCN import cfg as DIP_Vid_3DCN_cfg
9 | from configs.DIP_Vid_Flow import cfg as DIP_Vid_Flow_cfg
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--train_mode', type=str, default='DIP_Vid_Flow', help='mode of the experiment: (DIP|DIP_Vid|DIP_Vid_3DCN|DIP_Vid_Flow)', metavar='')
16 | parser.add_argument('--resize', nargs='+', type=int, default=None, help='height and width of the output', metavar='')
17 | parser.add_argument('--video_path', type=str, default='data/bmx-trees.avi', help='path of the input video', metavar='')
18 | parser.add_argument('--mask_path', type=str, default='data/bmx-trees_mask.avi', help='path of the input mask', metavar='')
19 | parser.add_argument('--res_dir', type=str, default='result', help='path to save the result', metavar='')
20 | parser.add_argument('--frame_sum', type=int, default=100, help='number of frames to load', metavar='')
21 | parser.add_argument('--dilation_iter', type=int, default=0, help='number of steps to dilate the mask', metavar='')
22 |
23 | if len(sys.argv) == 1:
24 | parser.print_help()
25 | sys.exit(1)
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main(args):
31 | if args.train_mode == 'DIP':
32 | cfg = DIP_cfg
33 | elif args.train_mode == 'DIP-Vid':
34 | cfg = DIP_Vid_cfg
35 | elif args.train_mode == 'DIP-Vid-3DCN':
36 | cfg = DIP_Vid_3DCN_cfg
37 | elif args.train_mode == 'DIP-Vid-Flow':
38 | cfg = DIP_Vid_Flow_cfg
39 | else:
40 | raise Exception("Train mode {} not implemented!".format(args.train_mode))
41 |
42 | cfg['resize'] = tuple(args.resize)
43 | if len(cfg['resize']) != 2: raise Exception("Resize must be a tuple of length 2!")
44 | cfg['video_path'] = args.video_path
45 | cfg['mask_path'] = args.mask_path
46 | cfg['res_dir'] = args.res_dir
47 | cfg['frame_sum'] = args.frame_sum
48 | cfg['dilation_iter'] = args.dilation_iter
49 |
50 | test = InpaintingTest(cfg)
51 | test.create_data_loader()
52 | test.visualize_single_batch()
53 | test.create_model()
54 | test.create_optimizer()
55 | test.create_loss_function()
56 | test.train()
57 |
58 |
59 | if __name__ == '__main__':
60 | args = parse_args()
61 | main(args)
--------------------------------------------------------------------------------
/src/flow_estimator.py:
--------------------------------------------------------------------------------
1 | ############################################################
2 | # Code modified from https://github.com/sniklaus/pytorch-pwc
3 | ############################################################
4 |
5 | import math
6 | import torch
7 |
8 | from models.pwc_net import PWC_Net
9 |
10 |
11 | class FlowEstimator(object):
12 |
13 | def __init__(self):
14 | self.model_pwc = PWC_Net().type(torch.cuda.FloatTensor)
15 | self.model_pwc.load_state_dict(torch.load('pretrained_models/pwc_net.tar'))
16 |
17 |
18 | def estimate_flow_pair(self, tensorInputFirst, tensorInputSecond):
19 | ### tensor format
20 | # C x H x W
21 | # BGR
22 | # 0-1
23 | # FloatTensor.cuda
24 | ###
25 |
26 | moduleNetwork = self.model_pwc
27 | tensorOutput = torch.FloatTensor().cuda()
28 |
29 | assert(tensorInputFirst.size(1) == tensorInputSecond.size(1))
30 | assert(tensorInputFirst.size(2) == tensorInputSecond.size(2))
31 |
32 | intWidth = tensorInputFirst.size(2)
33 | intHeight = tensorInputFirst.size(1)
34 |
35 | # assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
36 | # assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
37 |
38 | if True:
39 | tensorPreprocessedFirst = tensorInputFirst.view(1, 3, intHeight, intWidth)
40 | tensorPreprocessedSecond = tensorInputSecond.view(1, 3, intHeight, intWidth)
41 |
42 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
43 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
44 |
45 | tensorPreprocessedFirst = torch.nn.functional.upsample(input=tensorPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
46 | tensorPreprocessedSecond = torch.nn.functional.upsample(input=tensorPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
47 |
48 | tensorFlow = 20.0 * torch.nn.functional.upsample(input=moduleNetwork(tensorPreprocessedFirst, tensorPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
49 |
50 | tensorFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
51 | tensorFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
52 |
53 | tensorOutput.resize_(2, intHeight, intWidth).copy_(tensorFlow[0, :, :, :])
54 | # end
55 |
56 | return tensorOutput # C x H x W
57 |
58 |
59 | def estimate_flow_batch(self, out_tensor):
60 | N, C, H, W = out_tensor.size()
61 | flow_tensor = torch.FloatTensor(N-1, 2, H, W)
62 | for i in range(N-1):
63 | first = out_tensor[i]
64 | second = out_tensor[i+1]
65 | flow_tensor[i] = estimate_flow_pair(first, second)
66 | return flow_tensor # N-1 x 2 x H x W
67 |
--------------------------------------------------------------------------------
/src/models/common.py:
--------------------------------------------------------------------------------
1 | ######################################################################
2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior
3 | ######################################################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import init
8 | import numpy as np
9 |
10 |
11 | def add_module(self, module):
12 | self.add_module(str(len(self) + 1), module)
13 | torch.nn.Module.add = add_module
14 |
15 |
16 | class Concat(nn.Module):
17 | """
18 | Concatenate the output of multiple nn modules
19 | """
20 | def __init__(self, dim, *args):
21 | super(Concat, self).__init__()
22 | self.dim = dim
23 |
24 | for idx, module in enumerate(args):
25 | self.add_module(str(idx), module)
26 |
27 | def forward(self, input):
28 | inputs = []
29 | for module in self._modules.values():
30 | inputs.append(module(input))
31 |
32 | inputs_shapes2 = [x.shape[2] for x in inputs]
33 | inputs_shapes3 = [x.shape[3] for x in inputs]
34 |
35 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
36 | inputs_ = inputs
37 | else:
38 | target_shape2 = min(inputs_shapes2)
39 | target_shape3 = min(inputs_shapes3)
40 |
41 | inputs_ = []
42 | for inp in inputs:
43 | diff2 = (inp.size(2) - target_shape2) // 2
44 | diff3 = (inp.size(3) - target_shape3) // 2
45 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
46 |
47 | return torch.cat(inputs_, dim=self.dim)
48 |
49 | def __len__(self):
50 | return len(self._modules)
51 |
52 |
53 | def act(act_fun = 'LeakyReLU'):
54 | """
55 | Return an activation function or module (e.g. nn.ReLU)
56 | """
57 | if isinstance(act_fun, str):
58 | if act_fun == 'LeakyReLU':
59 | return nn.LeakyReLU(0.2, inplace=True)
60 | elif act_fun == 'Swish':
61 | return Swish()
62 | elif act_fun == 'ReLU':
63 | return nn.ReLU()
64 | elif act_fun == 'none':
65 | return nn.Sequential()
66 | else:
67 | assert False
68 | else:
69 | return act_fun()
70 |
71 |
72 | def init_net(net, init_type='normal', gain=0.02):
73 | """
74 | Initialize the network parameters
75 | """
76 | def init_func(m):
77 | classname = m.__class__.__name__
78 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
79 | if init_type == 'normal':
80 | init.normal_(m.weight.data, 0.0, gain)
81 | elif init_type == 'xavier':
82 | init.xavier_normal_(m.weight.data, gain=gain)
83 | elif init_type == 'kaiming':
84 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
85 | elif init_type == 'orthogonal':
86 | init.orthogonal_(m.weight.data, gain=gain)
87 | else:
88 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
89 | if hasattr(m, 'bias') and m.bias is not None:
90 | init.constant_(m.bias.data, 0.0)
91 | elif classname.find('BatchNorm2d') != -1:
92 | init.normal_(m.weight.data, 1.0, gain)
93 | init.constant_(m.bias.data, 0.0)
94 |
95 | print('initialize network with %s' % init_type)
96 | net.apply(init_func)
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # An Internal Learning Approach to Video Inpainting
2 | Haotian zhang, Long Mai, Ning Xu, Zhaowen Wang, John Collomosse, Hailin Jin
3 |
4 | International Conference on Computer Vision (ICCV) 2019.
5 |
6 | [Paper](https://arxiv.org/abs/1909.07957) |
7 | [Project](https://cs.stanford.edu/~haotianz/research/video_inpainting/) |
8 | [Video](https://youtu.be/-MQZayP5tc0)
9 |
10 |
11 |
12 |
13 |
14 | ## Install
15 | The code has been tested on pytorch 1.0.0 with python 3.5 and cuda 9.0. Please refer to [requirements.txt](https://github.com/Haotianz94/IL_video_inpainting/blob/master/requirements.txt) for details. Alternatively, you can build a docker image using provided [Dockerfile](https://github.com/Haotianz94/IL_video_inpainting/blob/master/Dockerfile).
16 |
17 | Warning! We have noticed that the optimization may not converge on some GPUs when using pytorch==0.4.0. We have observed the issue on Titan V and Tesla V100. Therefore, we highly recommend upgrading your pytorch version above 1.0.0 to avoid the issue if you are training on these GPUs.
18 |
19 |
20 |
21 | ## Usage
22 | We provide two ways to test our video inpainting approach. Please first download the demo data from [here](https://drive.google.com/file/d/1MJDCjj1aIUbW0OK9UnewhXlkKX9zllQd/view?usp=sharing) into `data/` and download the pretrained model weights for PWC-Net from [here](https://drive.google.com/file/d/1XPaqITtUV11WpOpX1PeCkS4zdjI5tKb8/view?usp=sharing) (No need to unzip the weights file) into `pretrained_models/`. (The model was originally trained by Simon Niklaus from [here](https://github.com/sniklaus/pytorch-pwc)).
23 |
24 | * To run our demo, please run the following command:
25 | ```
26 | python3 train.py --train_mode DIP-Vid-Flow --video_path data/bmx-trees.avi --mask_path data/bmx-trees_mask.avi --resize 192 384 --res_dir result/DIP_Vid_Flow
27 | ```
28 |
29 | * Alternatively, you can run through our demo step by step using the provided jupyter notebook [demo.ipynb](https://github.com/Haotianz94/IL_video_inpainting/blob/master/demo.ipynb)
30 |
31 |
32 | ## Citation
33 | ```
34 | @inproceedings{zhang2019internal,
35 | title={An Internal Learning Approach to Video Inpainting},
36 | author={Zhang, Haotian and Mai, Long and Xu, Ning and Wang, Zhaowen and Collomosse, John and Jin, Hailin},
37 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
38 | pages={2720--2729},
39 | year={2019}
40 | }
41 | ```
42 |
43 |
44 | ## License
45 | © 2019 Adobe.
46 |
47 | Adobe holds the copyright for all the files found in this repository.
48 |
49 | IL_video_inpainting is a project by Adobe Research. It is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) and may only be used for non-commercial purposes. See the LICENSE file for more information.
50 |
51 | Note We are not able to provide access to our Composed Dataset because we could not get copyright permission from the authors of the videos.
52 |
53 | ## Acknowledgement
54 | The implementation of our network architecture is mostly borrowed from the Deep Image Prior [repo](https://github.com/DmitryUlyanov/deep-image-prior). The implementation of the PWC-Net is borrowed from this [repo](https://github.com/sniklaus/pytorch-pwc). Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors.
55 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "scrolled": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import sys\n",
12 | "sys.path.append('src')\n",
13 | "from inpainting_test import InpaintingTest\n",
14 | "from configs.DIP import cfg as DIP_cfg\n",
15 | "from configs.DIP_Vid import cfg as DIP_Vid_cfg\n",
16 | "from configs.DIP_Vid_3DCN import cfg as DIP_Vid_3DCN_cfg\n",
17 | "from configs.DIP_Vid_Flow import cfg as DIP_Vid_Flow_cfg"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# Choose training mode (DIP|DIP-Vid|DIP-Vid-3DCN|DIP-Vid-Flow)\n",
27 | "train_mode = 'DIP-Vid-Flow'"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "# Set up config\n",
37 | "if train_mode == 'DIP':\n",
38 | " cfg = DIP_cfg\n",
39 | "elif train_mode == 'DIP-Vid':\n",
40 | " cfg = DIP_Vid_cfg\n",
41 | "elif train_mode == 'DIP-Vid-3DCN':\n",
42 | " cfg = DIP_Vid_3DCN_cfg\n",
43 | "elif train_mode == 'DIP-Vid-Flow':\n",
44 | " cfg = DIP_Vid_Flow_cfg\n",
45 | " \n",
46 | "cfg['video_path'] = 'data/bmx-trees.avi'\n",
47 | "cfg['mask_path'] = 'data/bmx-trees_mask.avi'\n",
48 | "cfg['save_every_iter'] = 100\n",
49 | "cfg['resize'] = (192, 384)\n",
50 | "cfg['plot'] = True\n",
51 | "cfg['res_dir'] = 'result/test/DIP_Vid_Flow'"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {
58 | "scrolled": true
59 | },
60 | "outputs": [],
61 | "source": [
62 | "test = InpaintingTest(cfg)"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {
69 | "scrolled": true
70 | },
71 | "outputs": [],
72 | "source": [
73 | "test.create_data_loader()"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "scrolled": true
81 | },
82 | "outputs": [],
83 | "source": [
84 | "test.visualize_single_batch()"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "test.create_model()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "test.create_optimizer()"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "test.create_loss_function()"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {
118 | "scrolled": true
119 | },
120 | "outputs": [],
121 | "source": [
122 | "test.train()"
123 | ]
124 | }
125 | ],
126 | "metadata": {
127 | "kernelspec": {
128 | "display_name": "Python 3",
129 | "language": "python",
130 | "name": "python3"
131 | },
132 | "language_info": {
133 | "codemirror_mode": {
134 | "name": "ipython",
135 | "version": 3
136 | },
137 | "file_extension": ".py",
138 | "mimetype": "text/x-python",
139 | "name": "python",
140 | "nbconvert_exporter": "python",
141 | "pygments_lexer": "ipython3",
142 | "version": "3.5.2"
143 | }
144 | },
145 | "nbformat": 4,
146 | "nbformat_minor": 2
147 | }
148 |
--------------------------------------------------------------------------------
/src/models/encoder_decoder_2d.py:
--------------------------------------------------------------------------------
1 | ######################################################################
2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior
3 | ######################################################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | from .common import *
8 |
9 |
10 | def EncoderDecoder2D(
11 | num_input_channels=2, num_output_channels=3,
12 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4],
13 | filter_size_down=3, filter_size_up=3, filter_size_skip=1,
14 | upsample_mode='nearest', downsample_mode='stride',
15 | need_sigmoid=True, need_bias=True, need1x1_up=True,
16 | pad='zero', act_fun='LeakyReLU'):
17 | """Assembles encoder-decoder with skip connections, using 2D convolutions.
18 |
19 | Arguments:
20 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
21 | pad (string): zero|reflection (default: 'zero')
22 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
23 | downsample_mode (string): 'stride|avg|max' (default: 'stride')
24 | """
25 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)
26 |
27 | n_scales = len(num_channels_down)
28 |
29 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) :
30 | upsample_mode = [upsample_mode]*n_scales
31 |
32 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)):
33 | downsample_mode = [downsample_mode]*n_scales
34 |
35 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) :
36 | filter_size_down = [filter_size_down]*n_scales
37 |
38 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
39 | filter_size_up = [filter_size_up]*n_scales
40 |
41 | last_scale = n_scales - 1
42 |
43 | cur_depth = None
44 |
45 | model = nn.Sequential()
46 | model_tmp = model
47 |
48 | input_depth = num_input_channels
49 | for i in range(len(num_channels_down)):
50 |
51 | deeper = nn.Sequential()
52 | skip = nn.Sequential()
53 |
54 | if num_channels_skip[i] != 0:
55 | model_tmp.add(Concat(1, skip, deeper))
56 | else:
57 | model_tmp.add(deeper)
58 |
59 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])))
60 |
61 | if num_channels_skip[i] != 0:
62 | skip.add(conv(input_depth, num_channels_skip[i], filter_size_skip, bias=need_bias, pad=pad))
63 | skip.add(bn(num_channels_skip[i]))
64 | skip.add(act(act_fun))
65 |
66 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], stride=2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]))
67 | deeper.add(bn(num_channels_down[i]))
68 | deeper.add(act(act_fun))
69 |
70 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad))
71 | deeper.add(bn(num_channels_down[i]))
72 | deeper.add(act(act_fun))
73 |
74 | deeper_main = nn.Sequential()
75 |
76 | if i == len(num_channels_down) - 1:
77 | # The deepest
78 | k = num_channels_down[i]
79 | else:
80 | deeper.add(deeper_main)
81 | k = num_channels_up[i + 1]
82 |
83 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))
84 |
85 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad))
86 | model_tmp.add(bn(num_channels_up[i]))
87 | model_tmp.add(act(act_fun))
88 |
89 | if need1x1_up:
90 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad))
91 | model_tmp.add(bn(num_channels_up[i]))
92 | model_tmp.add(act(act_fun))
93 |
94 | input_depth = num_channels_down[i]
95 | model_tmp = deeper_main
96 |
97 | model.add(FinalLayer(num_channels_up[0], num_output_channels, need_bias, pad, need_sigmoid))
98 | return model
99 |
100 |
101 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
102 | downsampler = None
103 | if stride != 1 and downsample_mode != 'stride':
104 |
105 | if downsample_mode == 'avg':
106 | downsampler = nn.AvgPool2d(stride, stride)
107 | elif downsample_mode == 'max':
108 | downsampler = nn.MaxPool2d(stride, stride)
109 | else:
110 | assert False
111 |
112 | stride = 1
113 |
114 | padder = None
115 | to_pad = int((kernel_size - 1) / 2)
116 | if pad == 'reflection':
117 | padder = nn.ReflectionPad2d(to_pad)
118 | to_pad = 0
119 |
120 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)
121 |
122 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
123 | return nn.Sequential(*layers)
124 |
125 |
126 | def bn(num_features):
127 | return nn.BatchNorm2d(num_features)
128 |
129 |
130 | class FinalLayer(nn.Module):
131 | """
132 | Split output into image and flow branch
133 | """
134 | def __init__(self, Cin, Cout, need_bias, pad, need_sigmoid):
135 | super(FinalLayer, self).__init__()
136 | self.conv_img = conv(Cin, 3, 1, bias=need_bias, pad=pad)
137 | if need_sigmoid:
138 | self.sigmoid = nn.Sigmoid()
139 | else:
140 | self.sigmoid = None
141 | if Cout > 3:
142 | self.conv_flow = conv(Cin, Cout-3, 1, bias=need_bias, pad=pad)
143 | else:
144 | self.conv_flow = None
145 |
146 | def forward(self, x):
147 | y_img = self.conv_img(x)
148 | if not self.sigmoid is None:
149 | y_img = self.sigmoid(y_img)
150 |
151 | if not self.conv_flow is None:
152 | y_flow = self.conv_flow(x)
153 | return torch.cat((y_img, y_flow), 1)
154 | else:
155 | return y_img
--------------------------------------------------------------------------------
/src/models/encoder_decoder_3d.py:
--------------------------------------------------------------------------------
1 | ######################################################################
2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior
3 | ######################################################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | from .common import *
8 |
9 |
10 | def EncoderDecoder3D(
11 | num_input_channels=1, num_output_channels=3,
12 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4],
13 | filter_size_down=(3, 3, 3), filter_size_up=(3, 3, 3), filter_size_skip=(1, 1, 1),
14 | upsample_mode='nearest', downsample_mode='stride',
15 | need_sigmoid=True, need_bias=True, need1x1_up=True,
16 | pad='zero', act_fun='LeakyReLU',
17 | ):
18 |
19 | """Assembles encoder-decoder with skip connections, using 3D convolutions.
20 |
21 | Arguments:
22 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
23 | pad (string): zero|reflection (default: 'zero')
24 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
25 | downsample_mode (string): 'stride|avg|max' (default: 'stride')
26 | """
27 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)
28 |
29 | n_scales = len(num_channels_down)
30 |
31 | if not isinstance(upsample_mode, list):
32 | upsample_mode = [upsample_mode]*n_scales
33 |
34 | if not isinstance(downsample_mode, list):
35 | downsample_mode = [downsample_mode]*n_scales
36 |
37 | if not isinstance(filter_size_down, list):
38 | filter_size_down = [filter_size_down]*n_scales
39 |
40 | if not isinstance(filter_size_up, list):
41 | filter_size_up = [filter_size_up]*n_scales
42 |
43 | last_scale = n_scales - 1
44 |
45 | cur_depth = None
46 |
47 | model = nn.Sequential()
48 | model_tmp = model
49 |
50 | input_depth = num_input_channels
51 | for i in range(len(num_channels_down)):
52 |
53 | deeper = nn.Sequential()
54 | skip = nn.Sequential()
55 |
56 | if num_channels_skip[i] != 0:
57 | model_tmp.add(Concat(1, skip, deeper))
58 | else:
59 | model_tmp.add(deeper)
60 |
61 | bn_up = BatchNorm3D(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))
62 | model_tmp.add(bn_up)
63 |
64 | if num_channels_skip[i] != 0:
65 | conv_skip = conv3d(input_depth, num_channels_skip[i], filter_size_skip, bias=need_bias, pad=pad)
66 | bn_skip = BatchNorm3D(num_channels_skip[i])
67 | skip.add(conv_skip)
68 | skip.add(bn_skip)
69 | skip.add(act(act_fun))
70 |
71 | conv_down = conv3d(input_depth, num_channels_down[i], filter_size_down[i], stride=(1, 2, 2), bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])
72 | bn_down = BatchNorm3D(num_channels_down[i])
73 | deeper.add(conv_down)
74 | deeper.add(bn_down)
75 | deeper.add(act(act_fun))
76 |
77 | conv_down = conv3d(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)
78 | bn_down = BatchNorm3D(num_channels_down[i])
79 | deeper.add(conv_down)
80 | deeper.add(bn_down)
81 | deeper.add(act(act_fun))
82 |
83 | deeper_main = nn.Sequential()
84 |
85 | if i == len(num_channels_down) - 1:
86 | # The deepest
87 | k = num_channels_down[i]
88 | else:
89 | deeper.add(deeper_main)
90 | k = num_channels_up[i + 1]
91 |
92 | deeper.add(Upsample3D(scale_factor=2, mode=upsample_mode[i]))
93 |
94 | conv_up = conv3d(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], bias=need_bias, pad=pad)
95 | bn_up = BatchNorm3D(num_channels_up[i])
96 | model_tmp.add(conv_up)
97 | model_tmp.add(bn_up)
98 | model_tmp.add(act(act_fun))
99 |
100 | if need1x1_up:
101 | conv_up = conv3d(num_channels_up[i], num_channels_up[i], kernel_size=(1,1,1), bias=need_bias, pad=pad)
102 | bn_up = BatchNorm3D(num_channels_up[i])
103 | model_tmp.add(conv_up)
104 | model_tmp.add(bn_up)
105 | model_tmp.add(act(act_fun))
106 |
107 | input_depth = num_channels_down[i]
108 | model_tmp = deeper_main
109 |
110 | conv_final = conv3d(num_channels_up[0], num_output_channels, kernel_size=(1,1,1), bias=need_bias, pad=pad)
111 | model.add(conv_final)
112 | if need_sigmoid:
113 | model.add(nn.Sigmoid())
114 |
115 | return model
116 |
117 |
118 | def conv3d(in_channels, out_channels, kernel_size=(1,1,1), stride=(1,1,1), bias=True, pad='zero', downsample_mode='stride'):
119 | downsampler = None
120 | if stride != 1 and downsample_mode != 'stride':
121 |
122 | if downsample_mode == 'avg':
123 | downsampler = nn.AvgPool2d(stride, stride)
124 | elif downsample_mode == 'max':
125 | downsampler = nn.MaxPool2d(stride, stride)
126 | elif downsample_mode in ['lanczos2', 'lanczos3']:
127 | downsampler = Downsampler(n_planes=out_channels, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
128 | else:
129 | assert False
130 |
131 | stride = 1
132 |
133 | padder = None
134 | pad_D = int((kernel_size[0] - 1) / 2)
135 | pad_HW = int((kernel_size[1] - 1) / 2)
136 | to_pad = (pad_D, pad_HW, pad_HW)
137 | if pad == 'reflection':
138 | padder = ReflectionPad3D(pad_D, pad_HW)
139 | to_pad = (0, 0, 0)
140 |
141 | convolver = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=to_pad, bias=bias)
142 |
143 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
144 | return nn.Sequential(*layers)
145 |
146 |
147 | class BatchNorm3D(nn.Module):
148 | def __init__(self, num_features):
149 | super(BatchNorm3D, self).__init__()
150 | self.bn = nn.BatchNorm2d(num_features)
151 |
152 | def forward(self, x):
153 | assert(x.size(0) == 1) # 1 x C x D x H x W
154 | y = x.squeeze(0).transpose(0, 1).contiguous() # D x C x H x W
155 | y = self.bn(y)
156 | y = y.transpose(0, 1).unsqueeze(0)
157 | return y
158 |
159 |
160 | class Upsample3D(nn.Module):
161 | def __init__(self, scale_factor, mode):
162 | super(Upsample3D, self).__init__()
163 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode)
164 |
165 | def forward(self, x):
166 | assert(x.size(0) == 1) # 1 x C x D x H x W
167 | y = x.squeeze(0).transpose(0, 1) # D x C x H x W
168 | y = self.upsample(y)
169 | return y.transpose(0, 1).unsqueeze(0)
170 |
171 |
172 | class ReflectionPad3D(nn.Module):
173 | def __init__(self, pad_D, pad_HW):
174 | super(ReflectionPad3D, self).__init__()
175 | self.padder_HW = nn.ReflectionPad2d(pad_HW)
176 | self.padder_D = nn.ReplicationPad3d((0, 0, 0, 0, pad_D, pad_D))
177 |
178 | def forward(self, x):
179 | assert(x.size(0) == 1) # 1 x C x D x H x W
180 | y = x.squeeze(0).transpose(0, 1) # D x C x H x W
181 | y = self.padder_HW(y)
182 | y = y.transpose(0, 1).unsqueeze(0) # 1 x C x D x H x W
183 | return self.padder_D(y)
--------------------------------------------------------------------------------
/src/correlation/correlation.py:
--------------------------------------------------------------------------------
1 | ############################################################
2 | # Code modified from https://github.com/sniklaus/pytorch-pwc
3 | ############################################################
4 |
5 | import cupy
6 | import torch
7 |
8 | kernel_Correlation_rearrange = '''
9 | extern "C" __global__ void kernel_Correlation_rearrange(
10 | const int n,
11 | const float* input,
12 | float* output
13 | ) {
14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
15 |
16 | if (intIndex >= n) {
17 | return;
18 | }
19 |
20 | int intSample = blockIdx.z;
21 | int intChannel = blockIdx.y;
22 |
23 | float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
24 |
25 | __syncthreads();
26 |
27 | int intPaddedY = (intIndex / SIZE_3(input)) + 4;
28 | int intPaddedX = (intIndex % SIZE_3(input)) + 4;
29 | int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
30 |
31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
32 | }
33 | '''
34 |
35 | kernel_Correlation_updateOutput = '''
36 | extern "C" __global__ void kernel_Correlation_updateOutput(
37 | const int n,
38 | const float* rbot0,
39 | const float* rbot1,
40 | float* top
41 | ) {
42 | extern __shared__ char patch_data_char[];
43 |
44 | float *patch_data = (float *)patch_data_char;
45 |
46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
47 | int x1 = blockIdx.x + 4;
48 | int y1 = blockIdx.y + 4;
49 | int item = blockIdx.z;
50 | int ch_off = threadIdx.x;
51 |
52 | // Load 3D patch into shared shared memory
53 | for (int j = 0; j < 1; j++) { // HEIGHT
54 | for (int i = 0; i < 1; i++) { // WIDTH
55 | int ji_off = ((j * 1) + i) * SIZE_3(rbot0);
56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
58 | int idxPatchData = ji_off + ch;
59 | patch_data[idxPatchData] = rbot0[idx1];
60 | }
61 | }
62 | }
63 |
64 | __syncthreads();
65 |
66 | __shared__ float sum[32];
67 |
68 | // Compute correlation
69 | for(int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
70 | sum[ch_off] = 0;
71 |
72 | int s2o = (top_channel % 9) - 4;
73 | int s2p = (top_channel / 9) - 4;
74 |
75 | for (int j = 0; j < 1; j++) { // HEIGHT
76 | for (int i = 0; i < 1; i++) { // WIDTH
77 | int ji_off = ((j * 1) + i) * SIZE_3(rbot0);
78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
79 | int x2 = x1 + s2o;
80 | int y2 = y1 + s2p;
81 |
82 | int idxPatchData = ji_off + ch;
83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
84 |
85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
86 | }
87 | }
88 | }
89 |
90 | __syncthreads();
91 |
92 | if (ch_off == 0) {
93 | float total_sum = 0;
94 | for (int idx = 0; idx < 32; idx++) {
95 | total_sum += sum[idx];
96 | }
97 | const int sumelems = SIZE_3(rbot0);
98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
100 | }
101 | }
102 | }
103 | '''
104 |
105 | def cupy_kernel(strFunction, objectVariables):
106 | strKernel = globals()[strFunction]
107 |
108 | for strVariable in objectVariables:
109 | strKernel = strKernel.replace('SIZE_0(' + strVariable + ')', str(objectVariables[strVariable].size(0)))
110 | strKernel = strKernel.replace('SIZE_1(' + strVariable + ')', str(objectVariables[strVariable].size(1)))
111 | strKernel = strKernel.replace('SIZE_2(' + strVariable + ')', str(objectVariables[strVariable].size(2)))
112 | strKernel = strKernel.replace('SIZE_3(' + strVariable + ')', str(objectVariables[strVariable].size(3)))
113 | # end
114 |
115 | return strKernel
116 | # end
117 |
118 | @cupy.util.memoize(for_each_device=True)
119 | def cupy_launch(strFunction, strKernel):
120 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
121 | # end
122 |
123 | class FunctionCorrelation(torch.autograd.Function):
124 | @staticmethod
125 | def forward(ctx, first, second):
126 | ctx.save_for_backward(first, second)
127 |
128 | assert(first.is_contiguous() == True)
129 | assert(second.is_contiguous() == True)
130 |
131 | rbot0 = first.new(first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)).zero_()
132 | rbot1 = first.new(first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)).zero_()
133 |
134 | output = first.new(first.size(0), 81, first.size(2), first.size(3)).zero_()
135 |
136 | if first.is_cuda == True:
137 | class Stream:
138 | ptr = torch.cuda.current_stream().cuda_stream
139 | # end
140 |
141 | n = first.size(2) * first.size(3)
142 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
143 | 'input': first,
144 | 'output': rbot0
145 | }))(
146 | grid=tuple([ int((n + 16 - 1) / 16), first.size(1), first.size(0) ]),
147 | block=tuple([ 16, 1, 1 ]),
148 | args=[ n, first.data_ptr(), rbot0.data_ptr() ],
149 | stream=Stream
150 | )
151 |
152 | n = second.size(2) * second.size(3)
153 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
154 | 'input': second,
155 | 'output': rbot1
156 | }))(
157 | grid=tuple([ int((n + 16 - 1) / 16), second.size(1), second.size(0) ]),
158 | block=tuple([ 16, 1, 1 ]),
159 | args=[ n, second.data_ptr(), rbot1.data_ptr() ],
160 | stream=Stream
161 | )
162 |
163 | n = output.size(1) * output.size(2) * output.size(3)
164 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
165 | 'rbot0': rbot0,
166 | 'rbot1': rbot1,
167 | 'top': output
168 | }))(
169 | grid=tuple([ first.size(3), first.size(2), first.size(0) ]),
170 | block=tuple([ 32, 1, 1 ]),
171 | shared_mem=first.size(1) * 4,
172 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ],
173 | stream=Stream
174 | )
175 |
176 | elif first.is_cuda == False:
177 | raise NotImplementedError()
178 |
179 | # end
180 |
181 | return output
182 | # end
183 |
184 | @staticmethod
185 | def backward(ctx, gradOutput):
186 | first, second = ctx.saved_tensors
187 |
188 | assert(gradOutput.is_contiguous() == True)
189 |
190 | gradFirst = first.new(first.size()).zero_() if ctx.needs_input_grad[0] == True else None
191 | gradSecond = first.new(first.size()).zero_() if ctx.needs_input_grad[1] == True else None
192 |
193 | if first.is_cuda == True:
194 | raise NotImplementedError()
195 |
196 | elif first.is_cuda == False:
197 | raise NotImplementedError()
198 |
199 | # end
200 |
201 | return gradFirst, gradSecond
202 | # end
203 | # end
204 |
205 | class ModuleCorrelation(torch.nn.Module):
206 | def __init__(self):
207 | super(ModuleCorrelation, self).__init__()
208 | # end
209 |
210 | def forward(self, tensorFirst, tensorSecond):
211 | correlation = FunctionCorrelation.apply
212 | return correlation(tensorFirst, tensorSecond)
213 | # end
214 | # end
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | ######################################################################
2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior
3 | ######################################################################
4 |
5 | import torch
6 | import torchvision
7 | import numpy as np
8 | import os
9 | import cv2
10 | import matplotlib.pyplot as plt
11 |
12 |
13 | ######################################################################
14 | # Network input
15 | ######################################################################
16 |
17 | def fill_noise(x, noise_type):
18 | """
19 | Fill tensor `x` with noise of type `noise_type`.
20 | """
21 | if noise_type == 'u':
22 | x.uniform_()
23 | elif noise_type == 'n':
24 | x.normal_()
25 | else:
26 | assert False
27 |
28 |
29 | def get_noise(batch_size, input_depth, method, spatial_size, noise_type='u', var=1./10):
30 | """
31 | Return a pytorch.Tensor of size (`batch_size` x `input_depth` x `spatial_size[0]` x `spatial_size[1]`)
32 | initialized in a specific way.
33 |
34 | Args:
35 | input_depth: number of channels in the tensor
36 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid
37 | spatial_size: spatial size of the tensor to initialize
38 | noise_type: 'u' for uniform; 'n' for normal
39 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler.
40 | """
41 | if isinstance(spatial_size, int):
42 | spatial_size = (spatial_size, spatial_size)
43 | if method == 'noise':
44 | shape = [batch_size, input_depth, spatial_size[0], spatial_size[1]]
45 | net_input = torch.zeros(shape)
46 |
47 | fill_noise(net_input, noise_type)
48 | net_input *= var
49 | elif method == 'meshgrid':
50 | assert batch_size == 1
51 | assert input_depth == 2
52 | X, Y = np.meshgrid(np.arange(0, spatial_size[1]) / float(spatial_size[1] - 1), np.arange(0, spatial_size[0]) / float(spatial_size[0] - 1))
53 | meshgrid = np.concatenate([X[None, :], Y[None, :]])
54 | net_input= np_to_torch(meshgrid)
55 | else:
56 | assert False
57 |
58 | return net_input
59 |
60 |
61 | ######################################################################
62 | # Flow related
63 | ######################################################################
64 |
65 | def warp_torch(x, flo):
66 | """
67 | Backward warp an image tensor (im2) to im1, according to the optical flow from im1 to im2
68 |
69 | x: [B, C, H, W] (im2)
70 | flo: [B, 2, H, W] flow
71 | """
72 | B, C, H, W = x.size()
73 | # Mesh grid
74 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
75 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
76 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
77 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
78 | grid = torch.cat((xx, yy), 1).float()
79 |
80 | if x.is_cuda:
81 | grid = grid.cuda()
82 | vgrid_ori = grid + flo
83 | vgrid = torch.zeros_like(vgrid_ori)
84 |
85 | # Scale grid to [-1,1]
86 | vgrid[:, 0, :, :] = 2.0 * vgrid_ori[:, 0, :, :] / max(W - 1, 1) - 1.0
87 | vgrid[:, 1, :, :] = 2.0 * vgrid_ori[:, 1, :, :] / max(H - 1, 1) - 1.0
88 |
89 | vgrid = vgrid.permute(0, 2, 3, 1)
90 | output = torch.nn.functional.grid_sample(x, vgrid)
91 | mask = torch.ones(x.size())
92 | if x.is_cuda:
93 | mask = mask.cuda()
94 | mask = torch.nn.functional.grid_sample(mask, vgrid)
95 |
96 | mask[mask < 0.999] = 0
97 | mask[mask > 0] = 1
98 |
99 | return output, mask
100 |
101 |
102 | def warp_np(x, flo):
103 | """
104 | Backward warp an image numpy array (im2) to im1, according to the optical flow from im1 to im2
105 |
106 | x: [B, C, H, W] (im2)
107 | flo: [B, 2, H, W] flow
108 | """
109 | if x.ndim != 4:
110 | assert(x.ndim == 3)
111 | # Add one dimention for single image
112 | x = x[None, ...]
113 | flo = flo[None, ...]
114 | add_dim = True
115 | else:
116 | add_dim = False
117 |
118 | output, mask = warp_torch(np_to_torch(x), np_to_torch(flo))
119 | if add_dim:
120 | return output.numpy()[0], mask.numpy()[0]
121 | else:
122 | return output.numpy(), mask.numpy()
123 |
124 |
125 | def check_flow_occlusion(flow_f, flow_b):
126 | """
127 | Compute occlusion map through forward/backward flow consistency check
128 | """
129 | def get_occlusion(flow1, flow2):
130 | grid_flow = grid + flow1
131 | grid_flow[0, :, :] = 2.0 * grid_flow[0, :, :] / max(W - 1, 1) - 1.0
132 | grid_flow[1, :, :] = 2.0 * grid_flow[1, :, :] / max(H - 1, 1) - 1.0
133 | grid_flow = grid_flow.permute(1, 2, 0)
134 | flow2_inter = torch.nn.functional.grid_sample(flow2[None, ...], grid_flow[None, ...])[0]
135 | score = torch.exp(- torch.sum((flow1 + flow2_inter) ** 2, dim=0) / 2.)
136 | occlusion = (score > 0.5)
137 | return occlusion[None, ...].float()
138 |
139 | C, H, W = flow_f.size()
140 | # Mesh grid
141 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
142 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
143 | xx = xx.view(1, H, W)
144 | yy = yy.view(1, H, W)
145 | grid = torch.cat((xx, yy), 0).float()
146 |
147 | occlusion_f = get_occlusion(flow_f, flow_b)
148 | occlusion_b = get_occlusion(flow_b, flow_f)
149 | flow_f = torch.cat((flow_f, occlusion_f), 0)
150 | flow_b = torch.cat((flow_b, occlusion_b), 0)
151 |
152 | return flow_f, flow_b
153 |
154 |
155 | ######################################################################
156 | # Visualization
157 | ######################################################################
158 |
159 | def get_image_grid(images_np, nrow=8, padding=2):
160 | """
161 | Create a grid from a list of images by concatenating them.
162 | """
163 | images_torch = [np_to_torch(x) for x in images_np]
164 | torch_grid = torchvision.utils.make_grid(images_torch, nrow, padding)
165 |
166 | return torch_grid.numpy()
167 |
168 |
169 | def plot_image_grid(images_np, nrow=8, padding=2, factor=1, interpolation='lanczos'):
170 | """
171 | Layout images in a grid
172 |
173 | Args:
174 | images_np: list of images, each image is np.array of size 3xHxW of 1xHxW
175 | nrow: how many images will be in one row
176 | factor: size if the plt.figure
177 | interpolation: interpolation used in plt.imshow
178 | """
179 | n_channels = max(x.shape[0] for x in images_np)
180 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"
181 |
182 | images_np = [np_cvt_color(x) if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
183 |
184 | grid = get_image_grid(images_np, nrow, padding)
185 |
186 | plt.figure(figsize=(len(images_np) + factor, 12 + factor))
187 | plt.axis('off')
188 |
189 | if images_np[0].shape[0] == 1:
190 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
191 | else:
192 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)
193 |
194 | plt.show()
195 |
196 | return grid
197 |
198 |
199 | ######################################################################
200 | # Data type transform
201 | ######################################################################
202 |
203 | def np_cvt_color(img_np):
204 | """
205 | Convert image from BGR/RGB to RGb/BGR
206 | From B x C x W x H to B x C x W x H
207 | """
208 | if len(img_np) == 4:
209 | return [img[::-1] for img in img_np]
210 | else:
211 | return img_np[::-1]
212 |
213 |
214 | def np_to_cv2(img_np):
215 | """
216 | Convert image in numpy.array to cv2 image.
217 | From C x W x H [0..1] to W x H x C [0...255]
218 | """
219 | return np.clip(img_np.transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)
220 |
221 |
222 | def np_to_torch(img_np):
223 | """
224 | Convert image in numpy.array to torch.Tensor.
225 | From B x C x W x H [0..1] to B x C x W x H [0..1]
226 | """
227 | return torch.from_numpy(np.ascontiguousarray(img_np))
228 |
229 |
230 | def torch_to_np(img_var):
231 | """
232 | Convert an image in torch.Tensor format to numpy.array.
233 | From B x C x W x H [0..1] to B x C x W x H [0..1]
234 | """
235 | return img_var.detach().cpu().numpy()
236 |
237 |
238 | ######################################################################
239 | # Others
240 | ######################################################################
241 |
242 | def mkdir(dir):
243 | os.makedirs(dir, exist_ok=True)
244 | os.chmod(dir, 0o777)
245 |
246 |
247 | def build_dir(res_dir, subpath):
248 | mkdir(os.path.join(res_dir, subpath))
249 | res_type_list = ['stitch', 'full_with_boundary', 'full']
250 | for res_type in res_type_list:
251 | sub_res_dir = os.path.join(res_dir, subpath, res_type)
252 | mkdir(sub_res_dir)
253 | mkdir(os.path.join(sub_res_dir, 'sequence'))
254 | mkdir(os.path.join(sub_res_dir, 'batch'))
255 |
256 |
257 | def get_model_num_parameters(model):
258 | """
259 | Return total number of parameters in model
260 | """
261 | total_num=0
262 | if type(model) == type(dict()):
263 | for key in model:
264 | for p in model[key].parameters():
265 | total_num+=p.nelement()
266 | else:
267 | for p in model.parameters():
268 | total_num+=p.nelement()
269 | return total_num
--------------------------------------------------------------------------------
/src/models/pwc_net.py:
--------------------------------------------------------------------------------
1 | ############################################################
2 | # Code modified from https://github.com/sniklaus/pytorch-pwc
3 | ############################################################
4 |
5 | import torch
6 | from correlation import correlation
7 |
8 |
9 | class PWC_Net(torch.nn.Module):
10 | def __init__(self):
11 | super(PWC_Net, self).__init__()
12 |
13 | class Extractor(torch.nn.Module):
14 | def __init__(self):
15 | super(Extractor, self).__init__()
16 |
17 | self.moduleOne = torch.nn.Sequential(
18 | torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
19 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
20 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
21 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
22 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
23 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
24 | )
25 |
26 | self.moduleTwo = torch.nn.Sequential(
27 | torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
28 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
29 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
30 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
31 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
32 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
33 | )
34 |
35 | self.moduleThr = torch.nn.Sequential(
36 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
37 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
38 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
39 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
40 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
41 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
42 | )
43 |
44 | self.moduleFou = torch.nn.Sequential(
45 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
46 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
47 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
48 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
49 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
50 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
51 | )
52 |
53 | self.moduleFiv = torch.nn.Sequential(
54 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
55 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
56 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
57 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
58 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
59 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
60 | )
61 |
62 | self.moduleSix = torch.nn.Sequential(
63 | torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
64 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
65 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
66 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
67 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
68 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
69 | )
70 | # end
71 |
72 | def forward(self, tensorInput):
73 | tensorOne = self.moduleOne(tensorInput)
74 | tensorTwo = self.moduleTwo(tensorOne)
75 | tensorThr = self.moduleThr(tensorTwo)
76 | tensorFou = self.moduleFou(tensorThr)
77 | tensorFiv = self.moduleFiv(tensorFou)
78 | tensorSix = self.moduleSix(tensorFiv)
79 |
80 | return [ tensorOne, tensorTwo, tensorThr, tensorFou, tensorFiv, tensorSix ]
81 | # end
82 | # end
83 |
84 | class Backward(torch.nn.Module):
85 | def __init__(self):
86 | super(Backward, self).__init__()
87 | # end
88 |
89 | def forward(self, tensorInput, tensorFlow):
90 | if hasattr(self, 'tensorPartial') == False or self.tensorPartial.size(0) != tensorFlow.size(0) or self.tensorPartial.size(2) != tensorFlow.size(2) or self.tensorPartial.size(3) != tensorFlow.size(3):
91 | self.tensorPartial = torch.FloatTensor().resize_(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3)).fill_(1.0).cuda()
92 | # end
93 |
94 | if hasattr(self, 'tensorGrid') == False or self.tensorGrid.size(0) != tensorFlow.size(0) or self.tensorGrid.size(2) != tensorFlow.size(2) or self.tensorGrid.size(3) != tensorFlow.size(3):
95 | torchHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3))
96 | torchVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3))
97 |
98 | self.tensorGrid = torch.cat([ torchHorizontal, torchVertical ], 1).cuda()
99 | # end
100 |
101 | tensorInput = torch.cat([ tensorInput, self.tensorPartial ], 1)
102 | tensorFlow = torch.cat([ tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0) ], 1)
103 |
104 | tensorOutput = torch.nn.functional.grid_sample(input=tensorInput, grid=(self.tensorGrid + tensorFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros')
105 |
106 | tensorMask = tensorOutput[:, -1:, :, :]; tensorMask[tensorMask > 0.999] = 1.0; tensorMask[tensorMask < 1.0] = 0.0
107 |
108 | return tensorOutput[:, :-1, :, :] * tensorMask
109 | # end
110 | # end
111 |
112 | class Decoder(torch.nn.Module):
113 | def __init__(self, intLevel):
114 | super(Decoder, self).__init__()
115 |
116 | intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]
117 | intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]
118 |
119 | if intLevel < 6: self.moduleUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
120 | if intLevel < 6: self.moduleUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
121 |
122 | if intLevel < 6: self.dblBackward = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]
123 | if intLevel < 6: self.moduleBackward = Backward()
124 |
125 | self.moduleCorrelation = correlation.ModuleCorrelation()
126 | self.moduleCorreleaky = torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
127 |
128 | self.moduleOne = torch.nn.Sequential(
129 | torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
130 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
131 | )
132 |
133 | self.moduleTwo = torch.nn.Sequential(
134 | torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
135 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
136 | )
137 |
138 | self.moduleThr = torch.nn.Sequential(
139 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
140 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
141 | )
142 |
143 | self.moduleFou = torch.nn.Sequential(
144 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
145 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
146 | )
147 |
148 | self.moduleFiv = torch.nn.Sequential(
149 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
150 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
151 | )
152 |
153 | self.moduleSix = torch.nn.Sequential(
154 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
155 | )
156 | # end
157 |
158 | def forward(self, tensorFirst, tensorSecond, objectPrevious):
159 | tensorFlow = None
160 | tensorFeat = None
161 |
162 | if objectPrevious is None:
163 | tensorFlow = None
164 | tensorFeat = None
165 |
166 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, tensorSecond))
167 |
168 | tensorFeat = torch.cat([ tensorVolume ], 1)
169 |
170 | elif objectPrevious is not None:
171 | tensorFlow = self.moduleUpflow(objectPrevious['tensorFlow'])
172 | tensorFeat = self.moduleUpfeat(objectPrevious['tensorFeat'])
173 |
174 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, self.moduleBackward(tensorSecond, tensorFlow * self.dblBackward)))
175 |
176 | tensorFeat = torch.cat([ tensorVolume, tensorFirst, tensorFlow, tensorFeat ], 1)
177 |
178 | # end
179 |
180 | tensorFeat = torch.cat([ self.moduleOne(tensorFeat), tensorFeat ], 1)
181 | tensorFeat = torch.cat([ self.moduleTwo(tensorFeat), tensorFeat ], 1)
182 | tensorFeat = torch.cat([ self.moduleThr(tensorFeat), tensorFeat ], 1)
183 | tensorFeat = torch.cat([ self.moduleFou(tensorFeat), tensorFeat ], 1)
184 | tensorFeat = torch.cat([ self.moduleFiv(tensorFeat), tensorFeat ], 1)
185 |
186 | tensorFlow = self.moduleSix(tensorFeat)
187 |
188 | return {
189 | 'tensorFlow': tensorFlow,
190 | 'tensorFeat': tensorFeat
191 | }
192 | # end
193 | # end
194 |
195 | class Refiner(torch.nn.Module):
196 | def __init__(self):
197 | super(Refiner, self).__init__()
198 |
199 | self.moduleMain = torch.nn.Sequential(
200 | torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
201 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
202 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
203 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
204 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
205 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
206 | torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
207 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
208 | torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
209 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
210 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
211 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
212 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
213 | )
214 | # end
215 |
216 | def forward(self, tensorInput):
217 | return self.moduleMain(tensorInput)
218 | # end
219 | # end
220 |
221 | self.moduleExtractor = Extractor()
222 |
223 | self.moduleTwo = Decoder(2)
224 | self.moduleThr = Decoder(3)
225 | self.moduleFou = Decoder(4)
226 | self.moduleFiv = Decoder(5)
227 | self.moduleSix = Decoder(6)
228 |
229 | self.moduleRefiner = Refiner()
230 | # end
231 |
232 | def forward(self, tensorFirst, tensorSecond):
233 | tensorFirst = self.moduleExtractor(tensorFirst)
234 | tensorSecond = self.moduleExtractor(tensorSecond)
235 |
236 | objectEstimate = self.moduleSix(tensorFirst[-1], tensorSecond[-1], None)
237 | objectEstimate = self.moduleFiv(tensorFirst[-2], tensorSecond[-2], objectEstimate)
238 | objectEstimate = self.moduleFou(tensorFirst[-3], tensorSecond[-3], objectEstimate)
239 | objectEstimate = self.moduleThr(tensorFirst[-4], tensorSecond[-4], objectEstimate)
240 | objectEstimate = self.moduleTwo(tensorFirst[-5], tensorSecond[-5], objectEstimate)
241 |
242 | return objectEstimate['tensorFlow'] + self.moduleRefiner(objectEstimate['tensorFeat'])
243 | # end
244 | # end
--------------------------------------------------------------------------------
/src/inpainting_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from scipy import ndimage
5 | import cv2
6 | import random
7 |
8 | from flow_estimator import FlowEstimator
9 | from models.perceptual import LossNetwork
10 | from utils import *
11 |
12 |
13 | class InpaintingDataset(object):
14 | """
15 | Data loader for the input video
16 | """
17 | def __init__(self, cfg):
18 | self.cfg = cfg
19 | if not os.path.exists(self.cfg['video_path']):
20 | raise Exception("Input video not found: {}".format(self.cfg['video_path']))
21 | if not os.path.exists(self.cfg['mask_path']):
22 | raise Exception("Input mask not found: {}".format(self.cfg['mask_path']))
23 |
24 | cap = cv2.VideoCapture(self.cfg['video_path'])
25 | frame_sum_true = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
26 | self.frame_W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
27 | self.frame_H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
28 | cap.release()
29 |
30 | if not self.cfg['resize'] is None:
31 | self.frame_H, self.frame_W = self.cfg['resize']
32 | self.frame_sum = self.cfg['frame_sum'] = min(self.cfg['frame_sum'], frame_sum_true)
33 | self.frame_size = self.cfg['frame_size'] = (self.frame_H , self.frame_W)
34 | self.batch_size = self.cfg['batch_size']
35 | self.batch_idx = 0
36 | self.batch_list_train = None
37 |
38 | if self.cfg['use_perceptual']:
39 | self.netL = LossNetwork(None, self.cfg['perceptual_layers'])
40 |
41 | self.init_frame_mask()
42 | self.init_input()
43 | self.init_flow()
44 | self.init_flow_mask()
45 | self.init_perceptual_mask()
46 | self.init_batch_list()
47 |
48 |
49 | def init_frame_mask(self):
50 | """
51 | Load input video and mask
52 | """
53 | self.image_all = []
54 | self.mask_all = []
55 | self.contour_all = []
56 |
57 | cap_video = cv2.VideoCapture(self.cfg['video_path'])
58 | cap_mask = cv2.VideoCapture(self.cfg['mask_path'])
59 | for fid in range(self.frame_sum):
60 | frame, mask = self.load_single_frame(cap_video, cap_mask)
61 | contour, hier = cv2.findContours(mask[0].astype('uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
62 | self.image_all.append(frame)
63 | self.mask_all.append(mask)
64 | self.contour_all.append(contour)
65 | cap_video.release()
66 | cap_mask.release()
67 | self.image_all = np.array(self.image_all)
68 | self.mask_all = np.array(self.mask_all)
69 |
70 |
71 | def init_input(self):
72 | """
73 | Generate input noise map
74 | """
75 | input_noise = get_noise(self.frame_sum, self.cfg['input_channel'], 'noise', self.frame_size, var=self.cfg['input_ratio']).float().detach()
76 | input_noise = torch_to_np(input_noise) # N x C x H x W
77 | self.input_noise = input_noise
78 |
79 |
80 | def init_flow(self):
81 | """
82 | Estimate flow using PWC-Net
83 | """
84 | self.cfg['flow_value_max'] = self.flow_value_max = None
85 | if self.cfg['use_flow']:
86 | flow_estimator = FlowEstimator()
87 | print('Loading input video and estimating flow...')
88 | self.flow_all = { ft : [] for ft in self.cfg['flow_type']}
89 | self.flow_value_max = {}
90 |
91 | for bs in self.cfg['batch_stride']:
92 | f, b = 'f' + str(bs), 'b' + str(bs)
93 | for fid in range(0, self.frame_sum - bs):
94 | frame_first = np_to_torch(self.image_all[fid]).clone()
95 | frame_second = np_to_torch(self.image_all[fid + bs]).clone()
96 | frame_first = frame_first.cuda()
97 | frame_second = frame_second.cuda()
98 | flow_f = flow_estimator.estimate_flow_pair(frame_first, frame_second).detach().cpu()
99 | flow_b = flow_estimator.estimate_flow_pair(frame_second, frame_first).detach().cpu()
100 | torch.cuda.empty_cache()
101 |
102 | flow_f, flow_b = check_flow_occlusion(flow_f, flow_b)
103 | self.flow_all[f].append(flow_f.numpy())
104 | self.flow_all[b].append(flow_b.numpy())
105 |
106 | for bs in self.cfg['batch_stride']:
107 | f, b = 'f' + str(bs), 'b' + str(bs)
108 | self.flow_all[f] = np.array(self.flow_all[f] + [self.flow_all[f][0]] * bs, dtype=np.float32)
109 | self.flow_all[b] = np.array([self.flow_all[b][0]] * bs + self.flow_all[b], dtype=np.float32)
110 | self.flow_value_max[bs] = max(np.abs(self.flow_all[f]).max().astype('float'), \
111 | np.abs(self.flow_all[b]).max().astype('float'))
112 | self.cfg['flow_value_max'] = self.flow_value_max
113 |
114 |
115 | def init_flow_mask(self):
116 | """
117 | Pre-compute warped mask and intersection of warped mask with original mask
118 | """
119 | if self.cfg['use_flow']:
120 | self.mask_warp_all = { ft : [] for ft in self.cfg['flow_type']}
121 | self.mask_flow_all = { ft : [] for ft in self.cfg['flow_type']}
122 | for bs in self.cfg['batch_stride']:
123 | f, b = 'f' + str(bs), 'b' + str(bs)
124 | # forward
125 | mask_warpf, _ = warp_np(self.mask_all[bs:], self.flow_all[f][:-bs][:, :2, ...])
126 | mask_warpf = (mask_warpf > 0).astype(np.float32)
127 | mask_flowf = 1. - (1. - mask_warpf) * (1. - self.mask_all[:-bs])
128 | self.mask_warp_all[f] = np.concatenate((mask_warpf, self.mask_all[-bs:]), 0)
129 | self.mask_flow_all[f] = np.concatenate((mask_flowf, self.mask_all[-bs:]), 0)
130 | # backward
131 | mask_warpb, _ = warp_np(self.mask_all[:-bs], self.flow_all[b][bs:][:, :2, ...])
132 | mask_warpb = (mask_warpb > 0).astype(np.float32)
133 | mask_flowb = 1. - (1. - mask_warpb) * (1. - self.mask_all[bs:])
134 | self.mask_warp_all[b] = np.concatenate((self.mask_all[:bs], mask_warpb), 0)
135 | self.mask_flow_all[b] = np.concatenate((self.mask_all[:bs], mask_flowb), 0)
136 |
137 |
138 | def init_perceptual_mask(self):
139 | """
140 | Pre-compute shrinked mask for perceptual loss
141 | """
142 | if self.cfg['use_perceptual']:
143 | self.mask_per_all = []
144 | mask_per = self.netL(np_to_torch(self.mask_all))
145 | for i, mask in enumerate(mask_per):
146 | self.mask_per_all.append((mask.detach().numpy() > 0).astype(np.float32))
147 |
148 |
149 | def init_batch_list(self):
150 | """
151 | List all the possible batch permutations
152 | """
153 | if self.cfg['use_flow']:
154 | self.batch_list = []
155 | for flow_type in self.cfg['flow_type']:
156 | batch_stride = int(flow_type[1])
157 | for batch_idx in range(0, self.frame_sum - (self.batch_size - 1) * batch_stride, self.cfg['traverse_step']):
158 | self.batch_list.append((batch_idx, batch_stride, [flow_type]))
159 | if self.cfg['batch_mode'] == 'random':
160 | random.shuffle(self.batch_list)
161 | else:
162 | for bs in self.cfg['batch_stride']:
163 | self.batch_list = self.batch_list = [(i, bs, []) for i in range(self.frame_sum - self.batch_size + 1)]
164 | if self.cfg['batch_mode'] == 'random':
165 | median = self.batch_list[len(self.batch_list) // 2]
166 | random.shuffle(self.batch_list)
167 | self.batch_list.remove(median)
168 | self.batch_list.append(median)
169 |
170 |
171 | def set_mode(self, mode):
172 | if mode == 'infer':
173 | self.batch_list_train = self.batch_list
174 | self.batch_list = [(i, 1, ['f1']) for i in range(self.frame_sum - self.batch_size + 1)]
175 | elif mode == 'train':
176 | if not self.batch_list_train is None:
177 | self.batch_list = self.batch_list_train
178 | else:
179 | self.init_batch_list()
180 |
181 |
182 | def next_batch(self):
183 | if len(self.batch_list) == 0:
184 | self.init_batch_list()
185 | return None
186 | else:
187 | (batch_idx, batch_stride, flow_type) = self.batch_list[0]
188 | self.batch_list = self.batch_list[1:]
189 | return self.get_batch_data(batch_idx, batch_stride, flow_type)
190 |
191 |
192 | def get_batch_data(self, batch_idx=0, batch_stride=1, flow_type=[]):
193 | """
194 | Collect batch data for centain batch
195 | """
196 | cur_batch = range(batch_idx, batch_idx + self.batch_size*batch_stride, batch_stride)
197 | batch_data = {}
198 | input_batch, img_batch, mask_batch, contour_batch = [], [], [], []
199 | if self.cfg['use_flow']:
200 | flow_batch = { ft : [] for ft in flow_type}
201 | mask_flow_batch = { ft : [] for ft in flow_type}
202 | mask_warp_batch = { ft : [] for ft in flow_type}
203 | if self.cfg['use_perceptual']:
204 | mask_per_batch = [ [] for _ in self.cfg['perceptual_layers']]
205 |
206 | for i, fid in enumerate(cur_batch):
207 | input_batch.append(self.input_noise[fid])
208 | img_batch.append(self.image_all[fid])
209 | mask_batch.append(self.mask_all[fid])
210 | contour_batch.append(self.contour_all[fid])
211 | if self.cfg['use_flow']:
212 | for ft in flow_type:
213 | flow_batch[ft].append(self.flow_all[ft][fid])
214 | mask_flow_batch[ft].append(self.mask_flow_all[ft][fid])
215 | mask_warp_batch[ft].append(self.mask_warp_all[ft][fid])
216 | if self.cfg['use_perceptual']:
217 | for l in range(len(self.cfg['perceptual_layers'])):
218 | mask_per_batch[l].append(self.mask_per_all[l][fid])
219 |
220 |
221 | if self.cfg['use_flow']:
222 | for ft in flow_type:
223 | idx1, idx2 = (0, -1) if 'f' in ft else (1, self.batch_size)
224 | flow_batch[ft] = np.array(flow_batch[ft][idx1:idx2])
225 | mask_flow_batch[ft] = np.array(mask_flow_batch[ft][idx1:idx2])
226 | mask_warp_batch[ft] = np.array(mask_warp_batch[ft][idx1:idx2])
227 | if self.cfg['use_perceptual']:
228 | for l in range(len(self.cfg['perceptual_layers'])):
229 | mask_per_batch[l] = np.array(mask_per_batch[l])
230 |
231 | batch_data['cur_batch'] = cur_batch
232 | batch_data['batch_idx'] = batch_idx
233 | batch_data['batch_stride'] = batch_stride
234 | batch_data['input_batch'] = np.array(input_batch)
235 | batch_data['img_batch'] = np.array(img_batch)
236 | batch_data['mask_batch'] = np.array(mask_batch)
237 | batch_data['contour_batch'] = contour_batch
238 | if self.cfg['use_flow']:
239 | batch_data['flow_type'] = flow_type
240 | batch_data['flow_batch'] = flow_batch
241 | batch_data['mask_flow_batch'] = mask_flow_batch
242 | batch_data['mask_warp_batch'] = mask_warp_batch
243 | batch_data['flow_value_max'] = self.flow_value_max[batch_stride]
244 | if self.cfg['use_perceptual']:
245 | batch_data['mask_per_batch'] = mask_per_batch
246 |
247 | return batch_data
248 |
249 |
250 | def get_median_batch(self):
251 | return self.get_batch_data(int((self.cfg['frame_sum']) // 2), 1, ['f1'])
252 |
253 |
254 | def get_all_data(self):
255 | """
256 | Result a batch containing all the frames
257 | """
258 | batch_data = {}
259 | batch_data['input_batch'] = np.array(self.input_noise[:self.frame_sum])
260 | batch_data['img_batch'] = self.image_all
261 | batch_data['mask_batch'] = self.mask_all
262 | batch_data['contour_batch'] = self.contour_all
263 | if self.cfg['use_perceptual']:
264 | batch_data['mask_per_batch'] = self.mask_per_all
265 | if self.cfg['use_flow']:
266 | batch_data['flow_type'] = self.cfg['flow_type']
267 | batch_data['flow_batch'] = self.flow_all
268 | batch_data['mask_flow_batch'] = self.mask_flow_all
269 | batch_data['mask_warp_batch'] = self.mask_warp_all
270 | return batch_data
271 |
272 |
273 | def load_single_frame(self, cap_video, cap_mask):
274 | gt = self.load_image(cap_video, False, self.frame_size)
275 | mask = self.load_image(cap_mask, True, self.frame_size)
276 | if self.cfg['dilation_iter'] > 0:
277 | mask = ndimage.binary_dilation(mask > 0, iterations=self.cfg['dilation_iter']).astype(np.float32)
278 | return gt, mask
279 |
280 |
281 | def load_image(self, cap, is_mask, resize=None):
282 | _, img = cap.read()
283 | if not resize is None:
284 | img = self.crop_and_resize(img, resize)
285 | if is_mask:
286 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None]
287 | img = (img > 127) * 255
288 | img = img.astype('uint8')
289 | img_convert = img.transpose(2, 0, 1)
290 | return img_convert.astype(np.float32) / 255
291 |
292 |
293 | def crop_and_resize(self, img, resize):
294 | """
295 | Crop and resize img, keeping relative ratio unchanged
296 | """
297 | h, w = img.shape[:2]
298 | source = 1. * h / w
299 | target = 1. * resize[0] / resize[1]
300 | if source > target:
301 | margin = int((h - w * target) // 2)
302 | img = img[margin:h-margin]
303 | elif source < target:
304 | margin = int((w - h / target) // 2)
305 | img = img[:, margin:w-margin]
306 | img = cv2.resize(img, (resize[1], resize[0]), interpolation=self.cfg['interpolation'])
307 | return img
--------------------------------------------------------------------------------
/src/inpainting_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 | import time
5 | import os
6 | import cv2
7 | import logging
8 |
9 | from inpainting_dataset import InpaintingDataset
10 | from models.encoder_decoder_2d import EncoderDecoder2D
11 | from models.encoder_decoder_3d import EncoderDecoder3D
12 | from models.perceptual import LossNetwork
13 | from utils import *
14 |
15 |
16 | class InpaintingTest(object):
17 | """
18 | Internal learning framework
19 | """
20 | def __init__(self, cfg):
21 | self.cfg = cfg
22 | self.log = {}
23 | self.log['cfg'] = self.cfg.copy()
24 | for loss_name in self.cfg['loss_weight']:
25 | self.log['loss_' + loss_name + '_train'] = [ [] for _ in range(self.cfg['num_pass'])]
26 | self.log['loss_' + loss_name + '_infer'] = [ [] for _ in range(self.cfg['num_pass'])]
27 | self.data_loader = None
28 | self.netG = None
29 | self.optimizer_G = None
30 |
31 | # Pre-process cfg
32 | self.cfg['use_perceptual'] = self.cfg['loss_weight']['perceptual'] > 0
33 | self.cfg['use_flow'] = self.cfg['loss_weight']['recon_flow'] > 0
34 |
35 | if self.cfg['use_flow']:
36 | self.cfg['flow_type'] = []
37 | for bs in self.cfg['batch_stride']:
38 | self.cfg['flow_type'] += ['f' + str(bs), 'b' + str(bs)]
39 | self.cfg['flow_channel_map'] = {}
40 | channel_idx = self.cfg['output_channel_img']
41 | for ft in self.cfg['flow_type']:
42 | self.cfg['flow_channel_map'][ft] = (channel_idx, channel_idx+2)
43 | channel_idx += 2
44 | else:
45 | self.cfg['flow_type'] = None
46 |
47 | # Build result folder
48 | res_dir = self.cfg['res_dir']
49 | if not res_dir is None:
50 | res_dir = os.path.join(res_dir, os.path.basename(self.cfg['video_path']).split('.')[0])
51 | self.cfg['res_dir'] = res_dir
52 | if os.path.exists(res_dir):
53 | print("Warning: Video folder existed!")
54 | mkdir(res_dir)
55 | mkdir(os.path.join(res_dir, 'model'))
56 | for pass_idx in range(self.cfg['num_pass']):
57 | if (pass_idx + 1) % self.cfg['save_every_pass'] == 0:
58 | res_dir = os.path.join(self.cfg['res_dir'], '{:03}'.format(pass_idx + 1))
59 | mkdir(res_dir)
60 | iter = self.cfg['save_every_iter']
61 | while iter <= self.cfg['num_iter']:
62 | build_dir(res_dir, '{:05}'.format(iter))
63 | iter += self.cfg['save_every_iter']
64 | build_dir(res_dir, 'final')
65 | if self.cfg['train_mode'] == 'DIP':
66 | build_dir(res_dir, 'best_nonhole')
67 |
68 | # Setup logging
69 | logging.basicConfig(level=self.cfg['logging_level'], format='%(message)s')
70 | self.logger = logging.getLogger(__name__)
71 | self.log_handler = None
72 | if not self.cfg['res_dir'] is None:
73 | self.log_handler = logging.FileHandler(os.path.join(self.cfg['res_dir'], 'log.txt'))
74 | # self.log_handler.setLevel(logging.DEBUG)
75 | # formatter = logging.Formatter('%(message)s')
76 | # self.log_handler.setFormatter(formatter)
77 | self.logger.addHandler(self.log_handler)
78 | self.logger.info('========================================== Config ==========================================')
79 | for key in sorted(self.cfg):
80 | self.logger.info('[{}]: {}'.format(key, str(self.cfg[key])))
81 |
82 |
83 | def create_data_loader(self):
84 | self.logger.info('========================================== Dataset ==========================================')
85 |
86 | self.data_loader = InpaintingDataset(self.cfg)
87 |
88 | self.logger.info("[Video name]: {}".format(os.path.basename(self.cfg['video_path'])))
89 | self.logger.info("[Mask name]: {}".format(os.path.basename(self.cfg['mask_path'])))
90 | self.logger.info("[Frame sum]: {}".format(self.cfg['frame_sum']))
91 | self.logger.info("[Batch size]: {}".format(self.cfg['batch_size']))
92 | self.logger.info("[Frame size]: {}".format(self.cfg['frame_size']))
93 | self.logger.info("[Flow type]: {}".format(self.cfg['flow_type']))
94 | self.logger.info("[Flow_value_max]: {}".format(self.cfg['flow_value_max']))
95 |
96 | self.log['input_noise'] = self.data_loader.input_noise
97 |
98 |
99 | def visualize_single_batch(self):
100 | """
101 | Randomly visualize one batch data
102 | """
103 | batch_data = self.data_loader.next_batch()
104 | input_batch = batch_data['input_batch']
105 | img_batch = batch_data['img_batch']
106 | mask_batch = batch_data['mask_batch']
107 | nonhole_batch = img_batch * (1 - mask_batch)
108 |
109 | if self.cfg['batch_size'] == 1:
110 | plot_image_grid(np.concatenate((img_batch, nonhole_batch), 0), 2, padding=3, factor=10)
111 | else:
112 | plot_image_grid(np.concatenate((img_batch, nonhole_batch), 0), self.cfg['batch_size'], padding=3, factor=15)
113 |
114 |
115 | def create_model(self):
116 | if not self.cfg['use_skip']:
117 | num_channels_skip = [0] * self.cfg['net_depth']
118 |
119 | input_channel = self.cfg['input_channel']
120 | if self.cfg['use_flow']:
121 | output_channel_flow = len(self.cfg['flow_type']) * 2
122 | output_channel = self.cfg['output_channel_img'] + output_channel_flow
123 | else:
124 | output_channel = self.cfg['output_channel_img']
125 |
126 | if self.cfg['net_type_G'] == '2d':
127 | self.netG = EncoderDecoder2D(input_channel, output_channel,
128 | self.cfg['num_channels_down'][:self.cfg['net_depth']],
129 | self.cfg['num_channels_up'][:self.cfg['net_depth']],
130 | self.cfg['num_channels_skip'][:self.cfg['net_depth']],
131 | self.cfg['filter_size_down'], self.cfg['filter_size_up'], self.cfg['filter_size_skip'],
132 | upsample_mode='nearest', downsample_mode='stride',
133 | need1x1_up=True, need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU')
134 | elif self.cfg['net_type_G'] == '3d':
135 | self.netG = EncoderDecoder3D(input_channel, output_channel,
136 | self.cfg['num_channels_down'][:self.cfg['net_depth']],
137 | self.cfg['num_channels_up'][:self.cfg['net_depth']],
138 | self.cfg['num_channels_skip'][:self.cfg['net_depth']],
139 | self.cfg['filter_size_down'], self.cfg['filter_size_up'], self.cfg['filter_size_skip'],
140 | upsample_mode='nearest', downsample_mode='stride',
141 | need1x1_up=True, need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU')
142 | else:
143 | raise Exception("Network not defined!")
144 | self.netG = self.netG.type(self.cfg['dtype'])
145 |
146 | if self.cfg['use_perceptual']:
147 | if self.cfg['net_type_L'] == 'VGG16':
148 | vgg_model = torchvision.models.vgg16(pretrained=True).type(self.cfg['dtype'])
149 | vgg_model.eval()
150 | vgg_model.requires_grad = False
151 | self.netL = LossNetwork(vgg_model)
152 |
153 | self.logger.info('========================================== Network ==========================================')
154 | self.logger.info(self.netG)
155 | self.logger.info("Total number of parameters: {}".format(get_model_num_parameters(self.netG)))
156 |
157 |
158 | def create_loss_function(self):
159 | if self.cfg['loss_recon'] == 'L1':
160 | self.criterion_recon = torch.nn.L1Loss().type(self.cfg['dtype'])
161 | elif self.cfg['loss_recon'] == 'L2':
162 | self.criterion_recon = torch.nn.MSELoss().type(self.cfg['dtype'])
163 | self.criterion_MSE = torch.nn.MSELoss().type(self.cfg['dtype'])
164 |
165 |
166 | def create_optimizer(self):
167 | if self.cfg['optimizer_G'] == 'Adam':
168 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.cfg['LR'])
169 | elif self.cfg['optimizer_G'] == 'SGD':
170 | self.optimizer_G = torch.optim.SGD(self.netG.parameters(), lr=self.cfg['LR'])
171 |
172 |
173 | def prepare_input(self, input_batch):
174 | """
175 | Prepare input noise map based on network type (2D/3D)
176 | """
177 | input_tensor = np_to_torch(input_batch).type(self.cfg['dtype']) # N x C x H x W
178 | if self.cfg['net_type_G'] == '2d':
179 | return input_tensor # N x C x H x W
180 | elif self.cfg['net_type_G'] == '3d':
181 | return input_tensor.transpose(0, 1).unsqueeze(0) # 1 x C x N x H x W
182 |
183 |
184 | def train(self):
185 | """
186 | Main function for internal learning
187 | """
188 | self.start_time = time.time()
189 | self.data_loader.init_batch_list()
190 |
191 | self.logger.info('========================================== Training ==========================================')
192 | if self.cfg['train_mode'] == 'DIP-Vid-Flow':
193 | self.train_with_flow()
194 | else:
195 | self.train_baseline()
196 |
197 | # Save log
198 | if not self.cfg['res_dir'] is None:
199 | torch.save(self.log, os.path.join(self.cfg['res_dir'], 'log.tar'))
200 |
201 | # Report training time
202 | running_time = time.time() - self.start_time
203 | self.logger.info("Training finished! Running time: {}s".format(running_time))
204 |
205 | # Release log file
206 | self.logger.removeHandler(self.log_handler)
207 |
208 |
209 | def train_baseline(self):
210 | """
211 | Training procedure for all baselines
212 | """
213 | for pass_idx in range(self.cfg['num_pass']):
214 | while True:
215 | # Get batch data
216 | batch_data = self.data_loader.next_batch()
217 | if batch_data is None:
218 | break
219 | batch_idx = batch_data['batch_idx']
220 |
221 | self.logger.info('Pass {}, Batch {}'.format(pass_idx + 1, batch_idx))
222 |
223 | # Start train
224 | self.train_batch(pass_idx, batch_idx, batch_data)
225 |
226 | if self.cfg['train_mode'] == 'DIP':
227 | self.create_model()
228 | self.create_optimizer()
229 |
230 | # Start infer
231 | if (pass_idx + 1) % self.cfg['save_every_pass'] == 0 and self.cfg['train_mode'] != 'DIP':
232 | inferred_result = self.infer(pass_idx)
233 |
234 | self.logger.info("Saving latest model at pass {}".format(pass_idx + 1))
235 | if not self.cfg['res_dir'] is None:
236 | # Save model and log
237 | checkpoint_G = os.path.join(self.cfg['res_dir'], 'model', '{:03}.tar'.format(pass_idx + 1))
238 | torch.save(self.netG.state_dict(), checkpoint_G)
239 | torch.save(self.log, os.path.join(self.cfg['res_dir'], 'log.tar'))
240 |
241 | self.logger.info("Running time: {}s".format(time.time() - self.start_time))
242 |
243 |
244 |
245 |
246 | def train_with_flow(self):
247 | """
248 | Training procedure for DIP-Vid-Flow
249 | """
250 | pass_idx = 0
251 | batch_count = 0
252 | while pass_idx < self.cfg['num_pass']:
253 | # Get batch data
254 | batch_data = self.data_loader.next_batch()
255 | if batch_data is None:
256 | continue
257 | batch_idx = batch_data['batch_idx']
258 |
259 | self.logger.info('Pass: {}, Batch: {}, Flow: {}'.format(pass_idx, batch_idx, str(batch_data['flow_type'])))
260 |
261 | # Start train
262 | self.train_batch(pass_idx, batch_idx, batch_data)
263 | batch_count += 1
264 |
265 | # Start infer
266 | if batch_count % self.cfg['save_every_batch'] == 0:
267 | batch_data = self.data_loader.get_median_batch()
268 | batch_idx = batch_data['batch_idx']
269 | self.logger.info('Train the median batch before inferring\nPass: {}, Batch: {}, Flow: {}'.format(pass_idx, batch_idx, str(batch_data['flow_type'])))
270 | self.train_batch(pass_idx, batch_idx, batch_data)
271 |
272 | self.infer(pass_idx)
273 |
274 | # Save model and log
275 | self.logger.info("Running time: {}s".format(time.time() - self.start_time))
276 | if not self.cfg['res_dir'] is None:
277 | checkpoint_G = os.path.join(self.cfg['res_dir'], 'model', '{:03}.tar'.format(pass_idx + 1))
278 | torch.save(self.netG.state_dict(), checkpoint_G)
279 | torch.save(self.log, os.path.join(self.cfg['res_dir'], 'log.tar'))
280 | pass_idx += 1
281 |
282 |
283 | def infer(self, pass_idx):
284 | """
285 | Run inferrance with trained model to collect all inpainted frames
286 | """
287 | self.logger.info('Pass {} infer start...'.format(pass_idx))
288 | self.data_loader.set_mode('infer')
289 |
290 | inferred_result = np.empty((self.cfg['frame_sum'], self.cfg['output_channel_img'], self.cfg['frame_size'][0], self.cfg['frame_size'][1]), dtype=np.float32)
291 | while True:
292 | batch_data = self.data_loader.next_batch()
293 | if batch_data is None:
294 | break
295 | batch_idx = batch_data['batch_idx']
296 | self.infer_batch(pass_idx, batch_idx, batch_data, inferred_result)
297 |
298 | self.data_loader.set_mode('train')
299 | return inferred_result
300 |
301 |
302 | def train_batch(self, pass_idx, batch_idx, batch_data):
303 | """
304 | Train the given batch for `num_iter` iterations
305 | """
306 | for loss_name in self.cfg['loss_weight']:
307 | self.log['loss_' + loss_name + '_train'][pass_idx].append([])
308 | best_loss_recon_image = 1e9
309 | best_iter = 0
310 | best_nonhole_batch = None
311 | batch_data['pass_idx'] = pass_idx
312 | batch_data['train'] = True
313 |
314 | # Optimize
315 | for iter_idx in range(self.cfg['num_iter']):
316 | if self.cfg['param_noise']:
317 | for n in [x for x in self.netG.parameters() if len(x.size()) == 4]:
318 | n = n + n.detach().clone().normal_() * n.std() / 50
319 |
320 | # Forward
321 | loss = self.optimize_params(batch_data)
322 |
323 | # Update
324 | for loss_name in self.cfg['loss_weight']:
325 | self.log['loss_' + loss_name + '_train'][pass_idx][-1].append(loss[loss_name].item())
326 | if loss['recon_image'].item() < best_loss_recon_image:
327 | best_loss_recon_image = loss['recon_image'].item()
328 | best_nonhole_batch = batch_data['out_img_batch']
329 | best_iter = iter_idx
330 |
331 | log_str = 'Iteration {:05}'.format(iter_idx)
332 | for loss_name in sorted(self.cfg['loss_weight']):
333 | if self.cfg['loss_weight'][loss_name] != 0:
334 | log_str += ' ' + loss_name + ' {:f}'.format(loss[loss_name].item())
335 | self.logger.info(log_str)
336 |
337 | # Plot and save
338 | if (pass_idx + 1) % self.cfg['save_every_pass'] == 0 and (iter_idx + 1) % self.cfg['save_every_iter'] == 0:
339 | self.plot_and_save(batch_idx, batch_data, '{:03}/{:05}'.format(pass_idx + 1, iter_idx + 1))
340 |
341 | log_str = 'Best at iteration {:05}, recon_image loss {:f}'.format(best_iter, best_loss_recon_image)
342 | self.logger.info(log_str)
343 |
344 | if self.cfg['train_mode'] == 'DIP':
345 | # For DIP, save the result with lowest loss on nonhole region as final result
346 | batch_data['out_img_batch'] = best_nonhole_batch
347 | self.plot_and_save(batch_idx, batch_data, '001/best_nonhole')
348 |
349 |
350 | def infer_batch(self, pass_idx, batch_idx, batch_data, inferred_result):
351 | """
352 | Run inferrance for the given batch
353 | """
354 | # Forward pass
355 | batch_data['pass_idx'] = pass_idx
356 | batch_data['train'] = False
357 | loss = self.optimize_params(batch_data)
358 |
359 | # Update
360 | for loss_name in self.cfg['loss_weight']:
361 | self.log['loss_' + loss_name + '_infer'][pass_idx].append(loss[loss_name].item())
362 |
363 | # Save inferred result
364 | for i, img in enumerate(batch_data['out_img_batch'][0]):
365 | if batch_idx < batch_data['batch_stride'] or i >= self.cfg['batch_size'] // 2:
366 | inferred_result[batch_idx + i * batch_data['batch_stride']] = img
367 |
368 | log_str = 'Batch {:05}'.format(batch_idx)
369 | for loss_name in sorted(self.cfg['loss_weight']):
370 | if self.cfg['loss_weight'][loss_name] != 0:
371 | log_str += ' ' + loss_name + ' {:f}'.format(loss[loss_name].item())
372 | self.logger.info(log_str)
373 |
374 | # Plot and save
375 | if (self.cfg['plot'] or self.cfg['save']) and (pass_idx + 1) % self.cfg['save_every_pass'] == 0:
376 | self.plot_and_save(batch_idx, batch_data, '{:03}/final'.format(pass_idx + 1))
377 |
378 |
379 | def optimize_params(self, batch_data):
380 | """
381 | Calculate loss and back-propagate the loss
382 | """
383 | pass_idx = batch_data['pass_idx']
384 | batch_idx = batch_data['batch_idx']
385 | net_input = self.prepare_input(batch_data['input_batch'])
386 | img_tensor = np_to_torch(batch_data['img_batch']).type(self.cfg['dtype'])
387 | mask_tensor = np_to_torch(batch_data['mask_batch']).type(self.cfg['dtype'])
388 |
389 | if self.cfg['use_flow']:
390 | flow_tensor, mask_flow_tensor, mask_warp_tensor = {}, {}, {}
391 | for ft in batch_data['flow_type']:
392 | flow_tensor[ft] = np_to_torch(batch_data['flow_batch'][ft]).type(self.cfg['dtype'])
393 | mask_flow_tensor[ft] = np_to_torch(batch_data['mask_flow_batch'][ft]).type(self.cfg['dtype'])
394 | mask_warp_tensor[ft] = np_to_torch(batch_data['mask_warp_batch'][ft]).type(self.cfg['dtype'])
395 | flow_value_max = batch_data['flow_value_max']
396 |
397 | if self.cfg['use_perceptual']:
398 | mask_per_tensor = []
399 | for mask in batch_data['mask_per_batch']:
400 | mask = np_to_torch(mask).type(self.cfg['dtype'])
401 | mask_per_tensor.append(mask)
402 |
403 | # Forward
404 | net_output = self.netG(net_input)
405 | torch.cuda.empty_cache()
406 |
407 | # Collect image/flow from network output
408 | if self.cfg['net_type_G'] == '2d':
409 | out_img_tensor = net_output[:, :self.cfg['output_channel_img'], ...] # N x 3 x H x W
410 | if self.cfg['use_flow']:
411 | out_flow_tensor = {}
412 | for ft in batch_data['flow_type']:
413 | channel_idx1, channel_idx2 = self.cfg['flow_channel_map'][ft]
414 | flow_idx1, flow_idx2 = (0, -1) if 'f' in ft else (1, self.cfg['batch_size'])
415 | out_flow_tensor[ft] = net_output[flow_idx1:flow_idx2, channel_idx1:channel_idx2, ...]
416 |
417 | elif self.cfg['net_type_G'] == '3d':
418 | out_img_tensor = net_output.squeeze(0)[:self.cfg['output_channel_img']].transpose(0, 1) # N x 3 x H x W
419 | if self.cfg['use_flow']:
420 | out_flow_tensor = {}
421 | for ft in batch_data['flow_type']:
422 | channel_idx1, channel_idx2 = self.cfg['flow_channel_map'][ft]
423 | flow_idx1, flow_idx2 = (0, -1) if 'f' in ft else (1, self.cfg['batch_size'])
424 | out_flow_tensor[ft] = net_output.squeeze(0) \
425 | [channel_idx1:channel_idx2, flow_idx1:flow_idx2, ...].transpose(0, 1) # N-1 x 2 x H x W
426 |
427 | # Compute loss
428 | loss = {}
429 | for loss_name in self.cfg['loss_weight']:
430 | loss[loss_name] = torch.zeros([]).float().cuda().detach()
431 |
432 | self.optimizer_G.zero_grad()
433 |
434 | # Image reconstruction loss
435 | if self.cfg['loss_weight']['recon_image'] != 0:
436 | loss['recon_image'] += self.criterion_recon(
437 | out_img_tensor * (1. - mask_tensor), \
438 | img_tensor * (1. - mask_tensor))
439 |
440 | # Flow reconstruction loss
441 | if self.cfg['loss_weight']['recon_flow'] != 0:
442 | for ft in batch_data['flow_type']:
443 | mask_flow_inv = (1. - mask_flow_tensor[ft]) * flow_tensor[ft][:, 2:3, ...]
444 | loss['recon_flow'] += self.criterion_recon(out_flow_tensor[ft] * mask_flow_inv, \
445 | flow_tensor[ft][:, :2, ...] * mask_flow_inv / flow_value_max)
446 |
447 | # Consistency loss
448 | if self.cfg['loss_weight']['consistency'] != 0:
449 | warped_img, warped_diff = {}, {}
450 | for ft in batch_data['flow_type']:
451 | idx1, idx2 = (0, -1) if 'f' in ft else (1, self.cfg['batch_size'])
452 | idx_inv1, idx_inv2 = (1, self.cfg['batch_size']) if 'f' in ft else (0, -1)
453 | out_img = out_img_tensor[idx_inv1:idx_inv2]
454 | out_flow = out_flow_tensor[ft]
455 | warped_img[ft], flowmask = warp_torch(out_img, out_flow * flow_value_max)
456 | mask = mask_flow_tensor[ft] * flowmask.detach()
457 | loss['consistency'] += self.criterion_recon(
458 | warped_img[ft] * mask,
459 | out_img_tensor[idx1:idx2].detach() * mask)
460 | torch.cuda.empty_cache()
461 |
462 | # Perceptual loss
463 | if self.cfg['use_perceptual']:
464 | feature_src = self.netL(out_img_tensor)
465 | feature_dst = self.netL(img_tensor)
466 | for i, mask in enumerate(mask_per_tensor):
467 | loss['perceptual'] += self.criterion_MSE(
468 | feature_src[i] * (1. - mask), feature_dst[i].detach() * (1. - mask))
469 | torch.cuda.empty_cache()
470 |
471 | # Back-propagation
472 | running_loss = 0
473 | for loss_name, weight in self.cfg['loss_weight'].items():
474 | if weight != 0:
475 | running_loss = running_loss + weight * loss[loss_name]
476 |
477 | if batch_data['train']:
478 | running_loss.backward()
479 | self.optimizer_G.step()
480 | torch.cuda.empty_cache()
481 |
482 | # Save generated image/flow
483 | batch_data['out_img_batch'] = torch_to_np(out_img_tensor)
484 | if self.cfg['use_flow']:
485 | out_flow_batch = {}
486 | for ft in batch_data['flow_type']:
487 | out_flow_batch[ft] = torch_to_np(out_flow_tensor[ft] * flow_value_max)
488 | batch_data['out_flow_batch'] = out_flow_batch
489 | return loss
490 |
491 |
492 | def plot_and_save(self, batch_idx, batch_data, subpath):
493 | """
494 | Plot/save intermediate results
495 | """
496 | def save(imgs, subpath, subsubpath):
497 | res_dir = os.path.join(self.cfg['res_dir'], subpath, subsubpath)
498 | for i, img in enumerate(imgs):
499 | if img is None:
500 | continue
501 | fid = batch_idx + i * batch_data['batch_stride']
502 | batch_path = os.path.join(res_dir, 'batch', '{:03}_{:03}.png'.format(batch_idx, fid))
503 | sequence_path = os.path.join(res_dir, 'sequence', '{:03}.png'.format(fid))
504 | if self.cfg['save_batch']:
505 | cv2.imwrite(batch_path, np_to_cv2(img))
506 | if batch_idx < batch_data['batch_stride'] or i >= self.cfg['batch_size'] // 2:
507 | cv2.imwrite(sequence_path, np_to_cv2(img))
508 |
509 | # Load batch data
510 | input_batch = batch_data['input_batch'] / self.cfg['input_ratio']
511 | img_batch = batch_data['img_batch']
512 | mask_batch = batch_data['mask_batch']
513 | contour_batch = batch_data['contour_batch']
514 | out_img_batch = batch_data['out_img_batch'].copy()
515 | stitch_batch = img_batch * (1 - mask_batch) + out_img_batch * mask_batch
516 |
517 | # Draw mask boundary
518 | for i in range(self.cfg['batch_size']):
519 | for con in contour_batch[i]:
520 | for pt in con:
521 | x, y = pt[0]
522 | out_img_batch[i][:, y, x] = [0, 0, 1]
523 |
524 | # Plot in jupyter
525 | if self.cfg['plot']:
526 | if self.cfg['batch_size'] == 1:
527 | plot_image_grid(out_img_batch, 1, factor=10)
528 | else:
529 | plot_image_grid(out_img_batch, self.cfg['batch_size'], padding=3, factor=15)
530 |
531 | # Save images to disk
532 | if not self.cfg['res_dir'] is None and self.cfg['save']:
533 | save(stitch_batch, subpath, 'stitch')
534 | save(out_img_batch, subpath, 'full_with_boundary')
535 | save(batch_data['out_img_batch'], subpath, 'full')
--------------------------------------------------------------------------------