├── src ├── __init__.py ├── configs │ ├── __init__.py │ ├── DIP.py │ ├── DIP_Vid.py │ ├── DIP_Vid_Flow.py │ ├── DIP_Vid_3DCN.py │ └── base.py ├── models │ ├── __init__.py │ ├── perceptual.py │ ├── common.py │ ├── encoder_decoder_2d.py │ ├── encoder_decoder_3d.py │ └── pwc_net.py ├── correlation │ ├── __init__.py │ ├── README.md │ └── correlation.py ├── flow_estimator.py ├── utils.py ├── inpainting_dataset.py └── inpainting_test.py ├── .gitignore ├── img └── rollerblade.gif ├── data └── README.md ├── requirements.txt ├── pretrained_models └── README.md ├── Dockerfile ├── train.py ├── README.md └── demo.ipynb /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/correlation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .git 2 | data 3 | result 4 | pretrained_models 5 | *.pyc 6 | *.ipynb_checkpoints/ 7 | *__pycache__ -------------------------------------------------------------------------------- /img/rollerblade.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haotianz94/IL_video_inpainting/HEAD/img/rollerblade.gif -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Please download the demo data from [here](https://drive.google.com/open?id=1MJDCjj1aIUbW0OK9UnewhXlkKX9zllQd). -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | torchvision==0.2.0 3 | ipykernel==5.1.2 4 | jupyter==1.0.0 5 | opencv-contrib-python==4.1.0 6 | matplotlib==3.0.3 7 | scipy==1.3.1 8 | cupy==6.3.0 -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | Please download the pre-trained model weights for PWC-Net from [here](https://drive.google.com/open?id=1vyoQFBz--DEkUq-0gucbWYrVbfYTOZfz). 2 | (The model was originally trained by Simon Niklaus from [here](https://github.com/sniklaus/pytorch-pwc)). -------------------------------------------------------------------------------- /src/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implemenation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /src/configs/DIP.py: -------------------------------------------------------------------------------- 1 | from configs.base import cfg 2 | cfg = cfg.copy() 3 | 4 | # Dataset 5 | cfg['batch_size'] = 1 6 | cfg['batch_stride'] = [1] 7 | cfg['batch_mode'] = 'seq' 8 | cfg['resize'] = (192, 384) 9 | 10 | # Model 11 | cfg['net_type_G'] = '2d' 12 | 13 | # Loss 14 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0} 15 | 16 | # Optimize 17 | cfg['train_mode'] = 'DIP' 18 | cfg['num_iter'] = 5000 19 | cfg['num_pass'] = 1 20 | cfg['fine_tune'] = False 21 | cfg['param_noise'] = True -------------------------------------------------------------------------------- /src/configs/DIP_Vid.py: -------------------------------------------------------------------------------- 1 | from configs.base import cfg 2 | cfg = cfg.copy() 3 | 4 | # Dataset 5 | cfg['batch_size'] = 5 6 | cfg['batch_stride'] = [1] 7 | cfg['batch_mode'] = 'random' 8 | cfg['resize'] = (192, 384) 9 | 10 | # Model 11 | cfg['net_type_G'] = '2d' 12 | 13 | # Loss 14 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0} 15 | 16 | # Optimize 17 | cfg['train_mode'] = 'DIP-Vid' 18 | cfg['num_iter'] = 100 19 | cfg['num_pass'] = 20 20 | cfg['fine_tune'] = True 21 | cfg['param_noise'] = False -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-devel 2 | MAINTAINER Haotian zhang 3 | 4 | RUN apt-get update && apt-get install -y rsync openssh-server vim wget unzip htop tmux 5 | RUN apt-get install -y libsm6 libxrender1 libgtk2.0-dev 6 | 7 | RUN apt-get install python3-pip -y 8 | RUN pip3 install --upgrade pip 9 | 10 | RUN pip3 install torch==1.0.0 torchvision==0.2.0 11 | RUN pip3 install ipykernel jupyter 12 | RUN pip3 install opencv-contrib-python matplotlib 13 | RUN pip3 install cupy scipy 14 | 15 | EXPOSE 8888 -------------------------------------------------------------------------------- /src/configs/DIP_Vid_Flow.py: -------------------------------------------------------------------------------- 1 | from configs.base import cfg 2 | cfg = cfg.copy() 3 | 4 | # Dataset 5 | cfg['batch_size'] = 5 6 | cfg['batch_stride'] = [1, 3, 5] 7 | cfg['batch_mode'] = 'random' 8 | cfg['resize'] = (192, 384) 9 | 10 | # Model 11 | cfg['net_type_G'] = '2d' 12 | 13 | # Loss 14 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0.1, 'consistency': 1, 'perceptual': 0.01} 15 | 16 | # Optimize 17 | cfg['train_mode'] = 'DIP-Vid-Flow' 18 | cfg['num_iter'] = 100 19 | cfg['num_pass'] = 20 20 | cfg['fine_tune'] = True 21 | cfg['param_noise'] = False 22 | 23 | # Result 24 | cfg['save_every_batch'] = 50 25 | -------------------------------------------------------------------------------- /src/configs/DIP_Vid_3DCN.py: -------------------------------------------------------------------------------- 1 | from configs.base import cfg 2 | cfg = cfg.copy() 3 | 4 | # Dataset 5 | cfg['batch_size'] = 5 6 | cfg['batch_stride'] = [1] 7 | cfg['batch_mode'] = 'random' 8 | cfg['resize'] = (192, 384) 9 | 10 | # Model 11 | cfg['net_type_G'] = '3d' 12 | cfg['filter_size_down'] = (3, 5, 5) 13 | cfg['filter_size_up'] = (3, 3, 3) 14 | cfg['filter_size_skip'] = (1, 1, 1) 15 | 16 | # Loss 17 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0} 18 | 19 | # Optimize 20 | cfg['train_mode'] = 'DIP-Vid-3DCN' 21 | cfg['num_iter'] = 100 22 | cfg['num_pass'] = 20 23 | cfg['fine_tune'] = True 24 | cfg['param_noise'] = False -------------------------------------------------------------------------------- /src/models/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LossNetwork(nn.Module): 6 | """ 7 | Extract certain feature maps from pretrained VGG model, used for computing perceptual loss 8 | """ 9 | def __init__(self, vgg_model=None, output_layer=['3', '8', '15']): 10 | super(LossNetwork, self).__init__() 11 | if vgg_model is None: 12 | # prepare fixed VGG16 13 | conv = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False) 14 | conv.weight.data.fill_(1) 15 | pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) 16 | relu = torch.nn.ReLU() 17 | model = [] 18 | model += ([conv, relu] * 2 + [pool]) * 2 19 | model += ([conv, relu] * 4 + [pool]) * 2 20 | model += [conv, relu] * 4 21 | self.vgg_layers = model 22 | else: 23 | self.vgg_layers = vgg_model.features 24 | self.output_layer = output_layer 25 | self.layer_name_mapping = { 26 | '3': "relu1_2", 27 | '8': "relu2_2", 28 | '15': "relu3_3", 29 | '22': "relu4_3" 30 | } 31 | 32 | def forward(self, x): 33 | feature_list = [] 34 | if type(self.vgg_layers) == list: 35 | for layer, module in enumerate(self.vgg_layers): 36 | x = module(x) 37 | if str(layer) in self.output_layer: 38 | feature_list.append(x) 39 | else: 40 | for name, module in self.vgg_layers._modules.items(): 41 | x = module(x) 42 | if name in self.output_layer: 43 | feature_list.append(x) 44 | return feature_list 45 | -------------------------------------------------------------------------------- /src/configs/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import logging 4 | 5 | cfg = {} 6 | # Dataset 7 | cfg['video_path'] = 'data/bmx-trees.avi' 8 | cfg['mask_path'] = 'data/bmx-trees_mask.avi' 9 | cfg['batch_size'] = 5 10 | cfg['batch_stride'] = [1] 11 | cfg['batch_mode'] = 'random' 12 | cfg['traverse_step'] = 1 13 | cfg['frame_sum'] = 100 14 | cfg['resize'] = None 15 | cfg['interpolation'] = cv2.INTER_AREA 16 | cfg['input_type'] = 'noise' # 'mesh_grid' 17 | cfg['input_ratio'] = 0.1 18 | cfg['dilation_iter'] = 0 19 | cfg['input_noise_std'] = 0 20 | 21 | # Model 22 | cfg['net_type_G'] = '2d' # 3d 23 | cfg['net_type_L'] = 'VGG16' 24 | cfg['net_depth'] = 6 25 | cfg['input_channel'] = 1 26 | cfg['output_channel_img'] = 3 27 | cfg['num_channels_down'] = [16, 32, 64, 128, 128, 128] 28 | cfg['num_channels_up'] = [16, 32, 64, 128, 128, 128] 29 | cfg['num_channels_skip'] = [4, 4, 4, 4, 4, 4] 30 | cfg['filter_size_down'] = 5 31 | cfg['filter_size_up'] = 3 32 | cfg['filter_size_skip'] = 1 33 | cfg['use_skip'] = True 34 | cfg['dtype'] = torch.cuda.FloatTensor 35 | 36 | # Loss 37 | cfg['loss_weight'] = {'recon_image': 1, 'recon_flow': 0, 'consistency': 0, 'perceptual': 0} 38 | cfg['loss_recon'] = 'L2' 39 | cfg['perceptual_layers'] = ['3', '8', '15'] 40 | 41 | # Optimize 42 | cfg['train_mode'] = 'DIP-Vid-Flow' 43 | cfg['baseline'] = True 44 | cfg['LR'] = 1e-2 45 | cfg['optimizer_G'] = 'Adam' 46 | cfg['fine_tune'] = True 47 | cfg['param_noise'] = True 48 | cfg['num_iter'] = 100 49 | cfg['num_pass'] = 20 50 | 51 | # Result 52 | cfg['save_every_iter'] = 100 53 | cfg['save_every_pass'] = 1 54 | cfg['plot'] = False 55 | cfg['save'] = True 56 | cfg['save_batch'] = False 57 | cfg['res_dir'] = None 58 | 59 | # Log 60 | cfg['logging_level'] = logging.INFO # logging.DEBUG -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | sys.path.append('src') 4 | 5 | from inpainting_test import InpaintingTest 6 | from configs.DIP import cfg as DIP_cfg 7 | from configs.DIP_Vid import cfg as DIP_Vid_cfg 8 | from configs.DIP_Vid_3DCN import cfg as DIP_Vid_3DCN_cfg 9 | from configs.DIP_Vid_Flow import cfg as DIP_Vid_Flow_cfg 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--train_mode', type=str, default='DIP_Vid_Flow', help='mode of the experiment: (DIP|DIP_Vid|DIP_Vid_3DCN|DIP_Vid_Flow)', metavar='') 16 | parser.add_argument('--resize', nargs='+', type=int, default=None, help='height and width of the output', metavar='') 17 | parser.add_argument('--video_path', type=str, default='data/bmx-trees.avi', help='path of the input video', metavar='') 18 | parser.add_argument('--mask_path', type=str, default='data/bmx-trees_mask.avi', help='path of the input mask', metavar='') 19 | parser.add_argument('--res_dir', type=str, default='result', help='path to save the result', metavar='') 20 | parser.add_argument('--frame_sum', type=int, default=100, help='number of frames to load', metavar='') 21 | parser.add_argument('--dilation_iter', type=int, default=0, help='number of steps to dilate the mask', metavar='') 22 | 23 | if len(sys.argv) == 1: 24 | parser.print_help() 25 | sys.exit(1) 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(args): 31 | if args.train_mode == 'DIP': 32 | cfg = DIP_cfg 33 | elif args.train_mode == 'DIP-Vid': 34 | cfg = DIP_Vid_cfg 35 | elif args.train_mode == 'DIP-Vid-3DCN': 36 | cfg = DIP_Vid_3DCN_cfg 37 | elif args.train_mode == 'DIP-Vid-Flow': 38 | cfg = DIP_Vid_Flow_cfg 39 | else: 40 | raise Exception("Train mode {} not implemented!".format(args.train_mode)) 41 | 42 | cfg['resize'] = tuple(args.resize) 43 | if len(cfg['resize']) != 2: raise Exception("Resize must be a tuple of length 2!") 44 | cfg['video_path'] = args.video_path 45 | cfg['mask_path'] = args.mask_path 46 | cfg['res_dir'] = args.res_dir 47 | cfg['frame_sum'] = args.frame_sum 48 | cfg['dilation_iter'] = args.dilation_iter 49 | 50 | test = InpaintingTest(cfg) 51 | test.create_data_loader() 52 | test.visualize_single_batch() 53 | test.create_model() 54 | test.create_optimizer() 55 | test.create_loss_function() 56 | test.train() 57 | 58 | 59 | if __name__ == '__main__': 60 | args = parse_args() 61 | main(args) -------------------------------------------------------------------------------- /src/flow_estimator.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Code modified from https://github.com/sniklaus/pytorch-pwc 3 | ############################################################ 4 | 5 | import math 6 | import torch 7 | 8 | from models.pwc_net import PWC_Net 9 | 10 | 11 | class FlowEstimator(object): 12 | 13 | def __init__(self): 14 | self.model_pwc = PWC_Net().type(torch.cuda.FloatTensor) 15 | self.model_pwc.load_state_dict(torch.load('pretrained_models/pwc_net.tar')) 16 | 17 | 18 | def estimate_flow_pair(self, tensorInputFirst, tensorInputSecond): 19 | ### tensor format 20 | # C x H x W 21 | # BGR 22 | # 0-1 23 | # FloatTensor.cuda 24 | ### 25 | 26 | moduleNetwork = self.model_pwc 27 | tensorOutput = torch.FloatTensor().cuda() 28 | 29 | assert(tensorInputFirst.size(1) == tensorInputSecond.size(1)) 30 | assert(tensorInputFirst.size(2) == tensorInputSecond.size(2)) 31 | 32 | intWidth = tensorInputFirst.size(2) 33 | intHeight = tensorInputFirst.size(1) 34 | 35 | # assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 36 | # assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 37 | 38 | if True: 39 | tensorPreprocessedFirst = tensorInputFirst.view(1, 3, intHeight, intWidth) 40 | tensorPreprocessedSecond = tensorInputSecond.view(1, 3, intHeight, intWidth) 41 | 42 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) 43 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) 44 | 45 | tensorPreprocessedFirst = torch.nn.functional.upsample(input=tensorPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 46 | tensorPreprocessedSecond = torch.nn.functional.upsample(input=tensorPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 47 | 48 | tensorFlow = 20.0 * torch.nn.functional.upsample(input=moduleNetwork(tensorPreprocessedFirst, tensorPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False) 49 | 50 | tensorFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 51 | tensorFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 52 | 53 | tensorOutput.resize_(2, intHeight, intWidth).copy_(tensorFlow[0, :, :, :]) 54 | # end 55 | 56 | return tensorOutput # C x H x W 57 | 58 | 59 | def estimate_flow_batch(self, out_tensor): 60 | N, C, H, W = out_tensor.size() 61 | flow_tensor = torch.FloatTensor(N-1, 2, H, W) 62 | for i in range(N-1): 63 | first = out_tensor[i] 64 | second = out_tensor[i+1] 65 | flow_tensor[i] = estimate_flow_pair(first, second) 66 | return flow_tensor # N-1 x 2 x H x W 67 | -------------------------------------------------------------------------------- /src/models/common.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior 3 | ###################################################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import numpy as np 9 | 10 | 11 | def add_module(self, module): 12 | self.add_module(str(len(self) + 1), module) 13 | torch.nn.Module.add = add_module 14 | 15 | 16 | class Concat(nn.Module): 17 | """ 18 | Concatenate the output of multiple nn modules 19 | """ 20 | def __init__(self, dim, *args): 21 | super(Concat, self).__init__() 22 | self.dim = dim 23 | 24 | for idx, module in enumerate(args): 25 | self.add_module(str(idx), module) 26 | 27 | def forward(self, input): 28 | inputs = [] 29 | for module in self._modules.values(): 30 | inputs.append(module(input)) 31 | 32 | inputs_shapes2 = [x.shape[2] for x in inputs] 33 | inputs_shapes3 = [x.shape[3] for x in inputs] 34 | 35 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 36 | inputs_ = inputs 37 | else: 38 | target_shape2 = min(inputs_shapes2) 39 | target_shape3 = min(inputs_shapes3) 40 | 41 | inputs_ = [] 42 | for inp in inputs: 43 | diff2 = (inp.size(2) - target_shape2) // 2 44 | diff3 = (inp.size(3) - target_shape3) // 2 45 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 46 | 47 | return torch.cat(inputs_, dim=self.dim) 48 | 49 | def __len__(self): 50 | return len(self._modules) 51 | 52 | 53 | def act(act_fun = 'LeakyReLU'): 54 | """ 55 | Return an activation function or module (e.g. nn.ReLU) 56 | """ 57 | if isinstance(act_fun, str): 58 | if act_fun == 'LeakyReLU': 59 | return nn.LeakyReLU(0.2, inplace=True) 60 | elif act_fun == 'Swish': 61 | return Swish() 62 | elif act_fun == 'ReLU': 63 | return nn.ReLU() 64 | elif act_fun == 'none': 65 | return nn.Sequential() 66 | else: 67 | assert False 68 | else: 69 | return act_fun() 70 | 71 | 72 | def init_net(net, init_type='normal', gain=0.02): 73 | """ 74 | Initialize the network parameters 75 | """ 76 | def init_func(m): 77 | classname = m.__class__.__name__ 78 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 79 | if init_type == 'normal': 80 | init.normal_(m.weight.data, 0.0, gain) 81 | elif init_type == 'xavier': 82 | init.xavier_normal_(m.weight.data, gain=gain) 83 | elif init_type == 'kaiming': 84 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 85 | elif init_type == 'orthogonal': 86 | init.orthogonal_(m.weight.data, gain=gain) 87 | else: 88 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 89 | if hasattr(m, 'bias') and m.bias is not None: 90 | init.constant_(m.bias.data, 0.0) 91 | elif classname.find('BatchNorm2d') != -1: 92 | init.normal_(m.weight.data, 1.0, gain) 93 | init.constant_(m.bias.data, 0.0) 94 | 95 | print('initialize network with %s' % init_type) 96 | net.apply(init_func) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Internal Learning Approach to Video Inpainting 2 | Haotian zhang, Long Mai, Ning Xu, Zhaowen Wang, John Collomosse, Hailin Jin 3 | 4 | International Conference on Computer Vision (ICCV) 2019. 5 | 6 | [Paper](https://arxiv.org/abs/1909.07957) | 7 | [Project](https://cs.stanford.edu/~haotianz/research/video_inpainting/) | 8 | [Video](https://youtu.be/-MQZayP5tc0) 9 | 10 | 11 | 12 | 13 | 14 | ## Install 15 | The code has been tested on pytorch 1.0.0 with python 3.5 and cuda 9.0. Please refer to [requirements.txt](https://github.com/Haotianz94/IL_video_inpainting/blob/master/requirements.txt) for details. Alternatively, you can build a docker image using provided [Dockerfile](https://github.com/Haotianz94/IL_video_inpainting/blob/master/Dockerfile). 16 | 17 | Warning! We have noticed that the optimization may not converge on some GPUs when using pytorch==0.4.0. We have observed the issue on Titan V and Tesla V100. Therefore, we highly recommend upgrading your pytorch version above 1.0.0 to avoid the issue if you are training on these GPUs. 18 | 19 | 20 | 21 | ## Usage 22 | We provide two ways to test our video inpainting approach. Please first download the demo data from [here](https://drive.google.com/file/d/1MJDCjj1aIUbW0OK9UnewhXlkKX9zllQd/view?usp=sharing) into `data/` and download the pretrained model weights for PWC-Net from [here](https://drive.google.com/file/d/1XPaqITtUV11WpOpX1PeCkS4zdjI5tKb8/view?usp=sharing) (No need to unzip the weights file) into `pretrained_models/`. (The model was originally trained by Simon Niklaus from [here](https://github.com/sniklaus/pytorch-pwc)). 23 | 24 | * To run our demo, please run the following command: 25 | ``` 26 | python3 train.py --train_mode DIP-Vid-Flow --video_path data/bmx-trees.avi --mask_path data/bmx-trees_mask.avi --resize 192 384 --res_dir result/DIP_Vid_Flow 27 | ``` 28 | 29 | * Alternatively, you can run through our demo step by step using the provided jupyter notebook [demo.ipynb](https://github.com/Haotianz94/IL_video_inpainting/blob/master/demo.ipynb) 30 | 31 | 32 | ## Citation 33 | ``` 34 | @inproceedings{zhang2019internal, 35 | title={An Internal Learning Approach to Video Inpainting}, 36 | author={Zhang, Haotian and Mai, Long and Xu, Ning and Wang, Zhaowen and Collomosse, John and Jin, Hailin}, 37 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 38 | pages={2720--2729}, 39 | year={2019} 40 | } 41 | ``` 42 | 43 | 44 | ## License 45 | © 2019 Adobe. 46 | 47 | Adobe holds the copyright for all the files found in this repository. 48 | 49 | IL_video_inpainting is a project by Adobe Research. It is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) and may only be used for non-commercial purposes. See the LICENSE file for more information. 50 | 51 | Note We are not able to provide access to our Composed Dataset because we could not get copyright permission from the authors of the videos. 52 | 53 | ## Acknowledgement 54 | The implementation of our network architecture is mostly borrowed from the Deep Image Prior [repo](https://github.com/DmitryUlyanov/deep-image-prior). The implementation of the PWC-Net is borrowed from this [repo](https://github.com/sniklaus/pytorch-pwc). Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. 55 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import sys\n", 12 | "sys.path.append('src')\n", 13 | "from inpainting_test import InpaintingTest\n", 14 | "from configs.DIP import cfg as DIP_cfg\n", 15 | "from configs.DIP_Vid import cfg as DIP_Vid_cfg\n", 16 | "from configs.DIP_Vid_3DCN import cfg as DIP_Vid_3DCN_cfg\n", 17 | "from configs.DIP_Vid_Flow import cfg as DIP_Vid_Flow_cfg" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Choose training mode (DIP|DIP-Vid|DIP-Vid-3DCN|DIP-Vid-Flow)\n", 27 | "train_mode = 'DIP-Vid-Flow'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# Set up config\n", 37 | "if train_mode == 'DIP':\n", 38 | " cfg = DIP_cfg\n", 39 | "elif train_mode == 'DIP-Vid':\n", 40 | " cfg = DIP_Vid_cfg\n", 41 | "elif train_mode == 'DIP-Vid-3DCN':\n", 42 | " cfg = DIP_Vid_3DCN_cfg\n", 43 | "elif train_mode == 'DIP-Vid-Flow':\n", 44 | " cfg = DIP_Vid_Flow_cfg\n", 45 | " \n", 46 | "cfg['video_path'] = 'data/bmx-trees.avi'\n", 47 | "cfg['mask_path'] = 'data/bmx-trees_mask.avi'\n", 48 | "cfg['save_every_iter'] = 100\n", 49 | "cfg['resize'] = (192, 384)\n", 50 | "cfg['plot'] = True\n", 51 | "cfg['res_dir'] = 'result/test/DIP_Vid_Flow'" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "test = InpaintingTest(cfg)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "scrolled": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "test.create_data_loader()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "scrolled": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "test.visualize_single_batch()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "test.create_model()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "test.create_optimizer()" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "test.create_loss_function()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "scrolled": true 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "test.train()" 123 | ] 124 | } 125 | ], 126 | "metadata": { 127 | "kernelspec": { 128 | "display_name": "Python 3", 129 | "language": "python", 130 | "name": "python3" 131 | }, 132 | "language_info": { 133 | "codemirror_mode": { 134 | "name": "ipython", 135 | "version": 3 136 | }, 137 | "file_extension": ".py", 138 | "mimetype": "text/x-python", 139 | "name": "python", 140 | "nbconvert_exporter": "python", 141 | "pygments_lexer": "ipython3", 142 | "version": "3.5.2" 143 | } 144 | }, 145 | "nbformat": 4, 146 | "nbformat_minor": 2 147 | } 148 | -------------------------------------------------------------------------------- /src/models/encoder_decoder_2d.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior 3 | ###################################################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | from .common import * 8 | 9 | 10 | def EncoderDecoder2D( 11 | num_input_channels=2, num_output_channels=3, 12 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 13 | filter_size_down=3, filter_size_up=3, filter_size_skip=1, 14 | upsample_mode='nearest', downsample_mode='stride', 15 | need_sigmoid=True, need_bias=True, need1x1_up=True, 16 | pad='zero', act_fun='LeakyReLU'): 17 | """Assembles encoder-decoder with skip connections, using 2D convolutions. 18 | 19 | Arguments: 20 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 21 | pad (string): zero|reflection (default: 'zero') 22 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 23 | downsample_mode (string): 'stride|avg|max' (default: 'stride') 24 | """ 25 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 26 | 27 | n_scales = len(num_channels_down) 28 | 29 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 30 | upsample_mode = [upsample_mode]*n_scales 31 | 32 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 33 | downsample_mode = [downsample_mode]*n_scales 34 | 35 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 36 | filter_size_down = [filter_size_down]*n_scales 37 | 38 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 39 | filter_size_up = [filter_size_up]*n_scales 40 | 41 | last_scale = n_scales - 1 42 | 43 | cur_depth = None 44 | 45 | model = nn.Sequential() 46 | model_tmp = model 47 | 48 | input_depth = num_input_channels 49 | for i in range(len(num_channels_down)): 50 | 51 | deeper = nn.Sequential() 52 | skip = nn.Sequential() 53 | 54 | if num_channels_skip[i] != 0: 55 | model_tmp.add(Concat(1, skip, deeper)) 56 | else: 57 | model_tmp.add(deeper) 58 | 59 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 60 | 61 | if num_channels_skip[i] != 0: 62 | skip.add(conv(input_depth, num_channels_skip[i], filter_size_skip, bias=need_bias, pad=pad)) 63 | skip.add(bn(num_channels_skip[i])) 64 | skip.add(act(act_fun)) 65 | 66 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], stride=2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 67 | deeper.add(bn(num_channels_down[i])) 68 | deeper.add(act(act_fun)) 69 | 70 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 71 | deeper.add(bn(num_channels_down[i])) 72 | deeper.add(act(act_fun)) 73 | 74 | deeper_main = nn.Sequential() 75 | 76 | if i == len(num_channels_down) - 1: 77 | # The deepest 78 | k = num_channels_down[i] 79 | else: 80 | deeper.add(deeper_main) 81 | k = num_channels_up[i + 1] 82 | 83 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 84 | 85 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 86 | model_tmp.add(bn(num_channels_up[i])) 87 | model_tmp.add(act(act_fun)) 88 | 89 | if need1x1_up: 90 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 91 | model_tmp.add(bn(num_channels_up[i])) 92 | model_tmp.add(act(act_fun)) 93 | 94 | input_depth = num_channels_down[i] 95 | model_tmp = deeper_main 96 | 97 | model.add(FinalLayer(num_channels_up[0], num_output_channels, need_bias, pad, need_sigmoid)) 98 | return model 99 | 100 | 101 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 102 | downsampler = None 103 | if stride != 1 and downsample_mode != 'stride': 104 | 105 | if downsample_mode == 'avg': 106 | downsampler = nn.AvgPool2d(stride, stride) 107 | elif downsample_mode == 'max': 108 | downsampler = nn.MaxPool2d(stride, stride) 109 | else: 110 | assert False 111 | 112 | stride = 1 113 | 114 | padder = None 115 | to_pad = int((kernel_size - 1) / 2) 116 | if pad == 'reflection': 117 | padder = nn.ReflectionPad2d(to_pad) 118 | to_pad = 0 119 | 120 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 121 | 122 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 123 | return nn.Sequential(*layers) 124 | 125 | 126 | def bn(num_features): 127 | return nn.BatchNorm2d(num_features) 128 | 129 | 130 | class FinalLayer(nn.Module): 131 | """ 132 | Split output into image and flow branch 133 | """ 134 | def __init__(self, Cin, Cout, need_bias, pad, need_sigmoid): 135 | super(FinalLayer, self).__init__() 136 | self.conv_img = conv(Cin, 3, 1, bias=need_bias, pad=pad) 137 | if need_sigmoid: 138 | self.sigmoid = nn.Sigmoid() 139 | else: 140 | self.sigmoid = None 141 | if Cout > 3: 142 | self.conv_flow = conv(Cin, Cout-3, 1, bias=need_bias, pad=pad) 143 | else: 144 | self.conv_flow = None 145 | 146 | def forward(self, x): 147 | y_img = self.conv_img(x) 148 | if not self.sigmoid is None: 149 | y_img = self.sigmoid(y_img) 150 | 151 | if not self.conv_flow is None: 152 | y_flow = self.conv_flow(x) 153 | return torch.cat((y_img, y_flow), 1) 154 | else: 155 | return y_img -------------------------------------------------------------------------------- /src/models/encoder_decoder_3d.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior 3 | ###################################################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | from .common import * 8 | 9 | 10 | def EncoderDecoder3D( 11 | num_input_channels=1, num_output_channels=3, 12 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 13 | filter_size_down=(3, 3, 3), filter_size_up=(3, 3, 3), filter_size_skip=(1, 1, 1), 14 | upsample_mode='nearest', downsample_mode='stride', 15 | need_sigmoid=True, need_bias=True, need1x1_up=True, 16 | pad='zero', act_fun='LeakyReLU', 17 | ): 18 | 19 | """Assembles encoder-decoder with skip connections, using 3D convolutions. 20 | 21 | Arguments: 22 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 23 | pad (string): zero|reflection (default: 'zero') 24 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 25 | downsample_mode (string): 'stride|avg|max' (default: 'stride') 26 | """ 27 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 28 | 29 | n_scales = len(num_channels_down) 30 | 31 | if not isinstance(upsample_mode, list): 32 | upsample_mode = [upsample_mode]*n_scales 33 | 34 | if not isinstance(downsample_mode, list): 35 | downsample_mode = [downsample_mode]*n_scales 36 | 37 | if not isinstance(filter_size_down, list): 38 | filter_size_down = [filter_size_down]*n_scales 39 | 40 | if not isinstance(filter_size_up, list): 41 | filter_size_up = [filter_size_up]*n_scales 42 | 43 | last_scale = n_scales - 1 44 | 45 | cur_depth = None 46 | 47 | model = nn.Sequential() 48 | model_tmp = model 49 | 50 | input_depth = num_input_channels 51 | for i in range(len(num_channels_down)): 52 | 53 | deeper = nn.Sequential() 54 | skip = nn.Sequential() 55 | 56 | if num_channels_skip[i] != 0: 57 | model_tmp.add(Concat(1, skip, deeper)) 58 | else: 59 | model_tmp.add(deeper) 60 | 61 | bn_up = BatchNorm3D(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])) 62 | model_tmp.add(bn_up) 63 | 64 | if num_channels_skip[i] != 0: 65 | conv_skip = conv3d(input_depth, num_channels_skip[i], filter_size_skip, bias=need_bias, pad=pad) 66 | bn_skip = BatchNorm3D(num_channels_skip[i]) 67 | skip.add(conv_skip) 68 | skip.add(bn_skip) 69 | skip.add(act(act_fun)) 70 | 71 | conv_down = conv3d(input_depth, num_channels_down[i], filter_size_down[i], stride=(1, 2, 2), bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]) 72 | bn_down = BatchNorm3D(num_channels_down[i]) 73 | deeper.add(conv_down) 74 | deeper.add(bn_down) 75 | deeper.add(act(act_fun)) 76 | 77 | conv_down = conv3d(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad) 78 | bn_down = BatchNorm3D(num_channels_down[i]) 79 | deeper.add(conv_down) 80 | deeper.add(bn_down) 81 | deeper.add(act(act_fun)) 82 | 83 | deeper_main = nn.Sequential() 84 | 85 | if i == len(num_channels_down) - 1: 86 | # The deepest 87 | k = num_channels_down[i] 88 | else: 89 | deeper.add(deeper_main) 90 | k = num_channels_up[i + 1] 91 | 92 | deeper.add(Upsample3D(scale_factor=2, mode=upsample_mode[i])) 93 | 94 | conv_up = conv3d(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], bias=need_bias, pad=pad) 95 | bn_up = BatchNorm3D(num_channels_up[i]) 96 | model_tmp.add(conv_up) 97 | model_tmp.add(bn_up) 98 | model_tmp.add(act(act_fun)) 99 | 100 | if need1x1_up: 101 | conv_up = conv3d(num_channels_up[i], num_channels_up[i], kernel_size=(1,1,1), bias=need_bias, pad=pad) 102 | bn_up = BatchNorm3D(num_channels_up[i]) 103 | model_tmp.add(conv_up) 104 | model_tmp.add(bn_up) 105 | model_tmp.add(act(act_fun)) 106 | 107 | input_depth = num_channels_down[i] 108 | model_tmp = deeper_main 109 | 110 | conv_final = conv3d(num_channels_up[0], num_output_channels, kernel_size=(1,1,1), bias=need_bias, pad=pad) 111 | model.add(conv_final) 112 | if need_sigmoid: 113 | model.add(nn.Sigmoid()) 114 | 115 | return model 116 | 117 | 118 | def conv3d(in_channels, out_channels, kernel_size=(1,1,1), stride=(1,1,1), bias=True, pad='zero', downsample_mode='stride'): 119 | downsampler = None 120 | if stride != 1 and downsample_mode != 'stride': 121 | 122 | if downsample_mode == 'avg': 123 | downsampler = nn.AvgPool2d(stride, stride) 124 | elif downsample_mode == 'max': 125 | downsampler = nn.MaxPool2d(stride, stride) 126 | elif downsample_mode in ['lanczos2', 'lanczos3']: 127 | downsampler = Downsampler(n_planes=out_channels, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 128 | else: 129 | assert False 130 | 131 | stride = 1 132 | 133 | padder = None 134 | pad_D = int((kernel_size[0] - 1) / 2) 135 | pad_HW = int((kernel_size[1] - 1) / 2) 136 | to_pad = (pad_D, pad_HW, pad_HW) 137 | if pad == 'reflection': 138 | padder = ReflectionPad3D(pad_D, pad_HW) 139 | to_pad = (0, 0, 0) 140 | 141 | convolver = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=to_pad, bias=bias) 142 | 143 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 144 | return nn.Sequential(*layers) 145 | 146 | 147 | class BatchNorm3D(nn.Module): 148 | def __init__(self, num_features): 149 | super(BatchNorm3D, self).__init__() 150 | self.bn = nn.BatchNorm2d(num_features) 151 | 152 | def forward(self, x): 153 | assert(x.size(0) == 1) # 1 x C x D x H x W 154 | y = x.squeeze(0).transpose(0, 1).contiguous() # D x C x H x W 155 | y = self.bn(y) 156 | y = y.transpose(0, 1).unsqueeze(0) 157 | return y 158 | 159 | 160 | class Upsample3D(nn.Module): 161 | def __init__(self, scale_factor, mode): 162 | super(Upsample3D, self).__init__() 163 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) 164 | 165 | def forward(self, x): 166 | assert(x.size(0) == 1) # 1 x C x D x H x W 167 | y = x.squeeze(0).transpose(0, 1) # D x C x H x W 168 | y = self.upsample(y) 169 | return y.transpose(0, 1).unsqueeze(0) 170 | 171 | 172 | class ReflectionPad3D(nn.Module): 173 | def __init__(self, pad_D, pad_HW): 174 | super(ReflectionPad3D, self).__init__() 175 | self.padder_HW = nn.ReflectionPad2d(pad_HW) 176 | self.padder_D = nn.ReplicationPad3d((0, 0, 0, 0, pad_D, pad_D)) 177 | 178 | def forward(self, x): 179 | assert(x.size(0) == 1) # 1 x C x D x H x W 180 | y = x.squeeze(0).transpose(0, 1) # D x C x H x W 181 | y = self.padder_HW(y) 182 | y = y.transpose(0, 1).unsqueeze(0) # 1 x C x D x H x W 183 | return self.padder_D(y) -------------------------------------------------------------------------------- /src/correlation/correlation.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Code modified from https://github.com/sniklaus/pytorch-pwc 3 | ############################################################ 4 | 5 | import cupy 6 | import torch 7 | 8 | kernel_Correlation_rearrange = ''' 9 | extern "C" __global__ void kernel_Correlation_rearrange( 10 | const int n, 11 | const float* input, 12 | float* output 13 | ) { 14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 15 | 16 | if (intIndex >= n) { 17 | return; 18 | } 19 | 20 | int intSample = blockIdx.z; 21 | int intChannel = blockIdx.y; 22 | 23 | float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 24 | 25 | __syncthreads(); 26 | 27 | int intPaddedY = (intIndex / SIZE_3(input)) + 4; 28 | int intPaddedX = (intIndex % SIZE_3(input)) + 4; 29 | int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; 30 | 31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue; 32 | } 33 | ''' 34 | 35 | kernel_Correlation_updateOutput = ''' 36 | extern "C" __global__ void kernel_Correlation_updateOutput( 37 | const int n, 38 | const float* rbot0, 39 | const float* rbot1, 40 | float* top 41 | ) { 42 | extern __shared__ char patch_data_char[]; 43 | 44 | float *patch_data = (float *)patch_data_char; 45 | 46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 47 | int x1 = blockIdx.x + 4; 48 | int y1 = blockIdx.y + 4; 49 | int item = blockIdx.z; 50 | int ch_off = threadIdx.x; 51 | 52 | // Load 3D patch into shared shared memory 53 | for (int j = 0; j < 1; j++) { // HEIGHT 54 | for (int i = 0; i < 1; i++) { // WIDTH 55 | int ji_off = ((j * 1) + i) * SIZE_3(rbot0); 56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 58 | int idxPatchData = ji_off + ch; 59 | patch_data[idxPatchData] = rbot0[idx1]; 60 | } 61 | } 62 | } 63 | 64 | __syncthreads(); 65 | 66 | __shared__ float sum[32]; 67 | 68 | // Compute correlation 69 | for(int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 70 | sum[ch_off] = 0; 71 | 72 | int s2o = (top_channel % 9) - 4; 73 | int s2p = (top_channel / 9) - 4; 74 | 75 | for (int j = 0; j < 1; j++) { // HEIGHT 76 | for (int i = 0; i < 1; i++) { // WIDTH 77 | int ji_off = ((j * 1) + i) * SIZE_3(rbot0); 78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 79 | int x2 = x1 + s2o; 80 | int y2 = y1 + s2p; 81 | 82 | int idxPatchData = ji_off + ch; 83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 84 | 85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 86 | } 87 | } 88 | } 89 | 90 | __syncthreads(); 91 | 92 | if (ch_off == 0) { 93 | float total_sum = 0; 94 | for (int idx = 0; idx < 32; idx++) { 95 | total_sum += sum[idx]; 96 | } 97 | const int sumelems = SIZE_3(rbot0); 98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 100 | } 101 | } 102 | } 103 | ''' 104 | 105 | def cupy_kernel(strFunction, objectVariables): 106 | strKernel = globals()[strFunction] 107 | 108 | for strVariable in objectVariables: 109 | strKernel = strKernel.replace('SIZE_0(' + strVariable + ')', str(objectVariables[strVariable].size(0))) 110 | strKernel = strKernel.replace('SIZE_1(' + strVariable + ')', str(objectVariables[strVariable].size(1))) 111 | strKernel = strKernel.replace('SIZE_2(' + strVariable + ')', str(objectVariables[strVariable].size(2))) 112 | strKernel = strKernel.replace('SIZE_3(' + strVariable + ')', str(objectVariables[strVariable].size(3))) 113 | # end 114 | 115 | return strKernel 116 | # end 117 | 118 | @cupy.util.memoize(for_each_device=True) 119 | def cupy_launch(strFunction, strKernel): 120 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 121 | # end 122 | 123 | class FunctionCorrelation(torch.autograd.Function): 124 | @staticmethod 125 | def forward(ctx, first, second): 126 | ctx.save_for_backward(first, second) 127 | 128 | assert(first.is_contiguous() == True) 129 | assert(second.is_contiguous() == True) 130 | 131 | rbot0 = first.new(first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)).zero_() 132 | rbot1 = first.new(first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)).zero_() 133 | 134 | output = first.new(first.size(0), 81, first.size(2), first.size(3)).zero_() 135 | 136 | if first.is_cuda == True: 137 | class Stream: 138 | ptr = torch.cuda.current_stream().cuda_stream 139 | # end 140 | 141 | n = first.size(2) * first.size(3) 142 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 143 | 'input': first, 144 | 'output': rbot0 145 | }))( 146 | grid=tuple([ int((n + 16 - 1) / 16), first.size(1), first.size(0) ]), 147 | block=tuple([ 16, 1, 1 ]), 148 | args=[ n, first.data_ptr(), rbot0.data_ptr() ], 149 | stream=Stream 150 | ) 151 | 152 | n = second.size(2) * second.size(3) 153 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 154 | 'input': second, 155 | 'output': rbot1 156 | }))( 157 | grid=tuple([ int((n + 16 - 1) / 16), second.size(1), second.size(0) ]), 158 | block=tuple([ 16, 1, 1 ]), 159 | args=[ n, second.data_ptr(), rbot1.data_ptr() ], 160 | stream=Stream 161 | ) 162 | 163 | n = output.size(1) * output.size(2) * output.size(3) 164 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 165 | 'rbot0': rbot0, 166 | 'rbot1': rbot1, 167 | 'top': output 168 | }))( 169 | grid=tuple([ first.size(3), first.size(2), first.size(0) ]), 170 | block=tuple([ 32, 1, 1 ]), 171 | shared_mem=first.size(1) * 4, 172 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ], 173 | stream=Stream 174 | ) 175 | 176 | elif first.is_cuda == False: 177 | raise NotImplementedError() 178 | 179 | # end 180 | 181 | return output 182 | # end 183 | 184 | @staticmethod 185 | def backward(ctx, gradOutput): 186 | first, second = ctx.saved_tensors 187 | 188 | assert(gradOutput.is_contiguous() == True) 189 | 190 | gradFirst = first.new(first.size()).zero_() if ctx.needs_input_grad[0] == True else None 191 | gradSecond = first.new(first.size()).zero_() if ctx.needs_input_grad[1] == True else None 192 | 193 | if first.is_cuda == True: 194 | raise NotImplementedError() 195 | 196 | elif first.is_cuda == False: 197 | raise NotImplementedError() 198 | 199 | # end 200 | 201 | return gradFirst, gradSecond 202 | # end 203 | # end 204 | 205 | class ModuleCorrelation(torch.nn.Module): 206 | def __init__(self): 207 | super(ModuleCorrelation, self).__init__() 208 | # end 209 | 210 | def forward(self, tensorFirst, tensorSecond): 211 | correlation = FunctionCorrelation.apply 212 | return correlation(tensorFirst, tensorSecond) 213 | # end 214 | # end -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Code modified from https://github.com/DmitryUlyanov/deep-image-prior 3 | ###################################################################### 4 | 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | import os 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | ###################################################################### 14 | # Network input 15 | ###################################################################### 16 | 17 | def fill_noise(x, noise_type): 18 | """ 19 | Fill tensor `x` with noise of type `noise_type`. 20 | """ 21 | if noise_type == 'u': 22 | x.uniform_() 23 | elif noise_type == 'n': 24 | x.normal_() 25 | else: 26 | assert False 27 | 28 | 29 | def get_noise(batch_size, input_depth, method, spatial_size, noise_type='u', var=1./10): 30 | """ 31 | Return a pytorch.Tensor of size (`batch_size` x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 32 | initialized in a specific way. 33 | 34 | Args: 35 | input_depth: number of channels in the tensor 36 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 37 | spatial_size: spatial size of the tensor to initialize 38 | noise_type: 'u' for uniform; 'n' for normal 39 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 40 | """ 41 | if isinstance(spatial_size, int): 42 | spatial_size = (spatial_size, spatial_size) 43 | if method == 'noise': 44 | shape = [batch_size, input_depth, spatial_size[0], spatial_size[1]] 45 | net_input = torch.zeros(shape) 46 | 47 | fill_noise(net_input, noise_type) 48 | net_input *= var 49 | elif method == 'meshgrid': 50 | assert batch_size == 1 51 | assert input_depth == 2 52 | X, Y = np.meshgrid(np.arange(0, spatial_size[1]) / float(spatial_size[1] - 1), np.arange(0, spatial_size[0]) / float(spatial_size[0] - 1)) 53 | meshgrid = np.concatenate([X[None, :], Y[None, :]]) 54 | net_input= np_to_torch(meshgrid) 55 | else: 56 | assert False 57 | 58 | return net_input 59 | 60 | 61 | ###################################################################### 62 | # Flow related 63 | ###################################################################### 64 | 65 | def warp_torch(x, flo): 66 | """ 67 | Backward warp an image tensor (im2) to im1, according to the optical flow from im1 to im2 68 | 69 | x: [B, C, H, W] (im2) 70 | flo: [B, 2, H, W] flow 71 | """ 72 | B, C, H, W = x.size() 73 | # Mesh grid 74 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 75 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 76 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 77 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 78 | grid = torch.cat((xx, yy), 1).float() 79 | 80 | if x.is_cuda: 81 | grid = grid.cuda() 82 | vgrid_ori = grid + flo 83 | vgrid = torch.zeros_like(vgrid_ori) 84 | 85 | # Scale grid to [-1,1] 86 | vgrid[:, 0, :, :] = 2.0 * vgrid_ori[:, 0, :, :] / max(W - 1, 1) - 1.0 87 | vgrid[:, 1, :, :] = 2.0 * vgrid_ori[:, 1, :, :] / max(H - 1, 1) - 1.0 88 | 89 | vgrid = vgrid.permute(0, 2, 3, 1) 90 | output = torch.nn.functional.grid_sample(x, vgrid) 91 | mask = torch.ones(x.size()) 92 | if x.is_cuda: 93 | mask = mask.cuda() 94 | mask = torch.nn.functional.grid_sample(mask, vgrid) 95 | 96 | mask[mask < 0.999] = 0 97 | mask[mask > 0] = 1 98 | 99 | return output, mask 100 | 101 | 102 | def warp_np(x, flo): 103 | """ 104 | Backward warp an image numpy array (im2) to im1, according to the optical flow from im1 to im2 105 | 106 | x: [B, C, H, W] (im2) 107 | flo: [B, 2, H, W] flow 108 | """ 109 | if x.ndim != 4: 110 | assert(x.ndim == 3) 111 | # Add one dimention for single image 112 | x = x[None, ...] 113 | flo = flo[None, ...] 114 | add_dim = True 115 | else: 116 | add_dim = False 117 | 118 | output, mask = warp_torch(np_to_torch(x), np_to_torch(flo)) 119 | if add_dim: 120 | return output.numpy()[0], mask.numpy()[0] 121 | else: 122 | return output.numpy(), mask.numpy() 123 | 124 | 125 | def check_flow_occlusion(flow_f, flow_b): 126 | """ 127 | Compute occlusion map through forward/backward flow consistency check 128 | """ 129 | def get_occlusion(flow1, flow2): 130 | grid_flow = grid + flow1 131 | grid_flow[0, :, :] = 2.0 * grid_flow[0, :, :] / max(W - 1, 1) - 1.0 132 | grid_flow[1, :, :] = 2.0 * grid_flow[1, :, :] / max(H - 1, 1) - 1.0 133 | grid_flow = grid_flow.permute(1, 2, 0) 134 | flow2_inter = torch.nn.functional.grid_sample(flow2[None, ...], grid_flow[None, ...])[0] 135 | score = torch.exp(- torch.sum((flow1 + flow2_inter) ** 2, dim=0) / 2.) 136 | occlusion = (score > 0.5) 137 | return occlusion[None, ...].float() 138 | 139 | C, H, W = flow_f.size() 140 | # Mesh grid 141 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 142 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 143 | xx = xx.view(1, H, W) 144 | yy = yy.view(1, H, W) 145 | grid = torch.cat((xx, yy), 0).float() 146 | 147 | occlusion_f = get_occlusion(flow_f, flow_b) 148 | occlusion_b = get_occlusion(flow_b, flow_f) 149 | flow_f = torch.cat((flow_f, occlusion_f), 0) 150 | flow_b = torch.cat((flow_b, occlusion_b), 0) 151 | 152 | return flow_f, flow_b 153 | 154 | 155 | ###################################################################### 156 | # Visualization 157 | ###################################################################### 158 | 159 | def get_image_grid(images_np, nrow=8, padding=2): 160 | """ 161 | Create a grid from a list of images by concatenating them. 162 | """ 163 | images_torch = [np_to_torch(x) for x in images_np] 164 | torch_grid = torchvision.utils.make_grid(images_torch, nrow, padding) 165 | 166 | return torch_grid.numpy() 167 | 168 | 169 | def plot_image_grid(images_np, nrow=8, padding=2, factor=1, interpolation='lanczos'): 170 | """ 171 | Layout images in a grid 172 | 173 | Args: 174 | images_np: list of images, each image is np.array of size 3xHxW of 1xHxW 175 | nrow: how many images will be in one row 176 | factor: size if the plt.figure 177 | interpolation: interpolation used in plt.imshow 178 | """ 179 | n_channels = max(x.shape[0] for x in images_np) 180 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" 181 | 182 | images_np = [np_cvt_color(x) if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] 183 | 184 | grid = get_image_grid(images_np, nrow, padding) 185 | 186 | plt.figure(figsize=(len(images_np) + factor, 12 + factor)) 187 | plt.axis('off') 188 | 189 | if images_np[0].shape[0] == 1: 190 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation) 191 | else: 192 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) 193 | 194 | plt.show() 195 | 196 | return grid 197 | 198 | 199 | ###################################################################### 200 | # Data type transform 201 | ###################################################################### 202 | 203 | def np_cvt_color(img_np): 204 | """ 205 | Convert image from BGR/RGB to RGb/BGR 206 | From B x C x W x H to B x C x W x H 207 | """ 208 | if len(img_np) == 4: 209 | return [img[::-1] for img in img_np] 210 | else: 211 | return img_np[::-1] 212 | 213 | 214 | def np_to_cv2(img_np): 215 | """ 216 | Convert image in numpy.array to cv2 image. 217 | From C x W x H [0..1] to W x H x C [0...255] 218 | """ 219 | return np.clip(img_np.transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8) 220 | 221 | 222 | def np_to_torch(img_np): 223 | """ 224 | Convert image in numpy.array to torch.Tensor. 225 | From B x C x W x H [0..1] to B x C x W x H [0..1] 226 | """ 227 | return torch.from_numpy(np.ascontiguousarray(img_np)) 228 | 229 | 230 | def torch_to_np(img_var): 231 | """ 232 | Convert an image in torch.Tensor format to numpy.array. 233 | From B x C x W x H [0..1] to B x C x W x H [0..1] 234 | """ 235 | return img_var.detach().cpu().numpy() 236 | 237 | 238 | ###################################################################### 239 | # Others 240 | ###################################################################### 241 | 242 | def mkdir(dir): 243 | os.makedirs(dir, exist_ok=True) 244 | os.chmod(dir, 0o777) 245 | 246 | 247 | def build_dir(res_dir, subpath): 248 | mkdir(os.path.join(res_dir, subpath)) 249 | res_type_list = ['stitch', 'full_with_boundary', 'full'] 250 | for res_type in res_type_list: 251 | sub_res_dir = os.path.join(res_dir, subpath, res_type) 252 | mkdir(sub_res_dir) 253 | mkdir(os.path.join(sub_res_dir, 'sequence')) 254 | mkdir(os.path.join(sub_res_dir, 'batch')) 255 | 256 | 257 | def get_model_num_parameters(model): 258 | """ 259 | Return total number of parameters in model 260 | """ 261 | total_num=0 262 | if type(model) == type(dict()): 263 | for key in model: 264 | for p in model[key].parameters(): 265 | total_num+=p.nelement() 266 | else: 267 | for p in model.parameters(): 268 | total_num+=p.nelement() 269 | return total_num -------------------------------------------------------------------------------- /src/models/pwc_net.py: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Code modified from https://github.com/sniklaus/pytorch-pwc 3 | ############################################################ 4 | 5 | import torch 6 | from correlation import correlation 7 | 8 | 9 | class PWC_Net(torch.nn.Module): 10 | def __init__(self): 11 | super(PWC_Net, self).__init__() 12 | 13 | class Extractor(torch.nn.Module): 14 | def __init__(self): 15 | super(Extractor, self).__init__() 16 | 17 | self.moduleOne = torch.nn.Sequential( 18 | torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), 19 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 20 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 21 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 22 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 23 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 24 | ) 25 | 26 | self.moduleTwo = torch.nn.Sequential( 27 | torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 28 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 29 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 30 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 31 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 32 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 33 | ) 34 | 35 | self.moduleThr = torch.nn.Sequential( 36 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 37 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 38 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 39 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 40 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 41 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 42 | ) 43 | 44 | self.moduleFou = torch.nn.Sequential( 45 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), 46 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 47 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 48 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 49 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 50 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 51 | ) 52 | 53 | self.moduleFiv = torch.nn.Sequential( 54 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), 55 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 56 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 57 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 58 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 59 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 60 | ) 61 | 62 | self.moduleSix = torch.nn.Sequential( 63 | torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), 64 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 65 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 66 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 67 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 68 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 69 | ) 70 | # end 71 | 72 | def forward(self, tensorInput): 73 | tensorOne = self.moduleOne(tensorInput) 74 | tensorTwo = self.moduleTwo(tensorOne) 75 | tensorThr = self.moduleThr(tensorTwo) 76 | tensorFou = self.moduleFou(tensorThr) 77 | tensorFiv = self.moduleFiv(tensorFou) 78 | tensorSix = self.moduleSix(tensorFiv) 79 | 80 | return [ tensorOne, tensorTwo, tensorThr, tensorFou, tensorFiv, tensorSix ] 81 | # end 82 | # end 83 | 84 | class Backward(torch.nn.Module): 85 | def __init__(self): 86 | super(Backward, self).__init__() 87 | # end 88 | 89 | def forward(self, tensorInput, tensorFlow): 90 | if hasattr(self, 'tensorPartial') == False or self.tensorPartial.size(0) != tensorFlow.size(0) or self.tensorPartial.size(2) != tensorFlow.size(2) or self.tensorPartial.size(3) != tensorFlow.size(3): 91 | self.tensorPartial = torch.FloatTensor().resize_(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3)).fill_(1.0).cuda() 92 | # end 93 | 94 | if hasattr(self, 'tensorGrid') == False or self.tensorGrid.size(0) != tensorFlow.size(0) or self.tensorGrid.size(2) != tensorFlow.size(2) or self.tensorGrid.size(3) != tensorFlow.size(3): 95 | torchHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3)) 96 | torchVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), 1, tensorFlow.size(2), tensorFlow.size(3)) 97 | 98 | self.tensorGrid = torch.cat([ torchHorizontal, torchVertical ], 1).cuda() 99 | # end 100 | 101 | tensorInput = torch.cat([ tensorInput, self.tensorPartial ], 1) 102 | tensorFlow = torch.cat([ tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0) ], 1) 103 | 104 | tensorOutput = torch.nn.functional.grid_sample(input=tensorInput, grid=(self.tensorGrid + tensorFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros') 105 | 106 | tensorMask = tensorOutput[:, -1:, :, :]; tensorMask[tensorMask > 0.999] = 1.0; tensorMask[tensorMask < 1.0] = 0.0 107 | 108 | return tensorOutput[:, :-1, :, :] * tensorMask 109 | # end 110 | # end 111 | 112 | class Decoder(torch.nn.Module): 113 | def __init__(self, intLevel): 114 | super(Decoder, self).__init__() 115 | 116 | intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] 117 | intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] 118 | 119 | if intLevel < 6: self.moduleUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) 120 | if intLevel < 6: self.moduleUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) 121 | 122 | if intLevel < 6: self.dblBackward = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] 123 | if intLevel < 6: self.moduleBackward = Backward() 124 | 125 | self.moduleCorrelation = correlation.ModuleCorrelation() 126 | self.moduleCorreleaky = torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 127 | 128 | self.moduleOne = torch.nn.Sequential( 129 | torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), 130 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 131 | ) 132 | 133 | self.moduleTwo = torch.nn.Sequential( 134 | torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), 135 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 136 | ) 137 | 138 | self.moduleThr = torch.nn.Sequential( 139 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), 140 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 141 | ) 142 | 143 | self.moduleFou = torch.nn.Sequential( 144 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), 145 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 146 | ) 147 | 148 | self.moduleFiv = torch.nn.Sequential( 149 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), 150 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 151 | ) 152 | 153 | self.moduleSix = torch.nn.Sequential( 154 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) 155 | ) 156 | # end 157 | 158 | def forward(self, tensorFirst, tensorSecond, objectPrevious): 159 | tensorFlow = None 160 | tensorFeat = None 161 | 162 | if objectPrevious is None: 163 | tensorFlow = None 164 | tensorFeat = None 165 | 166 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, tensorSecond)) 167 | 168 | tensorFeat = torch.cat([ tensorVolume ], 1) 169 | 170 | elif objectPrevious is not None: 171 | tensorFlow = self.moduleUpflow(objectPrevious['tensorFlow']) 172 | tensorFeat = self.moduleUpfeat(objectPrevious['tensorFeat']) 173 | 174 | tensorVolume = self.moduleCorreleaky(self.moduleCorrelation(tensorFirst, self.moduleBackward(tensorSecond, tensorFlow * self.dblBackward))) 175 | 176 | tensorFeat = torch.cat([ tensorVolume, tensorFirst, tensorFlow, tensorFeat ], 1) 177 | 178 | # end 179 | 180 | tensorFeat = torch.cat([ self.moduleOne(tensorFeat), tensorFeat ], 1) 181 | tensorFeat = torch.cat([ self.moduleTwo(tensorFeat), tensorFeat ], 1) 182 | tensorFeat = torch.cat([ self.moduleThr(tensorFeat), tensorFeat ], 1) 183 | tensorFeat = torch.cat([ self.moduleFou(tensorFeat), tensorFeat ], 1) 184 | tensorFeat = torch.cat([ self.moduleFiv(tensorFeat), tensorFeat ], 1) 185 | 186 | tensorFlow = self.moduleSix(tensorFeat) 187 | 188 | return { 189 | 'tensorFlow': tensorFlow, 190 | 'tensorFeat': tensorFeat 191 | } 192 | # end 193 | # end 194 | 195 | class Refiner(torch.nn.Module): 196 | def __init__(self): 197 | super(Refiner, self).__init__() 198 | 199 | self.moduleMain = torch.nn.Sequential( 200 | torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), 201 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 202 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), 203 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 204 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), 205 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 206 | torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), 207 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 208 | torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), 209 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 210 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), 211 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 212 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) 213 | ) 214 | # end 215 | 216 | def forward(self, tensorInput): 217 | return self.moduleMain(tensorInput) 218 | # end 219 | # end 220 | 221 | self.moduleExtractor = Extractor() 222 | 223 | self.moduleTwo = Decoder(2) 224 | self.moduleThr = Decoder(3) 225 | self.moduleFou = Decoder(4) 226 | self.moduleFiv = Decoder(5) 227 | self.moduleSix = Decoder(6) 228 | 229 | self.moduleRefiner = Refiner() 230 | # end 231 | 232 | def forward(self, tensorFirst, tensorSecond): 233 | tensorFirst = self.moduleExtractor(tensorFirst) 234 | tensorSecond = self.moduleExtractor(tensorSecond) 235 | 236 | objectEstimate = self.moduleSix(tensorFirst[-1], tensorSecond[-1], None) 237 | objectEstimate = self.moduleFiv(tensorFirst[-2], tensorSecond[-2], objectEstimate) 238 | objectEstimate = self.moduleFou(tensorFirst[-3], tensorSecond[-3], objectEstimate) 239 | objectEstimate = self.moduleThr(tensorFirst[-4], tensorSecond[-4], objectEstimate) 240 | objectEstimate = self.moduleTwo(tensorFirst[-5], tensorSecond[-5], objectEstimate) 241 | 242 | return objectEstimate['tensorFlow'] + self.moduleRefiner(objectEstimate['tensorFeat']) 243 | # end 244 | # end -------------------------------------------------------------------------------- /src/inpainting_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from scipy import ndimage 5 | import cv2 6 | import random 7 | 8 | from flow_estimator import FlowEstimator 9 | from models.perceptual import LossNetwork 10 | from utils import * 11 | 12 | 13 | class InpaintingDataset(object): 14 | """ 15 | Data loader for the input video 16 | """ 17 | def __init__(self, cfg): 18 | self.cfg = cfg 19 | if not os.path.exists(self.cfg['video_path']): 20 | raise Exception("Input video not found: {}".format(self.cfg['video_path'])) 21 | if not os.path.exists(self.cfg['mask_path']): 22 | raise Exception("Input mask not found: {}".format(self.cfg['mask_path'])) 23 | 24 | cap = cv2.VideoCapture(self.cfg['video_path']) 25 | frame_sum_true = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | self.frame_W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 27 | self.frame_H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 28 | cap.release() 29 | 30 | if not self.cfg['resize'] is None: 31 | self.frame_H, self.frame_W = self.cfg['resize'] 32 | self.frame_sum = self.cfg['frame_sum'] = min(self.cfg['frame_sum'], frame_sum_true) 33 | self.frame_size = self.cfg['frame_size'] = (self.frame_H , self.frame_W) 34 | self.batch_size = self.cfg['batch_size'] 35 | self.batch_idx = 0 36 | self.batch_list_train = None 37 | 38 | if self.cfg['use_perceptual']: 39 | self.netL = LossNetwork(None, self.cfg['perceptual_layers']) 40 | 41 | self.init_frame_mask() 42 | self.init_input() 43 | self.init_flow() 44 | self.init_flow_mask() 45 | self.init_perceptual_mask() 46 | self.init_batch_list() 47 | 48 | 49 | def init_frame_mask(self): 50 | """ 51 | Load input video and mask 52 | """ 53 | self.image_all = [] 54 | self.mask_all = [] 55 | self.contour_all = [] 56 | 57 | cap_video = cv2.VideoCapture(self.cfg['video_path']) 58 | cap_mask = cv2.VideoCapture(self.cfg['mask_path']) 59 | for fid in range(self.frame_sum): 60 | frame, mask = self.load_single_frame(cap_video, cap_mask) 61 | contour, hier = cv2.findContours(mask[0].astype('uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 62 | self.image_all.append(frame) 63 | self.mask_all.append(mask) 64 | self.contour_all.append(contour) 65 | cap_video.release() 66 | cap_mask.release() 67 | self.image_all = np.array(self.image_all) 68 | self.mask_all = np.array(self.mask_all) 69 | 70 | 71 | def init_input(self): 72 | """ 73 | Generate input noise map 74 | """ 75 | input_noise = get_noise(self.frame_sum, self.cfg['input_channel'], 'noise', self.frame_size, var=self.cfg['input_ratio']).float().detach() 76 | input_noise = torch_to_np(input_noise) # N x C x H x W 77 | self.input_noise = input_noise 78 | 79 | 80 | def init_flow(self): 81 | """ 82 | Estimate flow using PWC-Net 83 | """ 84 | self.cfg['flow_value_max'] = self.flow_value_max = None 85 | if self.cfg['use_flow']: 86 | flow_estimator = FlowEstimator() 87 | print('Loading input video and estimating flow...') 88 | self.flow_all = { ft : [] for ft in self.cfg['flow_type']} 89 | self.flow_value_max = {} 90 | 91 | for bs in self.cfg['batch_stride']: 92 | f, b = 'f' + str(bs), 'b' + str(bs) 93 | for fid in range(0, self.frame_sum - bs): 94 | frame_first = np_to_torch(self.image_all[fid]).clone() 95 | frame_second = np_to_torch(self.image_all[fid + bs]).clone() 96 | frame_first = frame_first.cuda() 97 | frame_second = frame_second.cuda() 98 | flow_f = flow_estimator.estimate_flow_pair(frame_first, frame_second).detach().cpu() 99 | flow_b = flow_estimator.estimate_flow_pair(frame_second, frame_first).detach().cpu() 100 | torch.cuda.empty_cache() 101 | 102 | flow_f, flow_b = check_flow_occlusion(flow_f, flow_b) 103 | self.flow_all[f].append(flow_f.numpy()) 104 | self.flow_all[b].append(flow_b.numpy()) 105 | 106 | for bs in self.cfg['batch_stride']: 107 | f, b = 'f' + str(bs), 'b' + str(bs) 108 | self.flow_all[f] = np.array(self.flow_all[f] + [self.flow_all[f][0]] * bs, dtype=np.float32) 109 | self.flow_all[b] = np.array([self.flow_all[b][0]] * bs + self.flow_all[b], dtype=np.float32) 110 | self.flow_value_max[bs] = max(np.abs(self.flow_all[f]).max().astype('float'), \ 111 | np.abs(self.flow_all[b]).max().astype('float')) 112 | self.cfg['flow_value_max'] = self.flow_value_max 113 | 114 | 115 | def init_flow_mask(self): 116 | """ 117 | Pre-compute warped mask and intersection of warped mask with original mask 118 | """ 119 | if self.cfg['use_flow']: 120 | self.mask_warp_all = { ft : [] for ft in self.cfg['flow_type']} 121 | self.mask_flow_all = { ft : [] for ft in self.cfg['flow_type']} 122 | for bs in self.cfg['batch_stride']: 123 | f, b = 'f' + str(bs), 'b' + str(bs) 124 | # forward 125 | mask_warpf, _ = warp_np(self.mask_all[bs:], self.flow_all[f][:-bs][:, :2, ...]) 126 | mask_warpf = (mask_warpf > 0).astype(np.float32) 127 | mask_flowf = 1. - (1. - mask_warpf) * (1. - self.mask_all[:-bs]) 128 | self.mask_warp_all[f] = np.concatenate((mask_warpf, self.mask_all[-bs:]), 0) 129 | self.mask_flow_all[f] = np.concatenate((mask_flowf, self.mask_all[-bs:]), 0) 130 | # backward 131 | mask_warpb, _ = warp_np(self.mask_all[:-bs], self.flow_all[b][bs:][:, :2, ...]) 132 | mask_warpb = (mask_warpb > 0).astype(np.float32) 133 | mask_flowb = 1. - (1. - mask_warpb) * (1. - self.mask_all[bs:]) 134 | self.mask_warp_all[b] = np.concatenate((self.mask_all[:bs], mask_warpb), 0) 135 | self.mask_flow_all[b] = np.concatenate((self.mask_all[:bs], mask_flowb), 0) 136 | 137 | 138 | def init_perceptual_mask(self): 139 | """ 140 | Pre-compute shrinked mask for perceptual loss 141 | """ 142 | if self.cfg['use_perceptual']: 143 | self.mask_per_all = [] 144 | mask_per = self.netL(np_to_torch(self.mask_all)) 145 | for i, mask in enumerate(mask_per): 146 | self.mask_per_all.append((mask.detach().numpy() > 0).astype(np.float32)) 147 | 148 | 149 | def init_batch_list(self): 150 | """ 151 | List all the possible batch permutations 152 | """ 153 | if self.cfg['use_flow']: 154 | self.batch_list = [] 155 | for flow_type in self.cfg['flow_type']: 156 | batch_stride = int(flow_type[1]) 157 | for batch_idx in range(0, self.frame_sum - (self.batch_size - 1) * batch_stride, self.cfg['traverse_step']): 158 | self.batch_list.append((batch_idx, batch_stride, [flow_type])) 159 | if self.cfg['batch_mode'] == 'random': 160 | random.shuffle(self.batch_list) 161 | else: 162 | for bs in self.cfg['batch_stride']: 163 | self.batch_list = self.batch_list = [(i, bs, []) for i in range(self.frame_sum - self.batch_size + 1)] 164 | if self.cfg['batch_mode'] == 'random': 165 | median = self.batch_list[len(self.batch_list) // 2] 166 | random.shuffle(self.batch_list) 167 | self.batch_list.remove(median) 168 | self.batch_list.append(median) 169 | 170 | 171 | def set_mode(self, mode): 172 | if mode == 'infer': 173 | self.batch_list_train = self.batch_list 174 | self.batch_list = [(i, 1, ['f1']) for i in range(self.frame_sum - self.batch_size + 1)] 175 | elif mode == 'train': 176 | if not self.batch_list_train is None: 177 | self.batch_list = self.batch_list_train 178 | else: 179 | self.init_batch_list() 180 | 181 | 182 | def next_batch(self): 183 | if len(self.batch_list) == 0: 184 | self.init_batch_list() 185 | return None 186 | else: 187 | (batch_idx, batch_stride, flow_type) = self.batch_list[0] 188 | self.batch_list = self.batch_list[1:] 189 | return self.get_batch_data(batch_idx, batch_stride, flow_type) 190 | 191 | 192 | def get_batch_data(self, batch_idx=0, batch_stride=1, flow_type=[]): 193 | """ 194 | Collect batch data for centain batch 195 | """ 196 | cur_batch = range(batch_idx, batch_idx + self.batch_size*batch_stride, batch_stride) 197 | batch_data = {} 198 | input_batch, img_batch, mask_batch, contour_batch = [], [], [], [] 199 | if self.cfg['use_flow']: 200 | flow_batch = { ft : [] for ft in flow_type} 201 | mask_flow_batch = { ft : [] for ft in flow_type} 202 | mask_warp_batch = { ft : [] for ft in flow_type} 203 | if self.cfg['use_perceptual']: 204 | mask_per_batch = [ [] for _ in self.cfg['perceptual_layers']] 205 | 206 | for i, fid in enumerate(cur_batch): 207 | input_batch.append(self.input_noise[fid]) 208 | img_batch.append(self.image_all[fid]) 209 | mask_batch.append(self.mask_all[fid]) 210 | contour_batch.append(self.contour_all[fid]) 211 | if self.cfg['use_flow']: 212 | for ft in flow_type: 213 | flow_batch[ft].append(self.flow_all[ft][fid]) 214 | mask_flow_batch[ft].append(self.mask_flow_all[ft][fid]) 215 | mask_warp_batch[ft].append(self.mask_warp_all[ft][fid]) 216 | if self.cfg['use_perceptual']: 217 | for l in range(len(self.cfg['perceptual_layers'])): 218 | mask_per_batch[l].append(self.mask_per_all[l][fid]) 219 | 220 | 221 | if self.cfg['use_flow']: 222 | for ft in flow_type: 223 | idx1, idx2 = (0, -1) if 'f' in ft else (1, self.batch_size) 224 | flow_batch[ft] = np.array(flow_batch[ft][idx1:idx2]) 225 | mask_flow_batch[ft] = np.array(mask_flow_batch[ft][idx1:idx2]) 226 | mask_warp_batch[ft] = np.array(mask_warp_batch[ft][idx1:idx2]) 227 | if self.cfg['use_perceptual']: 228 | for l in range(len(self.cfg['perceptual_layers'])): 229 | mask_per_batch[l] = np.array(mask_per_batch[l]) 230 | 231 | batch_data['cur_batch'] = cur_batch 232 | batch_data['batch_idx'] = batch_idx 233 | batch_data['batch_stride'] = batch_stride 234 | batch_data['input_batch'] = np.array(input_batch) 235 | batch_data['img_batch'] = np.array(img_batch) 236 | batch_data['mask_batch'] = np.array(mask_batch) 237 | batch_data['contour_batch'] = contour_batch 238 | if self.cfg['use_flow']: 239 | batch_data['flow_type'] = flow_type 240 | batch_data['flow_batch'] = flow_batch 241 | batch_data['mask_flow_batch'] = mask_flow_batch 242 | batch_data['mask_warp_batch'] = mask_warp_batch 243 | batch_data['flow_value_max'] = self.flow_value_max[batch_stride] 244 | if self.cfg['use_perceptual']: 245 | batch_data['mask_per_batch'] = mask_per_batch 246 | 247 | return batch_data 248 | 249 | 250 | def get_median_batch(self): 251 | return self.get_batch_data(int((self.cfg['frame_sum']) // 2), 1, ['f1']) 252 | 253 | 254 | def get_all_data(self): 255 | """ 256 | Result a batch containing all the frames 257 | """ 258 | batch_data = {} 259 | batch_data['input_batch'] = np.array(self.input_noise[:self.frame_sum]) 260 | batch_data['img_batch'] = self.image_all 261 | batch_data['mask_batch'] = self.mask_all 262 | batch_data['contour_batch'] = self.contour_all 263 | if self.cfg['use_perceptual']: 264 | batch_data['mask_per_batch'] = self.mask_per_all 265 | if self.cfg['use_flow']: 266 | batch_data['flow_type'] = self.cfg['flow_type'] 267 | batch_data['flow_batch'] = self.flow_all 268 | batch_data['mask_flow_batch'] = self.mask_flow_all 269 | batch_data['mask_warp_batch'] = self.mask_warp_all 270 | return batch_data 271 | 272 | 273 | def load_single_frame(self, cap_video, cap_mask): 274 | gt = self.load_image(cap_video, False, self.frame_size) 275 | mask = self.load_image(cap_mask, True, self.frame_size) 276 | if self.cfg['dilation_iter'] > 0: 277 | mask = ndimage.binary_dilation(mask > 0, iterations=self.cfg['dilation_iter']).astype(np.float32) 278 | return gt, mask 279 | 280 | 281 | def load_image(self, cap, is_mask, resize=None): 282 | _, img = cap.read() 283 | if not resize is None: 284 | img = self.crop_and_resize(img, resize) 285 | if is_mask: 286 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None] 287 | img = (img > 127) * 255 288 | img = img.astype('uint8') 289 | img_convert = img.transpose(2, 0, 1) 290 | return img_convert.astype(np.float32) / 255 291 | 292 | 293 | def crop_and_resize(self, img, resize): 294 | """ 295 | Crop and resize img, keeping relative ratio unchanged 296 | """ 297 | h, w = img.shape[:2] 298 | source = 1. * h / w 299 | target = 1. * resize[0] / resize[1] 300 | if source > target: 301 | margin = int((h - w * target) // 2) 302 | img = img[margin:h-margin] 303 | elif source < target: 304 | margin = int((w - h / target) // 2) 305 | img = img[:, margin:w-margin] 306 | img = cv2.resize(img, (resize[1], resize[0]), interpolation=self.cfg['interpolation']) 307 | return img -------------------------------------------------------------------------------- /src/inpainting_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import time 5 | import os 6 | import cv2 7 | import logging 8 | 9 | from inpainting_dataset import InpaintingDataset 10 | from models.encoder_decoder_2d import EncoderDecoder2D 11 | from models.encoder_decoder_3d import EncoderDecoder3D 12 | from models.perceptual import LossNetwork 13 | from utils import * 14 | 15 | 16 | class InpaintingTest(object): 17 | """ 18 | Internal learning framework 19 | """ 20 | def __init__(self, cfg): 21 | self.cfg = cfg 22 | self.log = {} 23 | self.log['cfg'] = self.cfg.copy() 24 | for loss_name in self.cfg['loss_weight']: 25 | self.log['loss_' + loss_name + '_train'] = [ [] for _ in range(self.cfg['num_pass'])] 26 | self.log['loss_' + loss_name + '_infer'] = [ [] for _ in range(self.cfg['num_pass'])] 27 | self.data_loader = None 28 | self.netG = None 29 | self.optimizer_G = None 30 | 31 | # Pre-process cfg 32 | self.cfg['use_perceptual'] = self.cfg['loss_weight']['perceptual'] > 0 33 | self.cfg['use_flow'] = self.cfg['loss_weight']['recon_flow'] > 0 34 | 35 | if self.cfg['use_flow']: 36 | self.cfg['flow_type'] = [] 37 | for bs in self.cfg['batch_stride']: 38 | self.cfg['flow_type'] += ['f' + str(bs), 'b' + str(bs)] 39 | self.cfg['flow_channel_map'] = {} 40 | channel_idx = self.cfg['output_channel_img'] 41 | for ft in self.cfg['flow_type']: 42 | self.cfg['flow_channel_map'][ft] = (channel_idx, channel_idx+2) 43 | channel_idx += 2 44 | else: 45 | self.cfg['flow_type'] = None 46 | 47 | # Build result folder 48 | res_dir = self.cfg['res_dir'] 49 | if not res_dir is None: 50 | res_dir = os.path.join(res_dir, os.path.basename(self.cfg['video_path']).split('.')[0]) 51 | self.cfg['res_dir'] = res_dir 52 | if os.path.exists(res_dir): 53 | print("Warning: Video folder existed!") 54 | mkdir(res_dir) 55 | mkdir(os.path.join(res_dir, 'model')) 56 | for pass_idx in range(self.cfg['num_pass']): 57 | if (pass_idx + 1) % self.cfg['save_every_pass'] == 0: 58 | res_dir = os.path.join(self.cfg['res_dir'], '{:03}'.format(pass_idx + 1)) 59 | mkdir(res_dir) 60 | iter = self.cfg['save_every_iter'] 61 | while iter <= self.cfg['num_iter']: 62 | build_dir(res_dir, '{:05}'.format(iter)) 63 | iter += self.cfg['save_every_iter'] 64 | build_dir(res_dir, 'final') 65 | if self.cfg['train_mode'] == 'DIP': 66 | build_dir(res_dir, 'best_nonhole') 67 | 68 | # Setup logging 69 | logging.basicConfig(level=self.cfg['logging_level'], format='%(message)s') 70 | self.logger = logging.getLogger(__name__) 71 | self.log_handler = None 72 | if not self.cfg['res_dir'] is None: 73 | self.log_handler = logging.FileHandler(os.path.join(self.cfg['res_dir'], 'log.txt')) 74 | # self.log_handler.setLevel(logging.DEBUG) 75 | # formatter = logging.Formatter('%(message)s') 76 | # self.log_handler.setFormatter(formatter) 77 | self.logger.addHandler(self.log_handler) 78 | self.logger.info('========================================== Config ==========================================') 79 | for key in sorted(self.cfg): 80 | self.logger.info('[{}]: {}'.format(key, str(self.cfg[key]))) 81 | 82 | 83 | def create_data_loader(self): 84 | self.logger.info('========================================== Dataset ==========================================') 85 | 86 | self.data_loader = InpaintingDataset(self.cfg) 87 | 88 | self.logger.info("[Video name]: {}".format(os.path.basename(self.cfg['video_path']))) 89 | self.logger.info("[Mask name]: {}".format(os.path.basename(self.cfg['mask_path']))) 90 | self.logger.info("[Frame sum]: {}".format(self.cfg['frame_sum'])) 91 | self.logger.info("[Batch size]: {}".format(self.cfg['batch_size'])) 92 | self.logger.info("[Frame size]: {}".format(self.cfg['frame_size'])) 93 | self.logger.info("[Flow type]: {}".format(self.cfg['flow_type'])) 94 | self.logger.info("[Flow_value_max]: {}".format(self.cfg['flow_value_max'])) 95 | 96 | self.log['input_noise'] = self.data_loader.input_noise 97 | 98 | 99 | def visualize_single_batch(self): 100 | """ 101 | Randomly visualize one batch data 102 | """ 103 | batch_data = self.data_loader.next_batch() 104 | input_batch = batch_data['input_batch'] 105 | img_batch = batch_data['img_batch'] 106 | mask_batch = batch_data['mask_batch'] 107 | nonhole_batch = img_batch * (1 - mask_batch) 108 | 109 | if self.cfg['batch_size'] == 1: 110 | plot_image_grid(np.concatenate((img_batch, nonhole_batch), 0), 2, padding=3, factor=10) 111 | else: 112 | plot_image_grid(np.concatenate((img_batch, nonhole_batch), 0), self.cfg['batch_size'], padding=3, factor=15) 113 | 114 | 115 | def create_model(self): 116 | if not self.cfg['use_skip']: 117 | num_channels_skip = [0] * self.cfg['net_depth'] 118 | 119 | input_channel = self.cfg['input_channel'] 120 | if self.cfg['use_flow']: 121 | output_channel_flow = len(self.cfg['flow_type']) * 2 122 | output_channel = self.cfg['output_channel_img'] + output_channel_flow 123 | else: 124 | output_channel = self.cfg['output_channel_img'] 125 | 126 | if self.cfg['net_type_G'] == '2d': 127 | self.netG = EncoderDecoder2D(input_channel, output_channel, 128 | self.cfg['num_channels_down'][:self.cfg['net_depth']], 129 | self.cfg['num_channels_up'][:self.cfg['net_depth']], 130 | self.cfg['num_channels_skip'][:self.cfg['net_depth']], 131 | self.cfg['filter_size_down'], self.cfg['filter_size_up'], self.cfg['filter_size_skip'], 132 | upsample_mode='nearest', downsample_mode='stride', 133 | need1x1_up=True, need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU') 134 | elif self.cfg['net_type_G'] == '3d': 135 | self.netG = EncoderDecoder3D(input_channel, output_channel, 136 | self.cfg['num_channels_down'][:self.cfg['net_depth']], 137 | self.cfg['num_channels_up'][:self.cfg['net_depth']], 138 | self.cfg['num_channels_skip'][:self.cfg['net_depth']], 139 | self.cfg['filter_size_down'], self.cfg['filter_size_up'], self.cfg['filter_size_skip'], 140 | upsample_mode='nearest', downsample_mode='stride', 141 | need1x1_up=True, need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU') 142 | else: 143 | raise Exception("Network not defined!") 144 | self.netG = self.netG.type(self.cfg['dtype']) 145 | 146 | if self.cfg['use_perceptual']: 147 | if self.cfg['net_type_L'] == 'VGG16': 148 | vgg_model = torchvision.models.vgg16(pretrained=True).type(self.cfg['dtype']) 149 | vgg_model.eval() 150 | vgg_model.requires_grad = False 151 | self.netL = LossNetwork(vgg_model) 152 | 153 | self.logger.info('========================================== Network ==========================================') 154 | self.logger.info(self.netG) 155 | self.logger.info("Total number of parameters: {}".format(get_model_num_parameters(self.netG))) 156 | 157 | 158 | def create_loss_function(self): 159 | if self.cfg['loss_recon'] == 'L1': 160 | self.criterion_recon = torch.nn.L1Loss().type(self.cfg['dtype']) 161 | elif self.cfg['loss_recon'] == 'L2': 162 | self.criterion_recon = torch.nn.MSELoss().type(self.cfg['dtype']) 163 | self.criterion_MSE = torch.nn.MSELoss().type(self.cfg['dtype']) 164 | 165 | 166 | def create_optimizer(self): 167 | if self.cfg['optimizer_G'] == 'Adam': 168 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.cfg['LR']) 169 | elif self.cfg['optimizer_G'] == 'SGD': 170 | self.optimizer_G = torch.optim.SGD(self.netG.parameters(), lr=self.cfg['LR']) 171 | 172 | 173 | def prepare_input(self, input_batch): 174 | """ 175 | Prepare input noise map based on network type (2D/3D) 176 | """ 177 | input_tensor = np_to_torch(input_batch).type(self.cfg['dtype']) # N x C x H x W 178 | if self.cfg['net_type_G'] == '2d': 179 | return input_tensor # N x C x H x W 180 | elif self.cfg['net_type_G'] == '3d': 181 | return input_tensor.transpose(0, 1).unsqueeze(0) # 1 x C x N x H x W 182 | 183 | 184 | def train(self): 185 | """ 186 | Main function for internal learning 187 | """ 188 | self.start_time = time.time() 189 | self.data_loader.init_batch_list() 190 | 191 | self.logger.info('========================================== Training ==========================================') 192 | if self.cfg['train_mode'] == 'DIP-Vid-Flow': 193 | self.train_with_flow() 194 | else: 195 | self.train_baseline() 196 | 197 | # Save log 198 | if not self.cfg['res_dir'] is None: 199 | torch.save(self.log, os.path.join(self.cfg['res_dir'], 'log.tar')) 200 | 201 | # Report training time 202 | running_time = time.time() - self.start_time 203 | self.logger.info("Training finished! Running time: {}s".format(running_time)) 204 | 205 | # Release log file 206 | self.logger.removeHandler(self.log_handler) 207 | 208 | 209 | def train_baseline(self): 210 | """ 211 | Training procedure for all baselines 212 | """ 213 | for pass_idx in range(self.cfg['num_pass']): 214 | while True: 215 | # Get batch data 216 | batch_data = self.data_loader.next_batch() 217 | if batch_data is None: 218 | break 219 | batch_idx = batch_data['batch_idx'] 220 | 221 | self.logger.info('Pass {}, Batch {}'.format(pass_idx + 1, batch_idx)) 222 | 223 | # Start train 224 | self.train_batch(pass_idx, batch_idx, batch_data) 225 | 226 | if self.cfg['train_mode'] == 'DIP': 227 | self.create_model() 228 | self.create_optimizer() 229 | 230 | # Start infer 231 | if (pass_idx + 1) % self.cfg['save_every_pass'] == 0 and self.cfg['train_mode'] != 'DIP': 232 | inferred_result = self.infer(pass_idx) 233 | 234 | self.logger.info("Saving latest model at pass {}".format(pass_idx + 1)) 235 | if not self.cfg['res_dir'] is None: 236 | # Save model and log 237 | checkpoint_G = os.path.join(self.cfg['res_dir'], 'model', '{:03}.tar'.format(pass_idx + 1)) 238 | torch.save(self.netG.state_dict(), checkpoint_G) 239 | torch.save(self.log, os.path.join(self.cfg['res_dir'], 'log.tar')) 240 | 241 | self.logger.info("Running time: {}s".format(time.time() - self.start_time)) 242 | 243 | 244 | 245 | 246 | def train_with_flow(self): 247 | """ 248 | Training procedure for DIP-Vid-Flow 249 | """ 250 | pass_idx = 0 251 | batch_count = 0 252 | while pass_idx < self.cfg['num_pass']: 253 | # Get batch data 254 | batch_data = self.data_loader.next_batch() 255 | if batch_data is None: 256 | continue 257 | batch_idx = batch_data['batch_idx'] 258 | 259 | self.logger.info('Pass: {}, Batch: {}, Flow: {}'.format(pass_idx, batch_idx, str(batch_data['flow_type']))) 260 | 261 | # Start train 262 | self.train_batch(pass_idx, batch_idx, batch_data) 263 | batch_count += 1 264 | 265 | # Start infer 266 | if batch_count % self.cfg['save_every_batch'] == 0: 267 | batch_data = self.data_loader.get_median_batch() 268 | batch_idx = batch_data['batch_idx'] 269 | self.logger.info('Train the median batch before inferring\nPass: {}, Batch: {}, Flow: {}'.format(pass_idx, batch_idx, str(batch_data['flow_type']))) 270 | self.train_batch(pass_idx, batch_idx, batch_data) 271 | 272 | self.infer(pass_idx) 273 | 274 | # Save model and log 275 | self.logger.info("Running time: {}s".format(time.time() - self.start_time)) 276 | if not self.cfg['res_dir'] is None: 277 | checkpoint_G = os.path.join(self.cfg['res_dir'], 'model', '{:03}.tar'.format(pass_idx + 1)) 278 | torch.save(self.netG.state_dict(), checkpoint_G) 279 | torch.save(self.log, os.path.join(self.cfg['res_dir'], 'log.tar')) 280 | pass_idx += 1 281 | 282 | 283 | def infer(self, pass_idx): 284 | """ 285 | Run inferrance with trained model to collect all inpainted frames 286 | """ 287 | self.logger.info('Pass {} infer start...'.format(pass_idx)) 288 | self.data_loader.set_mode('infer') 289 | 290 | inferred_result = np.empty((self.cfg['frame_sum'], self.cfg['output_channel_img'], self.cfg['frame_size'][0], self.cfg['frame_size'][1]), dtype=np.float32) 291 | while True: 292 | batch_data = self.data_loader.next_batch() 293 | if batch_data is None: 294 | break 295 | batch_idx = batch_data['batch_idx'] 296 | self.infer_batch(pass_idx, batch_idx, batch_data, inferred_result) 297 | 298 | self.data_loader.set_mode('train') 299 | return inferred_result 300 | 301 | 302 | def train_batch(self, pass_idx, batch_idx, batch_data): 303 | """ 304 | Train the given batch for `num_iter` iterations 305 | """ 306 | for loss_name in self.cfg['loss_weight']: 307 | self.log['loss_' + loss_name + '_train'][pass_idx].append([]) 308 | best_loss_recon_image = 1e9 309 | best_iter = 0 310 | best_nonhole_batch = None 311 | batch_data['pass_idx'] = pass_idx 312 | batch_data['train'] = True 313 | 314 | # Optimize 315 | for iter_idx in range(self.cfg['num_iter']): 316 | if self.cfg['param_noise']: 317 | for n in [x for x in self.netG.parameters() if len(x.size()) == 4]: 318 | n = n + n.detach().clone().normal_() * n.std() / 50 319 | 320 | # Forward 321 | loss = self.optimize_params(batch_data) 322 | 323 | # Update 324 | for loss_name in self.cfg['loss_weight']: 325 | self.log['loss_' + loss_name + '_train'][pass_idx][-1].append(loss[loss_name].item()) 326 | if loss['recon_image'].item() < best_loss_recon_image: 327 | best_loss_recon_image = loss['recon_image'].item() 328 | best_nonhole_batch = batch_data['out_img_batch'] 329 | best_iter = iter_idx 330 | 331 | log_str = 'Iteration {:05}'.format(iter_idx) 332 | for loss_name in sorted(self.cfg['loss_weight']): 333 | if self.cfg['loss_weight'][loss_name] != 0: 334 | log_str += ' ' + loss_name + ' {:f}'.format(loss[loss_name].item()) 335 | self.logger.info(log_str) 336 | 337 | # Plot and save 338 | if (pass_idx + 1) % self.cfg['save_every_pass'] == 0 and (iter_idx + 1) % self.cfg['save_every_iter'] == 0: 339 | self.plot_and_save(batch_idx, batch_data, '{:03}/{:05}'.format(pass_idx + 1, iter_idx + 1)) 340 | 341 | log_str = 'Best at iteration {:05}, recon_image loss {:f}'.format(best_iter, best_loss_recon_image) 342 | self.logger.info(log_str) 343 | 344 | if self.cfg['train_mode'] == 'DIP': 345 | # For DIP, save the result with lowest loss on nonhole region as final result 346 | batch_data['out_img_batch'] = best_nonhole_batch 347 | self.plot_and_save(batch_idx, batch_data, '001/best_nonhole') 348 | 349 | 350 | def infer_batch(self, pass_idx, batch_idx, batch_data, inferred_result): 351 | """ 352 | Run inferrance for the given batch 353 | """ 354 | # Forward pass 355 | batch_data['pass_idx'] = pass_idx 356 | batch_data['train'] = False 357 | loss = self.optimize_params(batch_data) 358 | 359 | # Update 360 | for loss_name in self.cfg['loss_weight']: 361 | self.log['loss_' + loss_name + '_infer'][pass_idx].append(loss[loss_name].item()) 362 | 363 | # Save inferred result 364 | for i, img in enumerate(batch_data['out_img_batch'][0]): 365 | if batch_idx < batch_data['batch_stride'] or i >= self.cfg['batch_size'] // 2: 366 | inferred_result[batch_idx + i * batch_data['batch_stride']] = img 367 | 368 | log_str = 'Batch {:05}'.format(batch_idx) 369 | for loss_name in sorted(self.cfg['loss_weight']): 370 | if self.cfg['loss_weight'][loss_name] != 0: 371 | log_str += ' ' + loss_name + ' {:f}'.format(loss[loss_name].item()) 372 | self.logger.info(log_str) 373 | 374 | # Plot and save 375 | if (self.cfg['plot'] or self.cfg['save']) and (pass_idx + 1) % self.cfg['save_every_pass'] == 0: 376 | self.plot_and_save(batch_idx, batch_data, '{:03}/final'.format(pass_idx + 1)) 377 | 378 | 379 | def optimize_params(self, batch_data): 380 | """ 381 | Calculate loss and back-propagate the loss 382 | """ 383 | pass_idx = batch_data['pass_idx'] 384 | batch_idx = batch_data['batch_idx'] 385 | net_input = self.prepare_input(batch_data['input_batch']) 386 | img_tensor = np_to_torch(batch_data['img_batch']).type(self.cfg['dtype']) 387 | mask_tensor = np_to_torch(batch_data['mask_batch']).type(self.cfg['dtype']) 388 | 389 | if self.cfg['use_flow']: 390 | flow_tensor, mask_flow_tensor, mask_warp_tensor = {}, {}, {} 391 | for ft in batch_data['flow_type']: 392 | flow_tensor[ft] = np_to_torch(batch_data['flow_batch'][ft]).type(self.cfg['dtype']) 393 | mask_flow_tensor[ft] = np_to_torch(batch_data['mask_flow_batch'][ft]).type(self.cfg['dtype']) 394 | mask_warp_tensor[ft] = np_to_torch(batch_data['mask_warp_batch'][ft]).type(self.cfg['dtype']) 395 | flow_value_max = batch_data['flow_value_max'] 396 | 397 | if self.cfg['use_perceptual']: 398 | mask_per_tensor = [] 399 | for mask in batch_data['mask_per_batch']: 400 | mask = np_to_torch(mask).type(self.cfg['dtype']) 401 | mask_per_tensor.append(mask) 402 | 403 | # Forward 404 | net_output = self.netG(net_input) 405 | torch.cuda.empty_cache() 406 | 407 | # Collect image/flow from network output 408 | if self.cfg['net_type_G'] == '2d': 409 | out_img_tensor = net_output[:, :self.cfg['output_channel_img'], ...] # N x 3 x H x W 410 | if self.cfg['use_flow']: 411 | out_flow_tensor = {} 412 | for ft in batch_data['flow_type']: 413 | channel_idx1, channel_idx2 = self.cfg['flow_channel_map'][ft] 414 | flow_idx1, flow_idx2 = (0, -1) if 'f' in ft else (1, self.cfg['batch_size']) 415 | out_flow_tensor[ft] = net_output[flow_idx1:flow_idx2, channel_idx1:channel_idx2, ...] 416 | 417 | elif self.cfg['net_type_G'] == '3d': 418 | out_img_tensor = net_output.squeeze(0)[:self.cfg['output_channel_img']].transpose(0, 1) # N x 3 x H x W 419 | if self.cfg['use_flow']: 420 | out_flow_tensor = {} 421 | for ft in batch_data['flow_type']: 422 | channel_idx1, channel_idx2 = self.cfg['flow_channel_map'][ft] 423 | flow_idx1, flow_idx2 = (0, -1) if 'f' in ft else (1, self.cfg['batch_size']) 424 | out_flow_tensor[ft] = net_output.squeeze(0) \ 425 | [channel_idx1:channel_idx2, flow_idx1:flow_idx2, ...].transpose(0, 1) # N-1 x 2 x H x W 426 | 427 | # Compute loss 428 | loss = {} 429 | for loss_name in self.cfg['loss_weight']: 430 | loss[loss_name] = torch.zeros([]).float().cuda().detach() 431 | 432 | self.optimizer_G.zero_grad() 433 | 434 | # Image reconstruction loss 435 | if self.cfg['loss_weight']['recon_image'] != 0: 436 | loss['recon_image'] += self.criterion_recon( 437 | out_img_tensor * (1. - mask_tensor), \ 438 | img_tensor * (1. - mask_tensor)) 439 | 440 | # Flow reconstruction loss 441 | if self.cfg['loss_weight']['recon_flow'] != 0: 442 | for ft in batch_data['flow_type']: 443 | mask_flow_inv = (1. - mask_flow_tensor[ft]) * flow_tensor[ft][:, 2:3, ...] 444 | loss['recon_flow'] += self.criterion_recon(out_flow_tensor[ft] * mask_flow_inv, \ 445 | flow_tensor[ft][:, :2, ...] * mask_flow_inv / flow_value_max) 446 | 447 | # Consistency loss 448 | if self.cfg['loss_weight']['consistency'] != 0: 449 | warped_img, warped_diff = {}, {} 450 | for ft in batch_data['flow_type']: 451 | idx1, idx2 = (0, -1) if 'f' in ft else (1, self.cfg['batch_size']) 452 | idx_inv1, idx_inv2 = (1, self.cfg['batch_size']) if 'f' in ft else (0, -1) 453 | out_img = out_img_tensor[idx_inv1:idx_inv2] 454 | out_flow = out_flow_tensor[ft] 455 | warped_img[ft], flowmask = warp_torch(out_img, out_flow * flow_value_max) 456 | mask = mask_flow_tensor[ft] * flowmask.detach() 457 | loss['consistency'] += self.criterion_recon( 458 | warped_img[ft] * mask, 459 | out_img_tensor[idx1:idx2].detach() * mask) 460 | torch.cuda.empty_cache() 461 | 462 | # Perceptual loss 463 | if self.cfg['use_perceptual']: 464 | feature_src = self.netL(out_img_tensor) 465 | feature_dst = self.netL(img_tensor) 466 | for i, mask in enumerate(mask_per_tensor): 467 | loss['perceptual'] += self.criterion_MSE( 468 | feature_src[i] * (1. - mask), feature_dst[i].detach() * (1. - mask)) 469 | torch.cuda.empty_cache() 470 | 471 | # Back-propagation 472 | running_loss = 0 473 | for loss_name, weight in self.cfg['loss_weight'].items(): 474 | if weight != 0: 475 | running_loss = running_loss + weight * loss[loss_name] 476 | 477 | if batch_data['train']: 478 | running_loss.backward() 479 | self.optimizer_G.step() 480 | torch.cuda.empty_cache() 481 | 482 | # Save generated image/flow 483 | batch_data['out_img_batch'] = torch_to_np(out_img_tensor) 484 | if self.cfg['use_flow']: 485 | out_flow_batch = {} 486 | for ft in batch_data['flow_type']: 487 | out_flow_batch[ft] = torch_to_np(out_flow_tensor[ft] * flow_value_max) 488 | batch_data['out_flow_batch'] = out_flow_batch 489 | return loss 490 | 491 | 492 | def plot_and_save(self, batch_idx, batch_data, subpath): 493 | """ 494 | Plot/save intermediate results 495 | """ 496 | def save(imgs, subpath, subsubpath): 497 | res_dir = os.path.join(self.cfg['res_dir'], subpath, subsubpath) 498 | for i, img in enumerate(imgs): 499 | if img is None: 500 | continue 501 | fid = batch_idx + i * batch_data['batch_stride'] 502 | batch_path = os.path.join(res_dir, 'batch', '{:03}_{:03}.png'.format(batch_idx, fid)) 503 | sequence_path = os.path.join(res_dir, 'sequence', '{:03}.png'.format(fid)) 504 | if self.cfg['save_batch']: 505 | cv2.imwrite(batch_path, np_to_cv2(img)) 506 | if batch_idx < batch_data['batch_stride'] or i >= self.cfg['batch_size'] // 2: 507 | cv2.imwrite(sequence_path, np_to_cv2(img)) 508 | 509 | # Load batch data 510 | input_batch = batch_data['input_batch'] / self.cfg['input_ratio'] 511 | img_batch = batch_data['img_batch'] 512 | mask_batch = batch_data['mask_batch'] 513 | contour_batch = batch_data['contour_batch'] 514 | out_img_batch = batch_data['out_img_batch'].copy() 515 | stitch_batch = img_batch * (1 - mask_batch) + out_img_batch * mask_batch 516 | 517 | # Draw mask boundary 518 | for i in range(self.cfg['batch_size']): 519 | for con in contour_batch[i]: 520 | for pt in con: 521 | x, y = pt[0] 522 | out_img_batch[i][:, y, x] = [0, 0, 1] 523 | 524 | # Plot in jupyter 525 | if self.cfg['plot']: 526 | if self.cfg['batch_size'] == 1: 527 | plot_image_grid(out_img_batch, 1, factor=10) 528 | else: 529 | plot_image_grid(out_img_batch, self.cfg['batch_size'], padding=3, factor=15) 530 | 531 | # Save images to disk 532 | if not self.cfg['res_dir'] is None and self.cfg['save']: 533 | save(stitch_batch, subpath, 'stitch') 534 | save(out_img_batch, subpath, 'full_with_boundary') 535 | save(batch_data['out_img_batch'], subpath, 'full') --------------------------------------------------------------------------------