├── figures ├── result1.png ├── result2.png ├── teaser.png └── teaser2.png ├── Dockerfile ├── datasets ├── __init__.py ├── VideoInterp.py └── data_transforms.py ├── models ├── __init__.py ├── model_utils.py ├── CycleHJSuperSloMo.py └── HJSuperSloMo.py ├── LICENSE ├── utils.py ├── parser.py ├── eval.py ├── README.md └── train.py /figures/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/unsupervised-video-interpolation/HEAD/figures/result1.png -------------------------------------------------------------------------------- /figures/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/unsupervised-video-interpolation/HEAD/figures/result2.png -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/unsupervised-video-interpolation/HEAD/figures/teaser.png -------------------------------------------------------------------------------- /figures/teaser2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/unsupervised-video-interpolation/HEAD/figures/teaser2.png -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # =========== 2 | # base images 3 | # =========== 4 | FROM nvcr.io/nvidia/pytorch:19.04-py3 5 | 6 | 7 | # =============== 8 | # system packages 9 | # =============== 10 | RUN apt-get update 11 | RUN apt-get install -y bash-completion \ 12 | emacs \ 13 | ffmpeg \ 14 | git \ 15 | graphviz \ 16 | htop \ 17 | libopenexr-dev \ 18 | openssh-server \ 19 | rsync \ 20 | wget \ 21 | curl 22 | 23 | 24 | # =========== 25 | # latest apex 26 | # =========== 27 | RUN pip uninstall -y apex 28 | RUN git clone https://github.com/NVIDIA/apex.git ~/apex && \ 29 | cd ~/apex && \ 30 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 31 | 32 | 33 | # ============ 34 | # pip packages 35 | # ============ 36 | RUN pip install --upgrade pip 37 | RUN pip install --upgrade ffmpeg==1.4 38 | RUN pip install --upgrade imageio==2.6.1 39 | RUN pip install --upgrade natsort==6.2.0 40 | RUN pip install --upgrade numpy==1.18.1 41 | RUN pip install --upgrade pillow==6.1 42 | RUN pip install --upgrade scikit-image==0.16.2 43 | RUN pip install --upgrade tensorboardX==2.0 44 | RUN pip install --upgrade torchvision==0.4.2 45 | RUN pip install --upgrade tqdm==4.41.1 46 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from .VideoInterp import * 28 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from .HJSuperSloMo import * 28 | from .CycleHJSuperSloMo import * 29 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import numpy as np 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | # The baseline Super SloMo relies on torch.nn.functional.grid_sample to implement a warping module. 36 | # To ensure that our results replicate published accuracy numbers, we also implement a Resample2D layer 37 | # in a similar way, completely with torch tensors, as is done in: 38 | # https://github.com/avinashpaliwal/Super-SloMo/blob/master/model.py#L213 39 | # 40 | # However, for faster training, we suggest to use our CUDA kernels for Resample2D, here: 41 | # https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/resample2d_package/resample2d.py 42 | # 43 | # from flownet2_pytorch.networks.resample2d_package.resample2d import Resample2d 44 | # 45 | 46 | 47 | class MyResample2D(nn.Module): 48 | def __init__(self, width, height): 49 | super(MyResample2D, self).__init__() 50 | 51 | self.width = width 52 | self.height = height 53 | 54 | # make grids for horizontal and vertical displacements 55 | grid_w, grid_h = np.meshgrid(np.arange(width), np.arange(height)) 56 | grid_w, grid_h = grid_w.reshape((1,) + grid_w.shape), grid_h.reshape((1,) + grid_h.shape) 57 | 58 | self.register_buffer("grid_w", torch.tensor(grid_w, requires_grad=False, dtype=torch.float32)) 59 | self.register_buffer("grid_h", torch.tensor(grid_h, requires_grad=False, dtype=torch.float32)) 60 | 61 | def forward(self, im, uv): 62 | 63 | # Get relative displacement 64 | u = uv[:, 0, ...] 65 | v = uv[:, 1, ...] 66 | 67 | # Calculate absolute displacement along height and width axis -> (batch_size, height, width) 68 | ww = self.grid_w.expand_as(u) + u 69 | hh = self.grid_h.expand_as(v) + v 70 | 71 | # Normalize indices to [-1,1] 72 | ww = 2 * ww / (self.width - 1) - 1 73 | hh = 2 * hh / (self.height - 1) - 1 74 | 75 | # Form a grid of shape (batch_size, height, width, 2) 76 | norm_grid_wh = torch.stack((ww, hh), dim=-1) 77 | 78 | # Perform a resample 79 | reampled_im = torch.nn.functional.grid_sample(im, norm_grid_wh) 80 | 81 | return reampled_im 82 | 83 | 84 | class DummyModel(nn.Module): 85 | def __init__(self): 86 | super(DummyModel, self).__init__() 87 | 88 | def forward(self, inputs, target_index): 89 | return {}, inputs['image'][1], inputs['image'][1] 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Nvidia Source Code License (1-Way Commercial) – NVIDIA CONFIDENTIAL 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Software” means the original work of authorship made available under this License. 7 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 8 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates. 9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 10 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 11 | 12 | 2. License Grants 13 | 14 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 15 | 16 | 3. Limitations 17 | 18 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 19 | 20 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 21 | 22 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or it’s affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 23 | 24 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately. 25 | 26 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 27 | 28 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately. 29 | 30 | 4. Disclaimer of Warranty. 31 | 32 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 33 | 34 | 5. Limitation of Liability. 35 | 36 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 37 | 38 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | import time 5 | from inspect import isclass 6 | import numpy as np 7 | 8 | 9 | class TimerBlock: 10 | def __init__(self, title): 11 | print(("{}".format(title))) 12 | 13 | def __enter__(self): 14 | self.start = time.clock() 15 | return self 16 | 17 | def __exit__(self, exc_type, exc_value, traceback): 18 | self.end = time.clock() 19 | self.interval = self.end - self.start 20 | 21 | if exc_type is not None: 22 | self.log("Operation failed\n") 23 | else: 24 | self.log("Operation finished\n") 25 | 26 | def log(self, string): 27 | duration = time.clock() - self.start 28 | units = 's' 29 | if duration > 60: 30 | duration = duration / 60. 31 | units = 'm' 32 | print(" [{:.3f}{}] {}".format(duration, units, string), flush=True) 33 | 34 | 35 | def module_to_dict(module, exclude=[]): 36 | return dict([(x, getattr(module, x)) for x in dir(module) 37 | if isclass(getattr(module, x)) 38 | and x not in exclude 39 | and getattr(module, x) not in exclude]) 40 | 41 | 42 | # AverageMeter: adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py 43 | class AverageMeter(object): 44 | """Computes and stores the average and current value""" 45 | def __init__(self): 46 | self.reset() 47 | 48 | def reset(self): 49 | self.val = 0 50 | self.avg = 0 51 | self.sum = 0 52 | self.count = 0 53 | 54 | def update(self, val, n=1): 55 | self.val = val 56 | self.sum += val * n 57 | self.count += n 58 | self.avg = self.sum / self.count 59 | 60 | 61 | # creat_pipe: adapted from https://stackoverflow.com/questions/23709893/popen-write-operation-on-closed-file-images-to-video-using-ffmpeg/23709937#23709937 62 | # start an ffmpeg pipe for creating RGB8 for color images or FFV1 for depth 63 | # NOTE: this is REALLY lossy and not optimal for HDR data. when it comes time to train 64 | # on HDR data, you'll need to figure out the way to save to pix_fmt=rgb48 or something 65 | # similar 66 | def create_pipe(pipe_filename, width, height, frame_rate=60, quite=True): 67 | # default extension and tonemapper 68 | pix_fmt = 'rgb24' 69 | out_fmt = 'yuv420p' 70 | codec = 'h264' 71 | 72 | command = ['ffmpeg', 73 | '-threads', '2', # number of threads to start 74 | '-y', # (optional) overwrite output file if it exists 75 | '-f', 'rawvideo', # input format 76 | '-vcodec', 'rawvideo', # input codec 77 | '-s', str(width) + 'x' + str(height), # size of one frame 78 | '-pix_fmt', pix_fmt, # input pixel format 79 | '-r', str(frame_rate), # frames per second 80 | '-i', '-', # The imput comes from a pipe 81 | '-an', # Tells FFMPEG not to expect any audio 82 | '-codec:v', codec, # output codec 83 | '-crf', '18', 84 | # compression quality for h264 (maybe h265 too?) - http://slhck.info/video/2017/02/24/crf-guide.html 85 | # '-compression_level', '10', # compression level for libjpeg if doing lossy depth 86 | '-strict', '-2', # experimental 16 bit support nessesary for gray16le 87 | '-pix_fmt', out_fmt, # output pixel format 88 | '-s', str(width) + 'x' + str(height), # output size 89 | pipe_filename] 90 | cmd = ' '.join(command) 91 | if not quite: 92 | print('openning a pip ....\n' + cmd + '\n') 93 | 94 | # open the pipe, and ignore stdout and stderr output 95 | DEVNULL = open(os.devnull, 'wb') 96 | return subprocess.Popen(command, stdin=subprocess.PIPE, stdout=DEVNULL, stderr=DEVNULL, close_fds=True) 97 | 98 | 99 | 100 | def get_pred_flag(height, width): 101 | pred_flag = np.ones((height, width, 3), dtype=np.uint8) 102 | pred_values = np.zeros((height, width, 3), dtype=np.uint8) 103 | 104 | hstart = int((192. / 1200) * height) 105 | wstart = int((224. / 1920) * width) 106 | h_step = int((24. / 1200) * height) 107 | w_step = int((32. / 1920) * width) 108 | 109 | pred_flag[hstart:hstart + h_step, -wstart + 0 * w_step:-wstart + 1 * w_step, :] = np.asarray([0, 0, 0]) 110 | pred_flag[hstart:hstart + h_step, -wstart + 1 * w_step:-wstart + 2 * w_step, :] = np.asarray([0, 0, 0]) 111 | pred_flag[hstart:hstart + h_step, -wstart + 2 * w_step:-wstart + 3 * w_step, :] = np.asarray([0, 0, 0]) 112 | 113 | pred_values[hstart:hstart + h_step, -wstart + 0 * w_step:-wstart + 1 * w_step, :] = np.asarray([0, 0, 255]) 114 | pred_values[hstart:hstart + h_step, -wstart + 1 * w_step:-wstart + 2 * w_step, :] = np.asarray([0, 255, 0]) 115 | pred_values[hstart:hstart + h_step, -wstart + 2 * w_step:-wstart + 3 * w_step, :] = np.asarray([255, 0, 0]) 116 | return pred_flag, pred_values 117 | 118 | 119 | def copy_arguments(main_dict, main_filepath='', save_dir='./'): 120 | pycmd = 'python3 ' + main_filepath + ' \\\n' 121 | _main_dict = main_dict.copy() 122 | _main_dict['--name'] = _main_dict['--name']+'_replicate' 123 | for k in _main_dict.keys(): 124 | if 'batchNorm' in k: 125 | pycmd += ' ' + k + ' ' + str(_main_dict[k]) + ' \\\n' 126 | elif type(_main_dict[k]) == bool and _main_dict[k]: 127 | pycmd += ' ' + k + ' \\\n' 128 | elif type(_main_dict[k]) == list: 129 | pycmd += ' ' + k + ' ' + ' '.join([str(f) for f in _main_dict[k]]) + ' \\\n' 130 | elif type(_main_dict[k]) != bool: 131 | pycmd += ' ' + k + ' ' + str(_main_dict[k]) + ' \\\n' 132 | pycmd = '#!/bin/bash\n' + pycmd[:-2] 133 | job_script = os.path.join(save_dir, 'job.sh') 134 | 135 | file = open(job_script, 'w') 136 | file.write(pycmd) 137 | file.close() 138 | 139 | return 140 | 141 | 142 | def block_print(): 143 | sys.stdout = open(os.devnull, 'w') 144 | -------------------------------------------------------------------------------- /datasets/VideoInterp.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import natsort 32 | import numpy as np 33 | from imageio import imread 34 | import torch 35 | from torch.utils import data 36 | 37 | 38 | class VideoInterp(data.Dataset): 39 | def __init__(self, args=None, root='', num_interp=7, sample_rate=1, step_size=1, 40 | is_training=False, transform=None): 41 | 42 | self.num_interp = num_interp 43 | self.sample_rate = sample_rate 44 | self.step_size = step_size 45 | self.transform = transform 46 | self.is_training = is_training 47 | self.transform = transform 48 | 49 | self.start_index = args.start_index 50 | self.stride = args.stride 51 | self.crop_size = args.crop_size 52 | 53 | # argument sanity check 54 | assert (os.path.exists(root)), "Invalid path to input dataset." 55 | assert self.num_interp > 0, "num_interp must be at least 1" 56 | assert self.step_size > 0, "step_size must be at least 1" 57 | 58 | if self.is_training: 59 | self.start_index = 0 60 | 61 | # collect, colors, motion vectors, and depth 62 | self.ref = self.collect_filelist(root) 63 | 64 | # calculate total number of unique sub-sequences 65 | def calc_subseq_len(n): 66 | return (n - max(1, (self.num_interp + 1) * self.sample_rate) - 1) // self.step_size + 1 67 | self.counts = [calc_subseq_len(len(el)) for el in self.ref] 68 | 69 | self.total = np.sum(self.counts) 70 | self.cum_sum = list(np.cumsum([0] + [el for el in self.counts])) 71 | 72 | def collect_filelist(self, root): 73 | include_ext = [".png", ".jpg", "jpeg", ".bmp"] 74 | # collect subfolders, excluding hidden files, but following symlinks 75 | dirs = [x[0] for x in os.walk(root, followlinks=True) if not x[0].startswith('.')] 76 | 77 | # naturally sort, both dirs and individual images, while skipping hidden files 78 | dirs = natsort.natsorted(dirs) 79 | 80 | datasets = [ 81 | [os.path.join(fdir, el) for el in natsort.natsorted(os.listdir(fdir)) 82 | if os.path.isfile(os.path.join(fdir, el)) 83 | and not el.startswith('.') 84 | and any([el.endswith(ext) for ext in include_ext])] 85 | for fdir in dirs 86 | ] 87 | 88 | return [el for el in datasets if el] 89 | 90 | def get_sample_indices(self, index, tar_index=None): 91 | if self.is_training: 92 | sample_indices = [index, index + self.sample_rate * tar_index, index + 93 | self.sample_rate * (self.num_interp + 1)] 94 | else: 95 | sample_indices = [index + i * self.sample_rate for i in range(0, self.num_interp + 2)] 96 | if self.sample_rate == 0: 97 | sample_indices[-1] += 1 98 | return sample_indices 99 | 100 | def pad_images(self, images): 101 | height, width, _ = images[0].shape 102 | image_count = len(images) 103 | # Pad images with zeros if it is not evenly divisible by args.stride (property of model) 104 | if (height % self.stride) != 0: 105 | new_height = (height // self.stride + 1) * self.stride 106 | for i in range(image_count): 107 | images[i] = np.pad(images[i], ((0, new_height - height), (0, 0), (0, 0)), 'constant', 108 | constant_values=(0, 0)) 109 | 110 | if (width % self.stride) != 0: 111 | new_width = (width // self.stride + 1) * self.stride 112 | for i in range(image_count): 113 | images[i] = np.pad(images[i], ((0, 0), (0, new_width - width), (0, 0)), 'constant', 114 | constant_values=(0, 0)) 115 | return images 116 | 117 | def __len__(self): 118 | return self.total 119 | 120 | def __getitem__(self, index): 121 | # Adjust index 122 | index = len(self) + index if index < 0 else index 123 | index = index + self.start_index 124 | 125 | dataset_index = np.searchsorted(self.cum_sum, index + 1) 126 | index = self.step_size * (index - self.cum_sum[np.maximum(0, dataset_index - 1)]) 127 | 128 | image_list = self.ref[dataset_index - 1] 129 | 130 | # target index, subset of range(1,num_interp+1) 131 | tar_index = 1 + torch.randint(0, max(1, self.num_interp), (1,)).item() 132 | input_indices = self.get_sample_indices(index, tar_index) 133 | 134 | # reverse subsequence for augmentation with a probability of 0.5 135 | if self.is_training and torch.randint(0, 2, (1,)).item(): 136 | input_indices = input_indices[::-1] 137 | tar_index = self.num_interp - tar_index + 1 138 | 139 | image_files = [image_list[i] for i in input_indices] 140 | 141 | # Read images from file 142 | images = [imread(image_file)[:, :, :3] for image_file in image_files] 143 | image_shape = images[0].shape 144 | 145 | # Apply data augmentation if defined. 146 | if self.transform: 147 | input_images, target_images = [images[0], images[-1]], images[1:-1] 148 | input_images, target_images = self.transform(input_images, target_images) 149 | images = [input_images[0]] + target_images + [input_images[-1]] 150 | 151 | # Pad images with zeros, so they fit evenly to model arch in forward pass. 152 | padded_images = self.pad_images(images) 153 | 154 | input_images = [torch.from_numpy(np.ascontiguousarray(tmp.transpose(2, 0, 1).astype(np.float32))).float() for 155 | tmp in padded_images] 156 | 157 | output_dict = { 158 | 'image': input_images, 'tindex': tar_index, 'ishape': image_shape[:2], 'input_files': image_files 159 | } 160 | # print (' '.join([os.path.basename(f) for f in image_files])) 161 | return output_dict 162 | 163 | 164 | class CycleVideoInterp(VideoInterp): 165 | def __init__(self, args=None, root='', num_interp=7, sample_rate=1, step_size=1, 166 | is_training=False, transform=None): 167 | super(CycleVideoInterp, self).__init__(args=args, root=root, num_interp=num_interp, sample_rate=sample_rate, 168 | step_size=step_size, is_training=is_training, transform=transform) 169 | 170 | # # Adjust indices 171 | if self.is_training: 172 | self.counts = [el - 1 for el in self.counts] 173 | self.total = np.sum(self.counts) 174 | self.cum_sum = list(np.cumsum([0] + [el for el in self.counts])) 175 | 176 | def get_sample_indices(self, index, tar_index=None): 177 | if self.is_training: 178 | offset = max(1, self.sample_rate) + self.sample_rate * self.num_interp 179 | sample_indices = [index, index + offset, index + 2 * offset] 180 | else: 181 | sample_indices = [index + i * self.sample_rate for i in range(0, self.num_interp + 2)] 182 | if self.sample_rate == 0: 183 | sample_indices[-1] += 1 184 | return sample_indices 185 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ***************************************************************************** 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # * Neither the name of the NVIDIA CORPORATION nor the 13 | # names of its contributors may be used to endorse or promote products 14 | # derived from this software without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | # 27 | # ***************************************************************************** 28 | import argparse 29 | import models 30 | 31 | # Collect all available model classes 32 | model_names = sorted(el for el in models.__dict__ 33 | if not el.startswith("__") and callable(models.__dict__[el])) 34 | 35 | """ 36 | Reda, Fitsum A., et al. "Unsupervised Video Interpolation Using Cycle Consistency." 37 | arXiv preprint arXiv:1906.05928 (2019). 38 | 39 | Jiang, Huaizu, et al. "Super slomo: High quality estimation of multiple 40 | intermediate frames for video interpolation." arXiv pre-print arXiv:1712.00080 (2017). 41 | """ 42 | 43 | parser = argparse.ArgumentParser(description="A PyTorch Implementation of Unsupervised Video Interpolation Using " 44 | "Cycle Consistency") 45 | 46 | parser.add_argument('--model', metavar='MODEL', default='HJSuperSloMo', 47 | choices=model_names, 48 | help='model architecture: ' + 49 | ' | '.join(model_names) + 50 | ' (default: HJSuperSloMo)') 51 | parser.add_argument('-s', '--save', '--save_root', 52 | default='./result_folder', type=str, 53 | help='Path of the output folder', 54 | metavar='SAVE_PATH') 55 | parser.add_argument('--torch_home', default='./.torch', type=str, 56 | metavar='TORCH_HOME', 57 | help='Path to save pre-trained models from torchvision') 58 | parser.add_argument('-n', '--name', default='trial_0', type=str, metavar='EXPERIMENT_NAME', 59 | help='Name of experiment folder.') 60 | parser.add_argument('--dataset', default='VideoInterp', type=str, metavar='TRAINING_DATALOADER_CLASS', 61 | help='Specify training dataset class for loading (Default: VideoInterp)') 62 | parser.add_argument('--resume', default='', type=str, metavar='CHECKPOINT_PATH', 63 | help='path to checkpoint file (default: none)') 64 | 65 | # Resources 66 | parser.add_argument('--distributed_backend', default='nccl', type=str, metavar='DISTRIBUTED_BACKEND', 67 | help='backend used for communication between processes.') 68 | parser.add_argument('-j', '--workers', default=4, type=int, 69 | help='number of data loader workers (default: 10)') 70 | parser.add_argument('-g', '--gpus', type=int, default=-1, 71 | help='number of GPUs to use') 72 | parser.add_argument('--fp16', action='store_true', help='Enable mixed-precision training.') 73 | 74 | # Learning rate parameters. 75 | parser.add_argument('--lr', '--learning_rate', default=0.0001, type=float, 76 | metavar='LR', help='initial learning rate') 77 | parser.add_argument('--lr_scheduler', default='MultiStepLR', type=str, 78 | metavar='LR_Scheduler', help='Scheduler for learning' + 79 | ' rate (only ExponentialLR and MultiStepLR supported.') 80 | parser.add_argument('--lr_gamma', default=0.1, type=float, 81 | help='learning rate will be multiplied by this gamma') 82 | parser.add_argument('--lr_step', default=200, type=int, 83 | help='stepsize of changing the learning rate') 84 | parser.add_argument('--lr_milestones', type=int, nargs='+', 85 | default=[250, 450], help="Spatial dimension to " + 86 | "crop training samples for training") 87 | # Gradient. 88 | parser.add_argument('--clip_gradients', default=-1.0, type=float, 89 | help='If positive, clip the gradients by this value.') 90 | 91 | # Optimization hyper-parameters 92 | parser.add_argument('-b', '--batch_size', default=4, type=int, metavar='BATCH_SIZE', 93 | help='mini-batch per gpu size (default : 4)') 94 | parser.add_argument('--wd', '--weight_decay', default=0.001, type=float, metavar='WEIGHT_DECAY', 95 | help='weight_decay (default = 0.001)') 96 | parser.add_argument('--seed', default=1234, type=int, metavar="SEED", 97 | help='seed for initializing training. ') 98 | parser.add_argument('--optimizer', default='Adam', type=str, metavar='OPTIMIZER', 99 | help='Specify optimizer from torch.optim (Default: Adam)') 100 | parser.add_argument('--mean_pix', nargs='+', type=float, metavar="RGB_MEAN", 101 | default=[109.93, 109.167, 101.455], 102 | help='mean pixel values carried over from superslomo (default: [109.93, 109.167, 101.455])') 103 | parser.add_argument('--print_freq', default=100, type=int, metavar="PRINT_FREQ", 104 | help='frequency of printing training status (default: 100)') 105 | parser.add_argument('--save_freq', type=int, default=20, metavar="SAVE_FREQ", 106 | help='frequency of saving intermediate models, in epoches (default: 20)') 107 | parser.add_argument('--start_epoch', type=int, default=-1, 108 | help="Set epoch number during resuming") 109 | parser.add_argument('--epochs', default=500, type=int, metavar="EPOCHES", 110 | help='number of total epochs to run (default: 500)') 111 | 112 | # Training sequence, supports a single sequence for now 113 | parser.add_argument('--train_file', required=False, metavar="TRAINING_FILE", 114 | help='training file (default : Required)') 115 | parser.add_argument('--crop_size', type=int, nargs='+', default=[704, 704], metavar="CROP_SIZE", 116 | help="Spatial dimension to crop training samples for training (default : [704, 704])") 117 | parser.add_argument('--train_n_batches', default=-1, type=int, metavar="TRAIN_N_BATCHES", 118 | help="Limit the number of minibatch iterations per epoch. Used for debugging purposes. \ 119 | (default : -1, means use all available mini-batches") 120 | parser.add_argument('--sample_rate', type=int, default=1, 121 | help='number of frames to skip when sampling input1, {intermediate}, and input2 \ 122 | (default=1, ie. we treat consecutive frames for input1 and intermediate, and input2 frames.)') 123 | parser.add_argument('--step_size', type=int, default=-1, metavar="STEP_INTERP", 124 | help='number of frames to skip from one mini-batch to the next mini-batch \ 125 | (default -1, means step_size = num_interp + 1') 126 | parser.add_argument('--num_interp', default=7, type=int, metavar="NUM_INTERP", 127 | help='number intermediate frames to interpolate (default : 7)') 128 | 129 | 130 | # Validation sequence, supports a single sequence for now 131 | parser.add_argument('--val_file', metavar="VALIDATION_FILE", 132 | help='validation file (default : None)') 133 | parser.add_argument('--val_batch_size', type=int, default=1, 134 | help="Batch size to use for validation.") 135 | parser.add_argument('--val_n_batches', default=-1, type=int, 136 | help="Limit the number of minibatch iterations per epoch. Used for debugging purposes.") 137 | parser.add_argument('--video_fps', type=int, default=30, 138 | help="Render predicted video with a specified frame rate") 139 | parser.add_argument('--initial_eval', action='store_true', help='Perform initial evaluation before training.') 140 | parser.add_argument("--start_index", type=int, default=0, metavar="VAL_START_INDEX", 141 | help="Index to start running validation (default : 0)") 142 | parser.add_argument("--val_sample_rate", type=int, default=1, metavar="VAL_START_INDEX", 143 | help='number of frames to skip when sampling input1, {intermediate}, and input2 (default=1, \ 144 | ie. we treat consecutive frames for input1 and intermediate, and input2 frames.)') 145 | parser.add_argument('--val_step_size', type=int, default=-1, metavar="VAL_STEP_INTERP", 146 | help='number of frames to skip from one mini-batch to the next mini-batch \ 147 | (default -1, means step_size = num_interp + 1') 148 | parser.add_argument('--val_num_interp', type=int, default=1, 149 | help='number of intermediate frames we want to interpolate for validation. (default: 1)') 150 | 151 | # Misc: undersample large sequences (--step_size), compute flow after downscale (--flow_scale) 152 | parser.add_argument('--flow_scale', type=float, default=1., 153 | help="Flow scale (default: 1.) for robust interpolation in high resolution images.") 154 | parser.add_argument('--skip_aug', action='store_true', help='Skips expensive geometric or photometric augmentations.') 155 | parser.add_argument('--teacher_weight', type=float, default=-1., 156 | help="Teacher or Pseudo Supervised Loss (PSL)'s weight of contribution to total loss.") 157 | 158 | parser.add_argument('--apply_vidflag', action='store_true', help='Apply applying the BRG flag to interpolated frames.') 159 | 160 | parser.add_argument('--write_video', action='store_true', help='save video to \'args.save/args.name.mp4\'.') 161 | parser.add_argument('--write_images', action='store_true', 162 | help='write to folder \'args.save/args.name\' prediction and ground-truth images.') 163 | parser.add_argument('--stride', type=int, default=64, 164 | help='the largest factor a model reduces spatial size of inputs during a forward pass.') 165 | parser.add_argument('--post_fix', default='Proposed', type=str, 166 | help='tag for predicted frames (default: \'proposed\')') 167 | 168 | # Required for torch distributed launch 169 | parser.add_argument('--local_rank', default=None, type=int, 170 | help='Torch Distributed') 171 | -------------------------------------------------------------------------------- /models/CycleHJSuperSloMo.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from __future__ import division 28 | from __future__ import print_function 29 | import torch 30 | import torch.nn.functional as F 31 | from .model_utils import MyResample2D, DummyModel 32 | from .HJSuperSloMo import HJSuperSloMo 33 | 34 | 35 | class CycleHJSuperSloMo(HJSuperSloMo): 36 | def __init__(self, args, mean_pix=[109.93, 109.167, 101.455]): 37 | super(CycleHJSuperSloMo, self).__init__(args=args, mean_pix=mean_pix) 38 | 39 | if args.resume: 40 | self.teacher = HJSuperSloMo(args) 41 | checkpoint = torch.load(args.resume, map_location='cpu') 42 | self.teacher.load_state_dict(checkpoint['state_dict'], strict=False) 43 | for param in self.teacher.parameters(): 44 | param.requires_grad = False 45 | 46 | self.teacher_weight = 0.8 47 | if 'teacher_weight' in args and args.teacher_weight >= 0: 48 | self.teacher_weight = args.teacher_weight 49 | else: 50 | self.teacher = DummyModel() 51 | self.teacher_weight = 0. 52 | 53 | def network_output(self, inputs, target_index): 54 | 55 | im1, im2 = inputs 56 | 57 | # Estimate bi-directional optical flows between input low FPS frame pairs 58 | # Downsample images for robust intermediate flow estimation 59 | ds_im1 = F.interpolate(im1, scale_factor=1./self.scale, mode='bilinear', align_corners=False) 60 | ds_im2 = F.interpolate(im2, scale_factor=1./self.scale, mode='bilinear', align_corners=False) 61 | 62 | uvf, bottleneck_out, uvb = self.make_flow_prediction(torch.cat((ds_im1, ds_im2), dim=1)) 63 | 64 | uvf = self.scale * F.interpolate(uvf, scale_factor=self.scale, mode='bilinear', align_corners=False) 65 | uvb = self.scale * F.interpolate(uvb, scale_factor=self.scale, mode='bilinear', align_corners=False) 66 | bottleneck_out = F.interpolate(bottleneck_out, scale_factor=self.scale, mode='bilinear', align_corners=False) 67 | 68 | t = self.tlinespace[target_index] 69 | t = t.reshape(t.shape[0], 1, 1, 1) 70 | 71 | uvb_t_raw = - (1 - t) * t * uvf + t * t * uvb 72 | uvf_t_raw = (1 - t) * (1 - t) * uvf - (1 - t) * t * uvb 73 | 74 | im1w_raw = self.resample2d(im1, uvb_t_raw) # im1w_raw 75 | im2w_raw = self.resample2d(im2, uvf_t_raw) # im2w_raw 76 | 77 | # Perform intermediate bi-directional flow refinement 78 | uv_t_data = torch.cat((im1, im2, im1w_raw, uvb_t_raw, im2w_raw, uvf_t_raw), dim=1) 79 | uvf_t, uvb_t, t_vis_map = self.make_flow_interpolation(uv_t_data, bottleneck_out) 80 | 81 | uvb_t = uvb_t_raw + uvb_t # uvb_t 82 | uvf_t = uvf_t_raw + uvf_t # uvf_t 83 | 84 | im1w = self.resample2d(im1, uvb_t) # im1w 85 | im2w = self.resample2d(im2, uvf_t) # im2w 86 | 87 | # Compute final intermediate frame via weighted blending 88 | alpha1 = (1 - t) * t_vis_map 89 | alpha2 = t * (1 - t_vis_map) 90 | denorm = alpha1 + alpha2 + 1e-10 91 | im_t_out = (alpha1 * im1w + alpha2 * im2w) / denorm 92 | 93 | return im_t_out, uvb, uvf 94 | 95 | def network_eval(self, inputs, target_index): 96 | _, _, height, width = inputs[0].shape 97 | self.resample2d = MyResample2D(width, height).cuda() 98 | 99 | # Normalize inputs 100 | im1, im_target, im2 = [(im - self.mean_pix) for im in inputs] 101 | 102 | im_t_out, uvb, uvf = self.network_output([im1, im2], target_index) 103 | 104 | # Calculate losses 105 | losses = {} 106 | losses['pix_loss'] = self.L1_loss(im_t_out, im_target) 107 | 108 | im_t_out_features = self.vgg16_features(im_t_out / 255.) 109 | im_target_features = self.vgg16_features(im_target / 255.) 110 | losses['vgg16_loss'] = self.L2_loss(im_t_out_features, im_target_features) 111 | 112 | losses['warp_loss'] = self.L1_loss(self.resample2d(im1, uvb.contiguous()), im2) + \ 113 | self.L1_loss(self.resample2d(im2, uvf.contiguous()), im1) 114 | 115 | smooth_bwd = self.L1_loss(uvb[:, :, :, :-1], uvb[:, :, :, 1:]) + \ 116 | self.L1_loss(uvb[:, :, :-1, :], uvb[:, :, 1:, :]) 117 | smooth_fwd = self.L1_loss(uvf[:, :, :, :-1], uvf[:, :, :, 1:]) + \ 118 | self.L1_loss(uvf[:, :, :-1, :], uvf[:, :, 1:, :]) 119 | 120 | losses['smooth_loss'] = smooth_bwd + smooth_fwd 121 | 122 | # Coefficients for total loss determined empirically using a validation set 123 | losses['tot'] = 0.8 * losses['pix_loss'] + 0.4 * losses['warp_loss'] + 0.005 * losses['vgg16_loss'] + losses[ 124 | 'smooth_loss'] 125 | 126 | # Converts back to (0, 255) range 127 | im_t_out = im_t_out + self.mean_pix 128 | im_target = im_target + self.mean_pix 129 | 130 | return losses, im_t_out, im_target 131 | 132 | def forward(self, inputs, target_index): 133 | if 'image' in inputs: 134 | inputs = inputs['image'] 135 | 136 | if not self.training: 137 | return self.network_eval(inputs, target_index) 138 | self.resample2d = MyResample2D(inputs[0].shape[-1], inputs[0].shape[-2]).cuda() 139 | 140 | # Input frames 141 | im1, im2, im3 = inputs 142 | 143 | # Calculate Pseudo targets at interm_index 144 | with torch.no_grad(): 145 | _, psuedo_gt12, _ = self.teacher({'image': [im1, im1, im2]}, target_index) 146 | _, psuedo_gt23, _ = self.teacher({'image': [im2, im3, im3]}, target_index) 147 | psuedo_gt12, psuedo_gt23 = psuedo_gt12 - self.mean_pix, psuedo_gt23 - self.mean_pix 148 | 149 | im1, im2, im3 = im1 - self.mean_pix, im2 - self.mean_pix, im3 - self.mean_pix 150 | 151 | pred12, pred12_uvb, pred12_uvf = self.network_output([im1, im2], target_index) 152 | pred23, pred23_uvb, pred23_uvf = self.network_output([im2, im3], target_index) 153 | 154 | target_index = (self.args.num_interp + 1) - target_index 155 | 156 | ds_pred12 = F.interpolate(pred12, scale_factor=1./self.scale, mode='bilinear', align_corners=False) 157 | ds_pred23 = F.interpolate(pred23, scale_factor=1./self.scale, mode='bilinear', align_corners=False) 158 | 159 | uvf, bottleneck_out, uvb = self.make_flow_prediction(torch.cat((ds_pred12, ds_pred23), dim=1)) 160 | 161 | uvf = self.scale * F.interpolate(uvf, scale_factor=self.scale, mode='bilinear', align_corners=False) 162 | uvb = self.scale * F.interpolate(uvb, scale_factor=self.scale, mode='bilinear', align_corners=False) 163 | bottleneck_out = F.interpolate(bottleneck_out, scale_factor=self.scale, mode='bilinear', align_corners=False) 164 | 165 | t = self.tlinespace[target_index] 166 | t = t.reshape(t.shape[0], 1, 1, 1) 167 | 168 | uvb_t_raw = - (1 - t) * t * uvf + t * t * uvb 169 | uvf_t_raw = (1 - t) * (1 - t) * uvf - (1 - t) * t * uvb 170 | 171 | im12w_raw = self.resample2d(pred12, uvb_t_raw) # im1w_raw 172 | im23w_raw = self.resample2d(pred23, uvf_t_raw) # im2w_raw 173 | 174 | # Perform intermediate bi-directional flow refinement 175 | uv_t_data = torch.cat((pred12, pred23, im12w_raw, uvb_t_raw, im23w_raw, uvf_t_raw), dim=1) 176 | uvf_t, uvb_t, t_vis_map = self.make_flow_interpolation(uv_t_data, bottleneck_out) 177 | 178 | uvb_t = uvb_t_raw + uvb_t # uvb_t 179 | uvf_t = uvf_t_raw + uvf_t # uvf_t 180 | 181 | im12w = self.resample2d(pred12, uvb_t) # im1w 182 | im23w = self.resample2d(pred23, uvf_t) # im2w 183 | 184 | # Compute final intermediate frame via weighted blending 185 | alpha1 = (1 - t) * t_vis_map 186 | alpha2 = t * (1 - t_vis_map) 187 | denorm = alpha1 + alpha2 + 1e-10 188 | im_t_out = (alpha1 * im12w + alpha2 * im23w) / denorm 189 | 190 | # Calculate training loss 191 | losses = {} 192 | losses['pix_loss'] = self.L1_loss(im_t_out, im2) 193 | 194 | im_t_out_features = self.vgg16_features(im_t_out/255.) 195 | im2_features = self.vgg16_features(im2/255.) 196 | losses['vgg16_loss'] = self.L2_loss(im_t_out_features, im2_features) 197 | 198 | losses['warp_loss'] = self.L1_loss(im12w_raw, im2) + self.L1_loss(im23w_raw, im2) + \ 199 | self.L1_loss(self.resample2d(pred12, uvb), pred23) + \ 200 | self.L1_loss(self.resample2d(pred23, uvf), pred12) + \ 201 | self.L1_loss(self.resample2d(im1, pred12_uvb), im2) + \ 202 | self.L1_loss(self.resample2d(im2, pred12_uvf), im1) + \ 203 | self.L1_loss(self.resample2d(im2, pred23_uvb), im3) + \ 204 | self.L1_loss(self.resample2d(im3, pred23_uvf), im2) 205 | 206 | smooth_bwd = self.L1_loss(uvb[:, :, :, :-1], uvb[:, :, :, 1:]) + \ 207 | self.L1_loss(uvb[:, :, :-1, :], uvb[:, :, 1:, :]) + \ 208 | self.L1_loss(pred12_uvb[:, :, :, :-1], pred12_uvb[:, :, :, 1:]) + \ 209 | self.L1_loss(pred12_uvb[:, :, :-1, :], pred12_uvb[:, :, 1:, :]) + \ 210 | self.L1_loss(pred23_uvb[:, :, :, :-1], pred23_uvb[:, :, :, 1:]) + \ 211 | self.L1_loss(pred23_uvb[:, :, :-1, :], pred23_uvb[:, :, 1:, :]) 212 | 213 | smooth_fwd = self.L1_loss(uvf[:, :, :, :-1], uvf[:, :, :, 1:]) + \ 214 | self.L1_loss(uvf[:, :, :-1, :], uvf[:, :, 1:, :]) + \ 215 | self.L1_loss(pred12_uvf[:, :, :, :-1], pred12_uvf[:, :, :, 1:]) + \ 216 | self.L1_loss(pred12_uvf[:, :, :-1, :], pred12_uvf[:, :, 1:, :]) + \ 217 | self.L1_loss(pred23_uvf[:, :, :, :-1], pred23_uvf[:, :, :, 1:]) + \ 218 | self.L1_loss(pred23_uvf[:, :, :-1, :], pred23_uvf[:, :, 1:, :]) 219 | 220 | losses['loss_smooth'] = smooth_bwd + smooth_fwd 221 | 222 | losses['teacher'] = self.L1_loss(psuedo_gt12, pred12) + self.L1_loss(psuedo_gt23, pred23) 223 | 224 | # Coefficients for total loss determined empirically using a validation set 225 | losses['tot'] = self.pix_alpha * losses['pix_loss'] + self.warp_alpha * losses['warp_loss'] + \ 226 | self.vgg16_alpha * losses['vgg16_loss'] + self.smooth_alpha * losses['loss_smooth'] + self.teacher_weight * losses['teacher'] 227 | 228 | # Converts back to (0, 255) range 229 | im_t_out = im_t_out + self.mean_pix 230 | im_target = im2 + self.mean_pix 231 | 232 | return losses, im_t_out, im_target 233 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ***************************************************************************** 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # * Neither the name of the NVIDIA CORPORATION nor the 13 | # names of its contributors may be used to endorse or promote products 14 | # derived from this software without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | # 27 | # ***************************************************************************** 28 | import os 29 | import sys 30 | import shutil 31 | import natsort 32 | import numpy as np 33 | from glob import glob 34 | from imageio import imsave 35 | from skimage.measure import compare_psnr, compare_ssim 36 | from tqdm import tqdm 37 | tqdm.monitor_interval = 0 38 | 39 | import torch 40 | import torch.backends.cudnn 41 | import torch.nn.parallel 42 | import torch.optim 43 | import torch.utils.data 44 | 45 | from parser import parser 46 | import datasets 47 | import models 48 | import utils 49 | 50 | """ 51 | Reda, Fitsum A., et al. "Unsupervised Video Interpolation Using Cycle Consistency." 52 | arXiv preprint arXiv:1906.05928 (2019). 53 | 54 | Jiang, Huaizu, et al. "Super slomo: High quality estimation of multiple 55 | intermediate frames for video interpolation." arXiv pre-print arXiv:1712.00080 (2017). 56 | """ 57 | 58 | 59 | def main(): 60 | with utils.TimerBlock("\nParsing Arguments") as block: 61 | args = parser.parse_args() 62 | 63 | args.rank = int(os.getenv('RANK', 0)) 64 | 65 | block.log("Creating save directory: {}".format(args.save)) 66 | args.save_root = os.path.join(args.save, args.name) 67 | if args.write_images or args.write_video: 68 | os.makedirs(args.save_root, exist_ok=True) 69 | assert os.path.exists(args.save_root) 70 | else: 71 | os.makedirs(args.save, exist_ok=True) 72 | assert os.path.exists(args.save) 73 | 74 | os.makedirs(args.torch_home, exist_ok=True) 75 | os.environ['TORCH_HOME'] = args.torch_home 76 | 77 | args.gpus = torch.cuda.device_count() if args.gpus < 0 else args.gpus 78 | block.log('Number of gpus: {} | {}'.format(args.gpus, list(range(args.gpus)))) 79 | 80 | args.network_class = utils.module_to_dict(models)[args.model] 81 | args.dataset_class = utils.module_to_dict(datasets)[args.dataset] 82 | block.log('save_root: {}'.format(args.save_root)) 83 | block.log('val_file: {}'.format(args.val_file)) 84 | 85 | with utils.TimerBlock("Building {} Dataset".format(args.dataset)) as block: 86 | vkwargs = {'batch_size': args.gpus * args.val_batch_size, 87 | 'num_workers': args.gpus * args.workers, 88 | 'pin_memory': True, 'drop_last': True} 89 | step_size = args.val_step_size if args.val_step_size > 0 else (args.num_interp + 1) 90 | val_dataset = args.dataset_class(args=args, root=args.val_file, num_interp=args.num_interp, 91 | sample_rate=args.val_sample_rate, step_size=step_size) 92 | 93 | val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, 94 | **vkwargs) 95 | 96 | args.folder_list = natsort.natsorted( 97 | [os.path.basename(f) for f in sorted(glob(os.path.join(args.val_file, '*')))]) 98 | 99 | block.log('Number of Validation Images: {}:({} mini-batches)'.format(len(val_loader.dataset), len(val_loader))) 100 | 101 | with utils.TimerBlock("Building {} Model".format(args.model)) as block: 102 | model = args.network_class(args) 103 | 104 | block.log('Number of parameters: {val:,}'.format(val= 105 | sum([p.data.nelement() if p.requires_grad else 0 for p in model.parameters()]))) 106 | 107 | block.log('Initializing CUDA') 108 | assert torch.cuda.is_available(), 'Code supported for GPUs only at the moment' 109 | model = model.cuda() 110 | model = torch.nn.DataParallel(model, device_ids=list(range(args.gpus))) 111 | torch.manual_seed(args.seed) 112 | 113 | block.log("Attempting to Load checkpoint '{}'".format(args.resume)) 114 | if args.resume and os.path.isfile(args.resume): 115 | checkpoint = torch.load(args.resume) 116 | 117 | # Partial initialization 118 | input_dict = checkpoint['state_dict'] 119 | curr_dict = model.module.state_dict() 120 | state_dict = input_dict.copy() 121 | for key in input_dict: 122 | if key not in curr_dict: 123 | continue 124 | if curr_dict[key].shape != input_dict[key].shape: 125 | state_dict.pop(key) 126 | print("key {} skipped because of size mismatch.".format(key)) 127 | model.module.load_state_dict(state_dict, strict=False) 128 | 129 | epoch = checkpoint['epoch'] 130 | block.log("Successfully loaded checkpoint (at epoch {})".format(epoch)) 131 | elif args.resume: 132 | block.log("No checkpoint found at '{}'.\nAborted.".format(args.resume)) 133 | sys.exit(0) 134 | else: 135 | block.log("Random initialization, checkpoint not provided.") 136 | 137 | with utils.TimerBlock("Inference started ") as block: 138 | evaluate(args, val_loader, model, args.num_interp, epoch, block) 139 | 140 | 141 | def evaluate(args, val_loader, model, num_interp, epoch, block): 142 | in_height, in_width = val_loader.dataset[0]['ishape'] 143 | pred_flag, pred_values = utils.get_pred_flag(in_height, in_width) 144 | 145 | if not args.apply_vidflag: 146 | pred_flag = 0 * pred_flag + 1 147 | pred_values = 0 * pred_values 148 | 149 | if args.rank == 0 and args.write_video: 150 | video_file = os.path.join(args.save_root, '__epoch_%03d.mp4' % epoch) 151 | _pipe = utils.create_pipe(video_file, in_width, in_height, frame_rate=args.video_fps) 152 | 153 | model.eval() 154 | 155 | loss_values = utils.AverageMeter() 156 | avg_metrics = np.zeros((0, 3), dtype=float) 157 | num_batches = len(val_loader) if args.val_n_batches < 0 else args.val_n_batches 158 | 159 | with torch.no_grad(): 160 | for i, batch in enumerate(tqdm(val_loader, total=num_batches)): 161 | 162 | inputs = [b.cuda() for b in batch['image']] 163 | 164 | input_images = [inputs[0], inputs[len(inputs) // 2], inputs[-1]] 165 | inputs_dict = {'image': input_images} 166 | 167 | target_images = inputs[1:-1] 168 | tar_indices = batch['tindex'].cuda() 169 | 170 | # compute loss at mid-way 171 | tar_indices[:] = (num_interp + 1) // 2 172 | loss, outputs, _ = model(inputs_dict, tar_indices) 173 | loss_values.update(loss['tot'].data.item(), outputs.size(0)) 174 | 175 | # compute output for each intermediate timepoint 176 | output_image = inputs[0] 177 | for tarIndex in range(1, num_interp + 1): 178 | tar_indices[:] = tarIndex 179 | _, outputs, _ = model(inputs_dict, tar_indices) 180 | output_image = torch.cat((output_image, outputs), dim=1) 181 | output_image = torch.split(output_image, 3, dim=1)[1:] 182 | 183 | batch_size, _, _, _ = inputs[0].shape 184 | input_filenames = batch['input_files'][1:-1] 185 | in_height, in_width = batch['ishape'] 186 | 187 | for b in range(batch_size): 188 | first_target = (input_images[0][b].data.cpu().numpy().transpose(1, 2, 0)).astype(np.uint8) 189 | first_target = first_target[:in_height, :in_width, :] 190 | second_target = (input_images[-1][b].data.cpu().numpy().transpose(1, 2, 0)).astype(np.uint8) 191 | second_target = second_target[:in_height, :in_width, :] 192 | 193 | gt_image = first_target 194 | for index in range(num_interp): 195 | pred_image = (output_image[index][b].data.cpu().numpy().transpose(1, 2, 0)).astype(np.uint8) 196 | pred_image = pred_image[:in_height, :in_width, :] 197 | 198 | # if ground-truth not loaded, treat low FPS frames as targets 199 | if index < len(target_images): 200 | gt_image = (target_images[index][b].data.cpu().numpy().transpose(1, 2, 0)).astype(np.uint8) 201 | gt_filename = '/'.join(input_filenames[index][b].split(os.sep)[-2:]) 202 | gt_image = gt_image[:in_height, :in_width, :] 203 | 204 | # calculate metrics using skimage 205 | psnr = compare_psnr(pred_image, gt_image) 206 | ssim = compare_ssim(pred_image, gt_image, multichannel=True, gaussian_weights=True) 207 | err = pred_image.astype(np.float32) - gt_image.astype(np.float32) 208 | ie = np.mean(np.sqrt(np.sum(err * err, axis=2))) 209 | 210 | avg_metrics = np.vstack((avg_metrics, np.array([psnr, ssim, ie]))) 211 | 212 | # write_images 213 | if args.write_images: 214 | tmp_filename = os.path.join(args.save_root, "%s-%02d-%s.png" % (gt_filename[:-4], (index + 1), args.post_fix)) 215 | os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) 216 | imsave(tmp_filename, pred_image) 217 | 218 | # write video 219 | if args.rank == 0 and args.write_video: 220 | if index == 0: 221 | _pipe.stdin.write(first_target.tobytes()) 222 | try: 223 | _pipe.stdin.write((pred_image * pred_flag + pred_values).tobytes()) 224 | except AttributeError: 225 | raise AttributeError("Error in ffmpeg video creation. Inconsistent image size.") 226 | if args.write_images: 227 | tmp_filename = os.path.join(args.save_root, "%s-%02d-%s.png" % (gt_filename[:-4], 0, "ground_truth")) 228 | os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) 229 | imsave(tmp_filename, first_target) 230 | tmp_filename = os.path.join(args.save_root, "%s-%02d-%s.png" % (gt_filename[:-4], num_interp+1, "ground_truth")) 231 | imsave(tmp_filename, second_target) 232 | if (i + 1) >= num_batches: 233 | break 234 | 235 | if args.write_video: 236 | _pipe.stdin.close() 237 | _pipe.wait() 238 | 239 | """ 240 | Print final accuracy statistics. If intermediate ground truth frames are not available from the input sequence, 241 | the first low FPS frame is treated as a ground-truth frame for all intermediately predicted frames, 242 | as the quantities should not be trusted, in this case. 243 | """ 244 | for i in range(num_interp): 245 | result2print = 'interm {:02d} PSNR: {:.2f}, SSIM: {:.3f}, IE: {:.2f}'.format(i+1, 246 | np.nanmean(avg_metrics[i::num_interp], axis=0)[0], 247 | np.nanmean(avg_metrics[i::num_interp], axis=0)[1], 248 | np.nanmean(avg_metrics[i::num_interp], axis=0)[2]) 249 | block.log(result2print) 250 | 251 | avg_metrics = np.nanmean(avg_metrics, axis=0) 252 | result2print = 'Overall PSNR: {:.2f}, SSIM: {:.3f}, IE: {:.2f}'.format(avg_metrics[0], avg_metrics[1], 253 | avg_metrics[2]) 254 | v_psnr, v_ssim, v_ie = avg_metrics[0], avg_metrics[1], avg_metrics[2] 255 | block.log(result2print) 256 | 257 | # re-name video with psnr 258 | if args.rank == 0 and args.write_video: 259 | shutil.move(os.path.join(args.save_root, '__epoch_%03d.mp4' % epoch), 260 | os.path.join(args.save_root, '__epoch_%03d_psnr_%1.2f.mp4' % (epoch, avg_metrics[0]))) 261 | 262 | # Move back the model to train mode. 263 | model.train() 264 | 265 | torch.cuda.empty_cache() 266 | block.log('max memory allocated (GB): {:.3f}: '.format( 267 | torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024))) 268 | 269 | return v_psnr, v_ssim, v_ie, loss_values.val 270 | 271 | 272 | if __name__ == '__main__': 273 | main() 274 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Video Interpolation using Cycle Consistency 2 | ### [Project](https://nv-adlr.github.io/publication/2019-UnsupervisedVideoInterpolation) | [Paper](https://arxiv.org/abs/1906.05928) | [YouTube](https://drive.google.com/uc?export=view&id=1DgF-0r1agSy9Y77Bthm_w135qOABc3Xd)
3 | [Unsupervised Video Interpolation using Cycle Consistency](https://nv-adlr.github.io/publication/2019-UnsupervisedVideoInterpolation)
4 | [Fitsum A. Reda](https://scholar.google.com/citations?user=quZ_qLYAAAAJ&hl=en), [Deqing Sun](https://scholar.google.com/citations?user=t4rgICIAAAAJ&hl=en)*, Aysegul Dundar, Mohammad Shoeybi, [Guilin Liu](https://liuguilin1225.github.io/), Kevin J. Shih, Andrew Tao, [Jan Kautz](http://jankautz.com/), [Bryan Catanzaro](http://catanzaro.name/) 5 | NVIDIA Corporation
6 | In International Conferene on Computer Vision (**ICCV**) 2019.
7 | ( * Currently affiliated with Google. )
8 | 9 |

10 | 11 | 12 |

= 1.0) 28 | * Python 3 29 | * numpy 30 | * scikit-image 31 | * imageio 32 | * pillow 33 | * tqdm 34 | * tensorboardX 35 | * natsort 36 | * ffmpeg 37 | * torchvision 38 | 39 | To propose a model or change for inclusion, please submit a pull request. 40 | 41 | Multiple GPU training and mixed precision training are supported, and the code provides examples for training and inference. For more help, type
42 | 43 | python3 train.py --help 44 | 45 | 46 | 47 | ## Network Architectures 48 | 49 | Our repo now supports [Super SloMo](https://arxiv.org/abs/1712.00080). Other video interpolation architectures can be integrated with our repo with minimal changes, for instance [DVF](https://arxiv.org/pdf/1702.02463) or [SepConv](https://github.com/sniklaus/sepconv-slomo). 50 | 51 | 52 | ## Pre-trained Models 53 | We've included pre-trained models trained with cycle consistency (CC) alone, or with cycle consistency with Psuedo-supervised (CC + PS) losses.
54 | Download checkpoints to a folder `pretrained_models`. 55 | 56 | Supervised Baseline Weights 57 | - [pretrained_models/baseline_superslomo_adobe.pth](https://drive.google.com/open?id=1BKn9UBpXo6nZRjTOk0ruhjT8gEBpmdoB)(Losses with Paired Ground-Truth ) 58 | - [pretrained_models/baseline_superslomo_adobe+youtube.pth](https://drive.google.com/open?id=10MP-NyDTOzQulA1UEGKPHMRL3dgBOey3)(Losses with Paired Ground-Truth ) 59 | 60 | Unsupervised Finetuned Weights 61 | - [pretrained_models/unsupervised_random2slowflow.pth](https://drive.google.com/open?id=1F4VDNzSpxGZ0yk-BWj4aCY5sOEknNpKW)(CC only) 62 | - [pretrained_models/unsupervised_adobe2slowflow.pth](https://drive.google.com/open?id=1fHAIxYfNHPYDXrpEWtKMlYhGDYIy91r3)(CC+PS) 63 | - [pretrained_models/unsupervised_adobe+youtube2slowflow.pth](https://drive.google.com/open?id=1NGWNSPk3Pea1sUe6abnTeq-B51WbAsgY)(CC+PS) 64 | - [pretrained_models/unsupervised_random2sintel.pth](https://drive.google.com/open?id=1G04-z62gJPEaXMMwi0LOVk0h-ov3EsVt)(CC only) 65 | - [pretrained_models/unsupervised_adobe2sintel.pth](https://drive.google.com/open?id=17fmGcon07AGGpjF85xDOtMBx6fQV2p4a)(CC+PS) 66 | - [pretrained_models/unsupervised_adobe+youtube2sintel.pth](https://drive.google.com/open?id=1WJg3V0nSshSYMzMEaqMJ_yvwUUm19E-m)(CC+PS) 67 | 68 | Fully Unsupervised Weights for UCF101 evaluation 69 | - [pretrained_models/fully_unsupervised_adobe30fps.pth](https://drive.google.com/open?id=1E0OhJzu0zxZunpFK3r7MM8cpS9XhX75-)(CC only) 70 | - [pretrained_models/fully_unsupervised_battlefield30fps.pth](https://drive.google.com/open?id=11bIZA2qMrU-CaZdMSWK50EmJHV9WxVSr)(CC only) 71 | 72 | ## Data Loaders 73 | 74 | We use `VideoInterp` and `CycleVideoInterp` (in [datasets](./datasets)) dataloaders for all frame sequences, i.e. [Adobe](http://www.cs.ubc.ca/labs/imager/tr/2017/DeepVideoDeblurring/), [YouTube](https://research.google.com/youtube8m/), [SlowFlow](http://www.cvlibs.net/projects/slow_flow/), [Sintel](http://www.cvlibs.net/projects/slow_flow/), and [UCF101](http://crcv.ucf.edu/data/UCF101.php).
75 | 76 | We split Slowflow dataset into disjoint sets: A low FPS training (3.4K frames) and a high FPS 77 | test (414 frames) subset. 78 | We form the test set by selecting the first nine frames in each of the 46 clips, and train set by temporally sub-sampling 79 | the remaining frames from 240-fps to 30-fps. 80 | During evaluation, our models take as input the first and ninth frame in 81 | each test clip and interpolate seven intermediate frames. 82 | We follow a similar procedure for Sintel-1008fps, but interpolate 41 intermediate frames, i.e., conversion of frame rate 83 | from 24- to 1008-fps. 84 | Note, since SlowFlow and Sintel are of high resolution, we downsample all frames by a factor of 2 isotropically.
85 | All training and evaluations presented in the paper are done on the spatially downsampled sequences. 86 | 87 | For UCF101, we simply use the the test provided [here](https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip). 88 | 89 | ## Generating Interpolated Frames or Videos 90 | - `--write_video` and `--write_images`, if enabled will create an interpolated video and interpolated frame sequences, respectively. 91 | ``` 92 | #Example creation of interpolated videos, where we interleave low FPS input frames with one or more interpolated intermediate frames. 93 | python3 eval.py --model CycleHJSuperSloMo --num_interp 7 --flow_scale 2 --val_file ${/path/to/input/sequences} \ 94 | --name ${video_name} --save ${/path/to/output/folder} --post_fix ${output_image_tag} \ 95 | --resume ${/path/to/pre-trained/model} --write_video 96 | ``` 97 | - If input sequences for interpolation do not contain ground-truth intermediate frames, add `--val_sample_rate 0` and `--val_step_size 1` to the example script above. 98 | - For a simple test on two input frames, set `--val_file` to the folder containing both frames, and set `--val_sample_rate 0`, `--val_step_size 1`. 99 | 100 | ## Images : Results and Comparisons 101 |

102 | 103 | 104 |

107 | . 108 | 109 |
110 | . 111 |
112 |
113 | . 114 |
115 | 116 | ## Inference for Unsupervised Models 117 | - UCF101: A total of 379 folders, each with three frames, with the middle frame being the ground-truth for a single frame interpolation. 118 | ``` 119 | # Evaluation of model trained with CC alone on Adobe-30fps dataset 120 | # PSNR: 34.47, SSIM: 0.946, IE: 5.50 121 | python3 eval.py --model CycleHJSuperSloMo --num_interp 1 --flow_scale 1 --val_file /path/to/ucf/root \ 122 | --resume ./pretrained_models/fully_unsupervised_adobe30fps.pth 123 | ``` 124 | ``` 125 | # Evaluation of model trained with CC alone on Battlefield-30fps dataset 126 | # PSNR: 34.55, SSIM: 0.947, IE: 5.38 127 | python3 eval.py --model CycleHJSuperSloMo --num_interp 1 --flow_scale 1 --val_file /path/to/ucf/root \ 128 | --resume ./pretrained_models/fully_unsupervised_battlefield30fps.pth 129 | ``` 130 | - SlowFlow: A total of 46 folders, each with nine frames, with the intermediate nine frames being ground-truths for a 30->240FPS multi-frame interpolation. 131 | ``` 132 | # Evaluation of model trained with CC alone on SlowFlow-30fps train split 133 | # PSNR: 32.35, SSIM: 0.886, IE: 6.78 134 | python3 eval.py --model CycleHJSuperSloMo --num_interp 7 --flow_scale 2 --val_file /path/to/SlowFlow/val \ 135 | --resume ./pretrained_models/unsupervised_random2slowflow.pth 136 | ``` 137 | ``` 138 | # Evaluation of model finetuned with CC+PS losses on SlowFlow-30fps train split. 139 | # Model pre-trained with supervision on Adobe-240fps. 140 | # PSNR: 33.05, SSIM: 0.890, IE: 6.62 141 | python3 eval.py --model CycleHJSuperSloMo --num_interp 7 --flow_scale 2 --val_file /path/to/SlowFlow/val \ 142 | --resume ./pretrained_models/unsupervised_adobe2slowflow.pth 143 | ``` 144 | ``` 145 | # Evaluation of model finetuned with CC+PS losses on SlowFlow-30fps train split. 146 | # Model pre-trained with supervision on Adobe+YouTube-240fps. 147 | # PSNR: 33.20, SSIM: 0.891, IE: 6.56 148 | python3 eval.py --model CycleHJSuperSloMo --num_interp 7 --flow_scale 2 --val_file /path/to/SlowFlow/val \ 149 | --resume ./pretrained_models/unsupervised_adobe+youtube2slowflow.pth 150 | ``` 151 | - Sintel: A total of 13 folders, each with 43 frames, with the intermediate 41 frames being ground-truths for a 30->1008FPS multi-frame interpolation. 152 | ``` 153 | We simply use the same commands used for SlowFlow, but setting `--num_interp 41` 154 | and the corresponding `--resume *2sintel.pth` pre-trained models should lead to the number we presented in our papers. 155 | ``` 156 | ## Inference for Supervised Baseline Models 157 | - UCF101: A total of 379 folders, each with three frames, with the middle frame being the ground-truth for a single frame interpolation. 158 | ``` 159 | # Evaluation of model trained with Paird-GT on Adobe-240fps dataset 160 | # PSNR: 34.63, SSIM: 0.946, IE: 5.48 161 | python3 eval.py --model HJSuperSloMo --num_interp 1 --flow_scale 1 --val_file /path/to/ucf/root \ 162 | --resume ./pretrained_models/baseline_superslomo_adobe.pth 163 | ``` 164 | - SlowFlow: A total of 46 folders, each with nine frames, with the intermediate nine frames being ground-truths for a 30->240FPS multi-frame interpolation. 165 | ``` 166 | # Evaluation of model trained with paird-GT on Adobe-240fps dataset 167 | # PSNR: 32.84, SSIM: 0.887, IE: 6.67 168 | python3 eval.py --model HJSuperSloMo --num_interp 7 --flow_scale 2 --val_file /path/to/SlowFlow/val \ 169 | --resume ./pretrained_models/baseline_superslomo_adobe.pth 170 | ``` 171 | ``` 172 | # Evaluation of model trained with paird-GT on Adobe+YouTube-240fps dataset 173 | # PSNR: 33.13, SSIM: 0.889, IE: 6.63 174 | python3 eval.py --model HJSuperSloMo --num_interp 7 --flow_scale 2 --val_file /path/to/SlowFlow/val \ 175 | --resume ./pretrained_models/baseline_superslomo_adobe+youtube.pth 176 | ``` 177 | - Sintel: We use commands similar to SlowFlow, but setting `--num_interp 41`. 178 | 179 | ## Training and Reproducing Our Results 180 | ``` 181 | # CC alone: Fully unsupervised training on SlowFlow and evaluation on SlowFlow 182 | # SlowFlow/val target PSNR: 32.35, SSIM: 0.886, IE: 6.78 183 | python3 -m torch.distributed.launch --nproc_per_node=16 train.py --model CycleHJSuperSloMo \ 184 | --flow_scale 2.0 --batch_size 2 --crop_size 384 384 --print_freq 1 --dataset CycleVideoInterp \ 185 | --step_size 1 --sample_rate 0 --num_interp 7 --val_num_interp 7 --skip_aug --save_freq 20 --start_epoch 0 \ 186 | --train_file /path/to/SlowFlow/train --val_file SlowFlow/val --name unsupervised_slowflow --save /path/to/output 187 | 188 | # --nproc_per_node=16, we use a total of 16 V100 GPUs over two nodes. 189 | ``` 190 | 191 | ``` 192 | # CC + PS: Unsupervised fine-tuning on SlowFlow with a baseline model pre-trained on Adobe+YouTube-240fps. 193 | # SlowFlow/val target PSNR: 33.20, SSIM: 0.891, IE: 6.56 194 | python3 -m torch.distributed.launch --nproc_per_node=16 train.py --model CycleHJSuperSloMo \ 195 | --flow_scale 2.0 --batch_size 2 --crop_size 384 384 --print_freq 1 --dataset CycleVideoInterp \ 196 | --step_size 1 --sample_rate 0 --num_interp 7 --val_num_interp 7 --skip_aug --save_freq 20 --start_epoch 0 \ 197 | --train_file /path/to/SlowFlow/train --val_file /path/to/SlowFlow/val --name finetune_slowflow \ 198 | --save /path/to/output --resume ./pretrained_models/baseline_superslomo_adobe+youtube.pth 199 | ``` 200 | 201 | ``` 202 | # Supervised baseline training on Adobe240-fps and evaluation on SlowFlow 203 | # SlowFlow/val target PSNR: 32.84, SSIM: 0.887, IE: 6.67 204 | python3 -m torch.distributed.launch --nproc_per_node=16 train.py --model HJSuperSloMo \ 205 | --flow_scale 2.0 --batch_size 2 --crop_size 352 352 --print_freq 1 --dataset VideoInterp \ 206 | --num_interp 7 --val_num_interp 7 --skip_aug --save_freq 20 --start_epoch 0 --stride 32 \ 207 | --train_file /path/to/Adobe-240fps/train --val_file /path/to/SlowFlow/val --name supervised_adobe \ 208 | --save /path/to/output 209 | ``` 210 | 211 | ## Reference 212 | 213 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper or code accordingly: 214 | 215 | ``` 216 | @InProceedings{Reda_2019_ICCV, 217 | author = {Fitsum A Reda and Deqing Sun and Aysegul Dundar and Mohammad Shoeybi and Guilin Liu and Kevin J Shih and Andrew Tao and Jan Kautz and Bryan Catanzaro}, 218 | title = {Unsupervised Video Interpolation Using Cycle Consistency}, 219 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 220 | month = {October}, 221 | year = {2019}, 222 | url={https://nv-adlr.github.io/publication/2019-UnsupervisedVideoInterpolation} 223 | } 224 | ``` 225 | We encourage people to contribute to our code base and provide suggestions, point any issues, or solution using merge request, and we hope this repo is useful. 226 | 227 | ## Acknowledgments 228 | Parts of the code were inspired by [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch), [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch), and [avinashpaliwal/Super-SloMo](https://github.com/avinashpaliwal/Super-SloMo). 229 | 230 | We would also like to thank Huaizu Jiang. 231 | 232 | ## Coding style 233 | * 4 spaces for indentation rather than tabs 234 | * 80 character line length 235 | * PEP8 formatting 236 | -------------------------------------------------------------------------------- /models/HJSuperSloMo.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import numpy as np 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | import torchvision 36 | 37 | from .model_utils import MyResample2D 38 | 39 | 40 | class HJSuperSloMo(nn.Module): 41 | def __init__(self, args, mean_pix=[109.93, 109.167, 101.455], in_channel=6): 42 | super(HJSuperSloMo, self).__init__() 43 | self.is_output_flow = False 44 | 45 | # --------------------- encoder -------------------- 46 | # conv1 47 | self.flow_pred_encoder_layer1 = self.make_flow_pred_encoder_layer(in_channel, 32, 7, 3) 48 | self.flow_pred_encoder_layer2 = self.make_flow_pred_encoder_layer(32, 64, 5, 2) 49 | self.flow_pred_encoder_layer3 = self.make_flow_pred_encoder_layer(64, 128) 50 | self.flow_pred_encoder_layer4 = self.make_flow_pred_encoder_layer(128, 256) 51 | self.flow_pred_encoder_layer5 = self.make_flow_pred_encoder_layer(256, 512) 52 | 53 | self.flow_pred_bottleneck = self.make_flow_pred_encoder_layer(512, 512) 54 | 55 | self.flow_pred_decoder_layer5 = self.make_flow_pred_decoder_layer(512, 512) 56 | self.flow_pred_decoder_layer4 = self.make_flow_pred_decoder_layer(1024, 256) 57 | self.flow_pred_decoder_layer3 = self.make_flow_pred_decoder_layer(512, 128) 58 | self.flow_pred_decoder_layer2 = self.make_flow_pred_decoder_layer(256, 64) 59 | self.flow_pred_decoder_layer1 = self.make_flow_pred_decoder_layer(128, 32) 60 | 61 | self.flow_pred_refine_layer = nn.Sequential( 62 | nn.Conv2d(64, 32, 3, padding=1), 63 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 64 | 65 | self.forward_flow_conv = nn.Conv2d(32, 2, 1) 66 | self.backward_flow_conv = nn.Conv2d(32, 2, 1) 67 | 68 | # -------------- flow interpolation encoder-decoder -------------- 69 | self.flow_interp_encoder_layer1 = self.make_flow_interp_encoder_layer(16, 32, 7, 3) 70 | self.flow_interp_encoder_layer2 = self.make_flow_interp_encoder_layer(32, 64, 5, 2) 71 | self.flow_interp_encoder_layer3 = self.make_flow_interp_encoder_layer(64, 128) 72 | self.flow_interp_encoder_layer4 = self.make_flow_interp_encoder_layer(128, 256) 73 | self.flow_interp_encoder_layer5 = self.make_flow_interp_encoder_layer(256, 512) 74 | 75 | self.flow_interp_bottleneck = self.make_flow_interp_encoder_layer(512, 512) 76 | 77 | self.flow_interp_decoder_layer5 = self.make_flow_interp_decoder_layer(1024, 512) 78 | self.flow_interp_decoder_layer4 = self.make_flow_interp_decoder_layer(1024, 256) 79 | self.flow_interp_decoder_layer3 = self.make_flow_interp_decoder_layer(512, 128) 80 | self.flow_interp_decoder_layer2 = self.make_flow_interp_decoder_layer(256, 64) 81 | self.flow_interp_decoder_layer1 = self.make_flow_interp_decoder_layer(128, 32) 82 | 83 | self.flow_interp_refine_layer = nn.Sequential( 84 | nn.Conv2d(64, 32, 3, padding=1), 85 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 86 | 87 | self.flow_interp_forward_out_layer = nn.Conv2d(32, 2, 1) 88 | self.flow_interp_backward_out_layer = nn.Conv2d(32, 2, 1) 89 | 90 | # visibility 91 | self.flow_interp_vis_layer = nn.Conv2d(32, 1, 1) 92 | 93 | self.resample2d_train = MyResample2D(args.crop_size[1], args.crop_size[0]) 94 | 95 | mean_pix = torch.from_numpy(np.array(mean_pix)).float() 96 | mean_pix = mean_pix.view(1, 3, 1, 1) 97 | self.register_buffer('mean_pix', mean_pix) 98 | 99 | self.args = args 100 | self.scale = args.flow_scale 101 | 102 | self.L1_loss = nn.L1Loss() 103 | self.L2_loss = nn.MSELoss() 104 | self.ignore_keys = ['vgg', 'grid_w', 'grid_h', 'tlinespace', 'resample2d_train', 'resample2d'] 105 | self.register_buffer('tlinespace', torch.linspace(0, 1, 2 + args.num_interp).float()) 106 | 107 | vgg16 = torchvision.models.vgg16(pretrained=True) 108 | self.vgg16_features = nn.Sequential(*list(vgg16.children())[0][:22]) 109 | for param in self.vgg16_features.parameters(): 110 | param.requires_grad = False 111 | 112 | # loss weights 113 | self.pix_alpha = 0.8 114 | self.warp_alpha = 0.4 115 | self.vgg16_alpha = 0.005 116 | self.smooth_alpha = 1. 117 | 118 | def make_flow_pred_encoder_layer(self, in_chn, out_chn, kernel_size=3, padding=1): 119 | layer = nn.Sequential( 120 | nn.Conv2d(in_chn, out_chn, kernel_size, padding=padding), 121 | nn.LeakyReLU(inplace=True, negative_slope=0.1), 122 | nn.Conv2d(out_chn, out_chn, kernel_size, padding=padding), 123 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 124 | return layer 125 | 126 | def make_flow_pred_decoder_layer(self, in_chn, out_chn): 127 | layer = nn.Sequential( 128 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), 129 | nn.Conv2d(in_chn, out_chn, 3, padding=1), 130 | nn.LeakyReLU(inplace=True, negative_slope=0.1), 131 | nn.Conv2d(out_chn, out_chn, 3, padding=1), 132 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 133 | return layer 134 | 135 | def make_flow_interp_encoder_layer(self, in_chn, out_chn, kernel_size=3, padding=1): 136 | layer = nn.Sequential( 137 | nn.Conv2d(in_chn, out_chn, kernel_size, padding=padding), 138 | nn.LeakyReLU(inplace=True, negative_slope=0.1), 139 | nn.Conv2d(out_chn, out_chn, kernel_size, padding=padding), 140 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 141 | return layer 142 | 143 | def make_flow_interp_decoder_layer(self, in_chn, out_chn): 144 | layer = nn.Sequential( 145 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), 146 | nn.Conv2d(in_chn, out_chn, 3, padding=1), 147 | nn.LeakyReLU(inplace=True, negative_slope=0.1), 148 | nn.Conv2d(out_chn, out_chn, 3, padding=1), 149 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 150 | return layer 151 | 152 | def make_flow_interpolation(self, in_data, flow_pred_bottleneck_out): 153 | flow_interp_encoder_out1 = self.flow_interp_encoder_layer1(in_data) 154 | flow_interp_encoder_out1_pool = F.avg_pool2d(flow_interp_encoder_out1, 2, stride=2) 155 | 156 | flow_interp_encoder_out2 = self.flow_interp_encoder_layer2(flow_interp_encoder_out1_pool) 157 | flow_interp_encoder_out2_pool = F.avg_pool2d(flow_interp_encoder_out2, 2, stride=2) 158 | 159 | flow_interp_encoder_out3 = self.flow_interp_encoder_layer3(flow_interp_encoder_out2_pool) 160 | flow_interp_encoder_out3_pool = F.avg_pool2d(flow_interp_encoder_out3, 2, stride=2) 161 | 162 | flow_interp_encoder_out4 = self.flow_interp_encoder_layer4(flow_interp_encoder_out3_pool) 163 | flow_interp_encoder_out4_pool = F.avg_pool2d(flow_interp_encoder_out4, 2, stride=2) 164 | 165 | flow_interp_encoder_out5 = self.flow_interp_encoder_layer5(flow_interp_encoder_out4_pool) 166 | flow_interp_encoder_out5_pool = F.avg_pool2d(flow_interp_encoder_out5, 2, stride=2) 167 | 168 | flow_interp_bottleneck_out = self.flow_interp_bottleneck(flow_interp_encoder_out5_pool) 169 | flow_interp_bottleneck_out = torch.cat((flow_pred_bottleneck_out, 170 | flow_interp_bottleneck_out), dim=1) 171 | 172 | flow_interp_decoder_out5 = self.flow_interp_decoder_layer5(flow_interp_bottleneck_out) 173 | flow_interp_decoder_out5 = torch.cat((flow_interp_encoder_out5, flow_interp_decoder_out5), dim=1) 174 | 175 | flow_interp_decoder_out4 = self.flow_interp_decoder_layer4(flow_interp_decoder_out5) 176 | flow_interp_decoder_out4 = torch.cat((flow_interp_encoder_out4, flow_interp_decoder_out4), dim=1) 177 | 178 | flow_interp_decoder_out3 = self.flow_interp_decoder_layer3(flow_interp_decoder_out4) 179 | flow_interp_decoder_out3 = torch.cat((flow_interp_encoder_out3, flow_interp_decoder_out3), dim=1) 180 | 181 | flow_interp_decoder_out2 = self.flow_interp_decoder_layer2(flow_interp_decoder_out3) 182 | flow_interp_decoder_out2 = torch.cat((flow_interp_encoder_out2, flow_interp_decoder_out2), dim=1) 183 | 184 | flow_interp_decoder_out1 = self.flow_interp_decoder_layer1(flow_interp_decoder_out2) 185 | flow_interp_decoder_out1 = torch.cat((flow_interp_encoder_out1, flow_interp_decoder_out1), dim=1) 186 | 187 | flow_interp_motion_rep = self.flow_interp_refine_layer(flow_interp_decoder_out1) 188 | 189 | flow_interp_forward_flow = self.flow_interp_forward_out_layer(flow_interp_motion_rep) 190 | flow_interp_backward_flow = self.flow_interp_backward_out_layer(flow_interp_motion_rep) 191 | 192 | flow_interp_vis_map = self.flow_interp_vis_layer(flow_interp_motion_rep) 193 | flow_interp_vis_map = torch.sigmoid(flow_interp_vis_map) 194 | 195 | return flow_interp_forward_flow, flow_interp_backward_flow, flow_interp_vis_map 196 | 197 | def make_flow_prediction(self, x): 198 | 199 | encoder_out1 = self.flow_pred_encoder_layer1(x) 200 | encoder_out1_pool = F.avg_pool2d(encoder_out1, 2, stride=2) 201 | 202 | encoder_out2 = self.flow_pred_encoder_layer2(encoder_out1_pool) 203 | encoder_out2_pool = F.avg_pool2d(encoder_out2, 2, stride=2) 204 | 205 | encoder_out3 = self.flow_pred_encoder_layer3(encoder_out2_pool) 206 | encoder_out3_pool = F.avg_pool2d(encoder_out3, 2, stride=2) 207 | 208 | encoder_out4 = self.flow_pred_encoder_layer4(encoder_out3_pool) 209 | encoder_out4_pool = F.avg_pool2d(encoder_out4, 2, stride=2) 210 | 211 | encoder_out5 = self.flow_pred_encoder_layer5(encoder_out4_pool) 212 | encoder_out5_pool = F.avg_pool2d(encoder_out5, 2, stride=2) 213 | 214 | bottleneck_out = self.flow_pred_bottleneck(encoder_out5_pool) 215 | 216 | decoder_out5 = self.flow_pred_decoder_layer5(bottleneck_out) 217 | decoder_out5 = torch.cat((encoder_out5, decoder_out5), dim=1) 218 | 219 | decoder_out4 = self.flow_pred_decoder_layer4(decoder_out5) 220 | decoder_out4 = torch.cat((encoder_out4, decoder_out4), dim=1) 221 | 222 | decoder_out3 = self.flow_pred_decoder_layer3(decoder_out4) 223 | decoder_out3 = torch.cat((encoder_out3, decoder_out3), dim=1) 224 | 225 | decoder_out2 = self.flow_pred_decoder_layer2(decoder_out3) 226 | decoder_out2 = torch.cat((encoder_out2, decoder_out2), dim=1) 227 | 228 | decoder_out1 = self.flow_pred_decoder_layer1(decoder_out2) 229 | decoder_out1 = torch.cat((encoder_out1, decoder_out1), dim=1) 230 | 231 | motion_rep = self.flow_pred_refine_layer(decoder_out1) 232 | 233 | uvf = self.forward_flow_conv(motion_rep) 234 | uvb = self.backward_flow_conv(motion_rep) 235 | 236 | return uvf, bottleneck_out, uvb 237 | 238 | def forward(self, inputs, target_index): 239 | if 'image' in inputs: 240 | inputs = inputs['image'] 241 | 242 | if self.training: 243 | self.resample2d = self.resample2d_train 244 | else: 245 | _, _, height, width = inputs[0].shape 246 | self.resample2d = MyResample2D(width, height).cuda() 247 | 248 | # Normalize inputs 249 | im1, im_target, im2 = [(im - self.mean_pix) for im in inputs] 250 | 251 | # Estimate bi-directional optical flows between input low FPS frame pairs 252 | # Downsample images for robust intermediate flow estimation 253 | ds_im1 = F.interpolate(im1, scale_factor=1./self.scale, mode='bilinear', align_corners=False) 254 | ds_im2 = F.interpolate(im2, scale_factor=1./self.scale, mode='bilinear', align_corners=False) 255 | 256 | uvf, bottleneck_out, uvb = self.make_flow_prediction(torch.cat((ds_im1, ds_im2), dim=1)) 257 | 258 | uvf = self.scale * F.interpolate(uvf, scale_factor=self.scale, mode='bilinear', align_corners=False) 259 | uvb = self.scale * F.interpolate(uvb, scale_factor=self.scale, mode='bilinear', align_corners=False) 260 | bottleneck_out = F.interpolate(bottleneck_out, scale_factor=self.scale, mode='bilinear', align_corners=False) 261 | 262 | t = self.tlinespace[target_index] 263 | t = t.reshape(t.shape[0], 1, 1, 1) 264 | 265 | uvb_t_raw = - (1 - t) * t * uvf + t * t * uvb 266 | uvf_t_raw = (1 - t) * (1 - t) * uvf - (1 - t) * t * uvb 267 | 268 | im1w_raw = self.resample2d(im1, uvb_t_raw) # im1w_raw 269 | im2w_raw = self.resample2d(im2, uvf_t_raw) # im2w_raw 270 | 271 | # Perform intermediate bi-directional flow refinement 272 | uv_t_data = torch.cat((im1, im2, im1w_raw, uvb_t_raw, im2w_raw, uvf_t_raw), dim=1) 273 | uvf_t, uvb_t, t_vis_map = self.make_flow_interpolation(uv_t_data, bottleneck_out) 274 | 275 | uvb_t = uvb_t_raw + uvb_t # uvb_t 276 | uvf_t = uvf_t_raw + uvf_t # uvf_t 277 | 278 | im1w = self.resample2d(im1, uvb_t) # im1w 279 | im2w = self.resample2d(im2, uvf_t) # im2w 280 | 281 | # Compute final intermediate frame via weighted blending 282 | alpha1 = (1 - t) * t_vis_map 283 | alpha2 = t * (1 - t_vis_map) 284 | denorm = alpha1 + alpha2 + 1e-10 285 | im_t_out = (alpha1 * im1w + alpha2 * im2w) / denorm 286 | 287 | # Calculate training loss 288 | losses = {} 289 | losses['pix_loss'] = self.L1_loss(im_t_out, im_target) 290 | 291 | im_t_out_features = self.vgg16_features(im_t_out/255.) 292 | im_target_features = self.vgg16_features(im_target/255.) 293 | losses['vgg16_loss'] = self.L2_loss(im_t_out_features, im_target_features) 294 | 295 | losses['warp_loss'] = self.L1_loss(im1w_raw, im_target) + self.L1_loss(im2w_raw, im_target) + \ 296 | self.L1_loss(self.resample2d(im1, uvb.contiguous()), im2) + \ 297 | self.L1_loss(self.resample2d(im2, uvf.contiguous()), im1) 298 | 299 | smooth_bwd = self.L1_loss(uvb[:, :, :, :-1], uvb[:, :, :, 1:]) + \ 300 | self.L1_loss(uvb[:, :, :-1, :], uvb[:, :, 1:, :]) 301 | smooth_fwd = self.L1_loss(uvf[:, :, :, :-1], uvf[:, :, :, 1:]) + \ 302 | self.L1_loss(uvf[:, :, :-1, :], uvf[:, :, 1:, :]) 303 | 304 | losses['smooth_loss'] = smooth_bwd + smooth_fwd 305 | 306 | # Coefficients for total loss determined empirically using a validation set 307 | losses['tot'] = self.pix_alpha * losses['pix_loss'] + self.warp_alpha * losses['warp_loss'] \ 308 | + self.vgg16_alpha * losses['vgg16_loss'] + self.smooth_alpha * losses['smooth_loss'] 309 | 310 | # Converts back to (0, 255) range 311 | im_t_out = im_t_out + self.mean_pix 312 | im_target = im_target + self.mean_pix 313 | 314 | return losses, im_t_out, im_target 315 | -------------------------------------------------------------------------------- /datasets/data_transforms.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | from __future__ import division 28 | import random 29 | from PIL import Image 30 | import numpy as np 31 | from torchvision.transforms import functional as transf 32 | 33 | """ 34 | Compose for Multiple Arguments 35 | """ 36 | 37 | 38 | class Compose(object): 39 | """Custom class to serialise transformations that 40 | accept multiple input arguments 41 | 42 | Args: 43 | transforms (list of ``Transform`` objects): list of custom transforms to compose 44 | 45 | Example: 46 | composed_transf = data_transforms.Compose( 47 | [NumpyToPILImage(), 48 | RandomScaledCrop2D(crop_height=384, crop_width=384, min_crop_ratio=0.8), 49 | PILImageToNumpy(), 50 | RandomReverseSequence(), 51 | RandomBrightness(brightness_factor=0.1) 52 | ]) 53 | """ 54 | 55 | def __init__(self, transforms): 56 | self.transforms = transforms 57 | 58 | def __call__(self, inputs, targets): 59 | for transform in self.transforms: 60 | inputs, targets = transform(inputs, targets) 61 | return inputs, targets 62 | 63 | 64 | """ 65 | Image Type Conversion 66 | """ 67 | 68 | 69 | class NumpyToPILImage(object): 70 | """Convert numpy array to an instance of PIL Image, so we can use 71 | geometric transformations already available in torchvision.transforms.functional.*. 72 | """ 73 | 74 | def __call__(self, inputs, targets): 75 | inputs = [Image.fromarray(np.clip(im, 0, 255)) for im in inputs] 76 | targets = [Image.fromarray(np.clip(im, 0, 255)) for im in targets] 77 | return inputs, targets 78 | 79 | 80 | class PILImageToNumpy(object): 81 | """Convert PIL Image to a numpy array at the end of geometric transformations. 82 | Note. All photometric transformations currently work on numpy arrays, because for some 83 | transformations, there is an implementation mis-match between torchvision and the ones defined 84 | in flownet2 (Caffe: https://github.com/lmb-freiburg/flownet2), which they are derived/inspired from. 85 | """ 86 | 87 | def __call__(self, inputs, targets): 88 | inputs = [np.array(im) for im in inputs] 89 | targets = [np.array(im) for im in targets] 90 | return inputs, targets 91 | 92 | 93 | """ 94 | Geometric Augmentation 95 | """ 96 | 97 | 98 | class RandomRotate2D(object): 99 | """Apply random 2D in-plane rotation of on input and target image sequences. 100 | For video interpolation or optical flow studies, we also add a small 101 | offset rotation to each image in the sequence ranging from [-delta, delta] degrees 102 | in a linear fashion, such that networks can learn to recover the added fake rotation. 103 | """ 104 | def __init__(self, base_angle=20, delta_angle=0, resample=Image.BILINEAR): 105 | self.base_angle = base_angle 106 | self.delta_angle = delta_angle 107 | self.resample = resample 108 | 109 | def __call__(self, inputs, targets): 110 | base = random.uniform(-self.base_angle, self.base_angle) 111 | delta = random.uniform(-self.delta_angle, self.delta_angle) 112 | resample = self.resample 113 | 114 | inputs[0] = transf.rotate(inputs[0], angle=(base - delta / 2.), resample=resample) 115 | inputs[-1] = transf.rotate(inputs[1], angle=(base + delta / 2.), resample=resample) 116 | 117 | # Apply linearly varying offset to targets 118 | # calculate offset ~ (-delta/2., delta/2.) 119 | tlinspace = np.linspace(-1, 1, len(targets) + 2) 120 | for i, image in enumerate(targets): 121 | offset = tlinspace[i + 1] * delta / 2. 122 | targets[i] = transf.rotate(image, angle=(base + offset), resample=resample) 123 | 124 | return inputs, targets 125 | 126 | 127 | class RandomTranslate2D(object): 128 | """Apply random 2D translation on input and target image sequences. 129 | For video interpolation or optical flow studies, we also add a small 130 | offset translation to each image in the sequence ranging from [-delta, delta] pixel displacements 131 | in a linear fashion, such that networks can learn to recover the added fake translation. 132 | """ 133 | def __init__(self, max_displ_factor=0.05, resample=Image.NEAREST): 134 | self.max_displ_factor = max_displ_factor 135 | self.resample = resample 136 | 137 | def __call__(self, inputs, targets): 138 | # h, w, _ = inputs[0].shape 139 | w, h = inputs[0].size 140 | max_displ_factor = self.max_displ_factor 141 | resample = self.resample 142 | 143 | # Sample a displacement in [-max_displ, max_displ] for both height and width 144 | max_width_displ = int(w * max_displ_factor) 145 | wd = random.randint(-max_width_displ, max_width_displ) 146 | 147 | max_height_displ = int(h * max_displ_factor) 148 | hd = random.randint(-max_height_displ, max_height_displ) 149 | 150 | inputs[0] = transf.affine(inputs[0], angle=0, translate=(wd, hd), scale=1, shear=0, resample=resample) 151 | inputs[-1] = transf.affine(inputs[-1], angle=0, translate=(-wd, -hd), scale=1, shear=0, resample=resample) 152 | 153 | # Apply linearly varying offset to targets 154 | # calculate offset ~ (-{w|h}_delta, {w|h}_delta}) 155 | tlinspace = -1 * np.linspace(-1, 1, len(targets) + 2) 156 | for i, image in enumerate(targets): 157 | wo, ho = tlinspace[i + 1] * wd, tlinspace[i + 1] * hd 158 | targets[i] = transf.affine(image, angle=0, translate=(wo, ho), scale=1, shear=0, resample=resample) 159 | 160 | return inputs, targets 161 | 162 | 163 | class RandomCrop2D(object): 164 | """A simple random 3D crop with a provided crop_size. 165 | """ 166 | def __init__(self, crop_height, crop_width): 167 | self.crop_height = crop_height 168 | self.crop_width = crop_width 169 | 170 | def __call__(self, inputs, targets): 171 | width, height = inputs[0].size 172 | crop_width, crop_height = self.crop_width, self.crop_height 173 | 174 | # sample crop indices 175 | left = random.randint(0, width - crop_width) 176 | top = random.randint(0, height - crop_height) 177 | 178 | for i, image in enumerate(inputs): 179 | inputs[i] = transf.crop(image, top, left, crop_height, crop_width) 180 | 181 | for i, image in enumerate(targets): 182 | targets[i] = transf.crop(image, top, left, crop_height, crop_width) 183 | 184 | return inputs, targets 185 | 186 | 187 | class RandomScaledCrop2D(object): 188 | """Apply random 2D crop followed by a scale operation. 189 | Note to simulate a simple crop, set 190 | ``min_crop_ratio=min(crop_height,crop_width)/min(height, width)``. 191 | We basically, first, crop the original image with a size larger or smaller than 192 | the desired crop size. We then scale the images to the desired crop_size. 193 | So, in a way, this transformation encapsulates two augmentations: scale + crop. 194 | """ 195 | 196 | def __init__(self, crop_height, crop_width, min_crop_ratio=0.6, resample=Image.BILINEAR): 197 | # Aspect ratio inherited from (crop_height, crop_width) 198 | self.crop_aspect = crop_height / crop_width 199 | self.crop_shape = (crop_height, crop_width) 200 | self.min_crop_ratio = min_crop_ratio 201 | self.resample = resample 202 | 203 | def __call__(self, inputs, targets): 204 | # height, width, _ = inputs[0].shape 205 | width, height = inputs[0].size 206 | crop_aspect = self.crop_aspect 207 | crop_shape = self.crop_shape 208 | resample = self.resample 209 | min_crop_ratio = self.min_crop_ratio 210 | 211 | source_aspect = height / width 212 | 213 | # sample a crop factor in [min_crop_ratio, 1.) 214 | crop_ratio = random.uniform(min_crop_ratio, 1.0) 215 | 216 | # Preserve aspect ratio provided by (crop_height, crop_width) 217 | # Calculate crop height and with, apply crop_ratio along the min(height,width)'s axis 218 | if crop_aspect < source_aspect: 219 | cwidth = int(width * crop_ratio) 220 | cheight = int(cwidth * crop_aspect) 221 | else: 222 | cheight = int(height * crop_ratio) 223 | cwidth = int(cheight / crop_aspect) 224 | 225 | # Avoid bilinear re-sampling crop_size == full_size 226 | if cheight == cwidth and cwidth == width: 227 | return inputs, targets 228 | 229 | # sample crop indices 230 | left = random.randint(0, width - cwidth) 231 | top = random.randint(0, height - cheight) 232 | 233 | for i, image in enumerate(inputs): 234 | inputs[i] = transf.resized_crop(inputs[i], top, left, cheight, cwidth, crop_shape, interpolation=resample) 235 | for i, image in enumerate(targets): 236 | targets[i] = transf.resized_crop(targets[i], top, left, cheight, cwidth, crop_shape, interpolation=resample) 237 | 238 | return inputs, targets 239 | 240 | 241 | class RandomHorizontalFlip(object): 242 | """Apply a random horizontal flip.""" 243 | 244 | def __init__(self, prob=0.5): 245 | self.prob = prob 246 | 247 | def __call__(self, inputs, targets): 248 | # 249 | if random.random() < self.prob: 250 | return inputs, targets 251 | 252 | # Apply a horizontal flip 253 | for i, image in enumerate(inputs): 254 | inputs[i] = transf.hflip(image) 255 | for i, image in enumerate(targets): 256 | targets[i] = transf.hflip(image) 257 | 258 | return inputs, targets 259 | 260 | 261 | class RandomVerticalFlip(object): 262 | """Apply a random vertical flip.""" 263 | 264 | def __init__(self, prob=0.5): 265 | self.prob = prob 266 | 267 | def __call__(self, inputs, targets): 268 | # 269 | if random.random() < self.prob: 270 | return inputs, targets 271 | 272 | # Apply a vertical flip 273 | for i, image in enumerate(inputs): 274 | inputs[i] = transf.vflip(image) 275 | for i, image in enumerate(targets): 276 | targets[i] = transf.vflip(image) 277 | 278 | return inputs, targets 279 | 280 | 281 | class RandomReverseSequence(object): 282 | """Randomly reverse the order of inputs, and targets""" 283 | 284 | def __init__(self, prob=0.5): 285 | self.prob = prob 286 | 287 | def __call__(self, inputs, targets): 288 | if random.random() < self.prob: 289 | return inputs, targets 290 | 291 | # Reverse sequence 292 | inputs = inputs[::-1] 293 | targets = targets[::-1] 294 | 295 | return inputs, targets 296 | 297 | 298 | """ 299 | Photometric Augmentation 300 | """ 301 | 302 | 303 | class RandomGamma(object): 304 | """Apply a gamma transformation, with gamma factor of (gamma_low, anf gamma_high)""" 305 | 306 | def __init__(self, gamma_low, gamma_high): 307 | self.gamma_low = gamma_low 308 | self.gamma_high = gamma_high 309 | 310 | def __call__(self, inputs, targets): 311 | gamma = random.uniform(self.gamma_low, self.gamma_high) 312 | 313 | if gamma == 1.0: 314 | return inputs, targets 315 | gamma_inv = 1. / gamma 316 | 317 | # Apply a gamma 318 | for i, image in enumerate(inputs): 319 | image = np.power(image / 255.0, gamma_inv) * 255.0 320 | inputs[i] = np.clip(image, 0., 255.) 321 | 322 | for i, image in enumerate(targets): 323 | image = np.power(image / 255.0, gamma_inv) * 255.0 324 | targets[i] = np.clip(image, 0., 255.) 325 | 326 | return inputs, targets 327 | 328 | 329 | class RandomBrightness(object): 330 | """Apply a random brightness to each channel in the image. 331 | An implementation that is quite distinct from torchvision. 332 | """ 333 | 334 | def __init__(self, brightness_factor=0.1): 335 | self.brightness_factor = brightness_factor 336 | 337 | def __call__(self, inputs, targets): 338 | brighness_factor = [1 + random.uniform(-self.brightness_factor, self.brightness_factor) for _ in range(3)] 339 | brighness_factor = np.array(brighness_factor) 340 | 341 | # Apply a brightness 342 | for i, image in enumerate(inputs): 343 | image = image * brighness_factor 344 | inputs[i] = np.clip(image, 0., 255.) 345 | 346 | for i, image in enumerate(targets): 347 | image = image * brighness_factor 348 | targets[i] = np.clip(image, 0., 255.) 349 | 350 | return inputs, targets 351 | 352 | 353 | class RandomColorOrder(object): 354 | """Randomly re-order the channels of images. 355 | """ 356 | 357 | def __init__(self, prob=0.5): 358 | self.prob = prob 359 | 360 | def __call__(self, inputs, targets): 361 | if random.random() < self.prob: 362 | return inputs, targets 363 | 364 | new_order = np.random.permutation(3) 365 | 366 | # Apply a brightness 367 | for i, image in enumerate(inputs): 368 | inputs[i] = image[..., new_order] 369 | for i, image in enumerate(targets): 370 | targets[i] = image[..., new_order] 371 | 372 | return inputs, targets 373 | 374 | 375 | class RandomContrast(object): 376 | """Apply a random contrast in the range (contrast_low, contrast_high) to all channels. 377 | An implementation that is quite distinct from torchvision. 378 | """ 379 | 380 | def __init__(self, contrast_low, contrast_high): 381 | self.contrast_low = contrast_low 382 | self.contrast_high = contrast_high 383 | 384 | def __call__(self, inputs, targets): 385 | contrast = 1 + random.uniform(self.contrast_low, self.contrast_high) 386 | 387 | # Apply a contrast 388 | for i, image in enumerate(inputs): 389 | gray_img = image[..., 0] * 0.299 + image[..., 1] * 0.587 + image[..., 2] * 0.114 390 | tmp_img = np.ones_like(image) * gray_img.mean() 391 | image = image * contrast + (1 - contrast) * tmp_img 392 | inputs[i] = np.clip(image, 0, 255) 393 | 394 | for i, image in enumerate(targets): 395 | gray_img = image[..., 0] * 0.299 + image[..., 1] * 0.587 + image[..., 2] * 0.114 396 | tmp_img = np.ones_like(image) * gray_img.mean() 397 | image = image * contrast + (1 - contrast) * tmp_img 398 | targets[i] = np.clip(image, 0, 255) 399 | 400 | return inputs, targets 401 | 402 | 403 | class RandomSaturation(object): 404 | """Apply a random saturation in the range (saturation_low, saturation_high) to all channels. 405 | An implementation that is quite distinct from torchvision. 406 | """ 407 | 408 | def __init__(self, saturation_low, saturation_high): 409 | self.saturation_low = saturation_low 410 | self.saturation_high = saturation_high 411 | 412 | def __call__(self, inputs, targets): 413 | saturation = 1 + random.uniform(self.saturation_low, self.saturation_high) 414 | if saturation == 1.0: 415 | return inputs, targets 416 | 417 | # Apply a saturation 418 | for i, image in enumerate(inputs): 419 | gray_img = image[..., 0] * 0.299 + image[..., 1] * image[..., 2] * 0.114 420 | tmp_img = np.stack((gray_img, gray_img, gray_img), axis=2) 421 | image = image * saturation + (1 - saturation) * tmp_img 422 | inputs[i] = np.clip(image, 0, 255) 423 | 424 | for i, image in enumerate(targets): 425 | gray_img = image[..., 0] * 0.299 + image[..., 1] * image[..., 2] * 0.114 426 | tmp_img = np.stack((gray_img, gray_img, gray_img), axis=2) 427 | image = image * saturation + (1 - saturation) * tmp_img 428 | targets[i] = np.clip(image, 0, 255) 429 | 430 | return inputs, targets 431 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ***************************************************************************** 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # * Neither the name of the NVIDIA CORPORATION nor the 13 | # names of its contributors may be used to endorse or promote products 14 | # derived from this software without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | # 27 | # ***************************************************************************** 28 | import os 29 | import random 30 | import math 31 | import numpy as np 32 | 33 | import torch.backends.cudnn 34 | import torch.nn.parallel 35 | import torch.optim 36 | import torch.utils.data 37 | from tensorboardX import SummaryWriter 38 | from tqdm import tqdm 39 | tqdm.monitor_interval = 0 40 | 41 | import datasets 42 | import models 43 | import utils 44 | from parser import parser 45 | from eval import evaluate 46 | from datasets import data_transforms 47 | 48 | # Import apex's distributed module. 49 | try: 50 | from apex.parallel import DistributedDataParallel 51 | except ImportError: 52 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 53 | from apex import amp 54 | 55 | """ 56 | Reda, Fitsum A., et al. "Unsupervised Video Interpolation Using Cycle Consistency." 57 | arXiv preprint arXiv:1906.05928 (2019). 58 | 59 | Jiang, Huaizu, et al. "Super slomo: High quality estimation of multiple 60 | intermediate frames for video interpolation." arXiv pre-print arXiv:1712.00080 (2017). 61 | """ 62 | 63 | 64 | def parse_and_set_args(block): 65 | args = parser.parse_args() 66 | 67 | torch.backends.cudnn.benchmark = True 68 | block.log("Enabling torch.backends.cudnn.benchmark") 69 | 70 | if args.resume != '': 71 | block.log("Setting initial eval to true since checkpoint is provided") 72 | args.initial_eval = True 73 | 74 | args.rank = int(os.getenv('RANK', 0)) 75 | args.world_size = int(os.getenv("WORLD_SIZE", 1)) 76 | 77 | if args.local_rank: 78 | args.rank = args.local_rank 79 | if args.local_rank is not None and args.local_rank != 0: 80 | utils.block_print() 81 | 82 | block.log("Creating save directory: {}".format( 83 | os.path.join(args.save, args.name))) 84 | args.save_root = os.path.join(args.save, args.name) 85 | os.makedirs(args.save_root, exist_ok=True) 86 | assert os.path.exists(args.save_root) 87 | 88 | # temporary directory for torch pre-trained models 89 | os.makedirs(args.torch_home, exist_ok=True) 90 | os.environ['TORCH_HOME'] = args.torch_home 91 | 92 | defaults, input_arguments = {}, {} 93 | for key in vars(args): 94 | defaults[key] = parser.get_default(key) 95 | 96 | for argument, value in sorted(vars(args).items()): 97 | if value != defaults[argument] and argument in vars(parser.parse_args()).keys(): 98 | input_arguments['--' + str(argument)] = value 99 | block.log('{}: {}'.format(argument, value)) 100 | 101 | if args.rank == 0: 102 | utils.copy_arguments(input_arguments, os.path.realpath(__file__), 103 | args.save_root) 104 | 105 | args.network_class = utils.module_to_dict(models)[args.model] 106 | args.optimizer_class = utils.module_to_dict(torch.optim)[args.optimizer] 107 | args.dataset_class = utils.module_to_dict(datasets)[args.dataset] 108 | 109 | return args 110 | 111 | 112 | def initialize_distributed(args): 113 | # Manually set the device ids. 114 | torch.cuda.set_device(args.rank % torch.cuda.device_count()) 115 | 116 | # Call the init process 117 | if args.world_size > 1: 118 | init_method = 'env://' 119 | torch.distributed.init_process_group( 120 | backend=args.distributed_backend, 121 | world_size=args.world_size, rank=args.rank, 122 | init_method=init_method) 123 | 124 | 125 | def set_random_seed(seed): 126 | random.seed(seed) 127 | np.random.seed(seed) 128 | torch.manual_seed(seed) 129 | torch.cuda.manual_seed(seed) 130 | 131 | 132 | def get_train_and_valid_data_loaders(block, args): 133 | transform = data_transforms.Compose([ 134 | # geometric augmentation 135 | data_transforms.NumpyToPILImage(), 136 | data_transforms.RandomTranslate2D(max_displ_factor=0.05), 137 | data_transforms.RandomRotate2D(base_angle=17, delta_angle=5), 138 | data_transforms.RandomScaledCrop2D(crop_height=args.crop_size[0], 139 | crop_width=args.crop_size[1], min_crop_ratio=0.8), 140 | data_transforms.RandomVerticalFlip(prob=0.5), 141 | data_transforms.RandomHorizontalFlip(prob=0.5), 142 | data_transforms.PILImageToNumpy(), 143 | # photometric augmentation 144 | data_transforms.RandomGamma(gamma_low=0.9, gamma_high=1.1), 145 | data_transforms.RandomBrightness(brightness_factor=0.1), 146 | data_transforms.RandomColorOrder(prob=0.5), 147 | data_transforms.RandomContrast(contrast_low=-0.1, contrast_high=0.1), 148 | data_transforms.RandomSaturation(saturation_low=-0.1, saturation_high=0.1) 149 | ]) 150 | 151 | if args.skip_aug: 152 | transform = data_transforms.Compose([ 153 | # geometric augmentation 154 | data_transforms.NumpyToPILImage(), 155 | data_transforms.RandomCrop2D(crop_height=args.crop_size[0], 156 | crop_width=args.crop_size[1]), 157 | data_transforms.RandomVerticalFlip(prob=0.5), 158 | data_transforms.RandomHorizontalFlip(prob=0.5), 159 | data_transforms.PILImageToNumpy() 160 | ]) 161 | 162 | # training dataloader 163 | tkwargs = {'batch_size': args.batch_size, 164 | 'num_workers': args.workers, 165 | 'pin_memory': True, 'drop_last': True} 166 | step_size = args.step_size if args.step_size > 0 else (args.num_interp + 1) 167 | train_dataset = args.dataset_class(args=args, root=args.train_file, num_interp=args.num_interp, 168 | sample_rate=args.sample_rate, step_size=step_size, is_training=True, 169 | transform=transform) 170 | 171 | if args.world_size > 1: 172 | train_sampler = torch.utils.data.distributed.DistributedSampler( 173 | train_dataset) 174 | else: 175 | train_sampler = None 176 | 177 | train_loader = torch.utils.data.DataLoader( 178 | train_dataset, sampler=train_sampler, 179 | shuffle=(train_sampler is None), **tkwargs) 180 | 181 | block.log('Number of Training Images: {}:({} mini-batches)'.format( 182 | step_size * len(train_loader.dataset), len(train_loader))) 183 | 184 | # validation dataloader 185 | vkwargs = {'batch_size': args.val_batch_size, 186 | 'num_workers': args.workers, 187 | 'pin_memory': True, 'drop_last': True} 188 | step_size = args.val_step_size if args.val_step_size > 0 else (args.val_num_interp + 1) 189 | 190 | val_dataset = args.dataset_class(args=args, root=args.val_file, num_interp=args.val_num_interp, 191 | sample_rate=args.val_sample_rate, step_size=step_size) 192 | 193 | val_loader = torch.utils.data.DataLoader( 194 | val_dataset, shuffle=False, **vkwargs) 195 | 196 | block.log('Number of Validation Images: {}:({} mini-batches)'.format( 197 | step_size * len(val_loader.dataset), len(val_loader))) 198 | args.val_size = val_loader.dataset[0]['image'][0].shape[:2] 199 | 200 | return train_loader, train_sampler, val_loader 201 | 202 | 203 | def load_model(model, optimizer, block, args): 204 | # trained weights 205 | checkpoint = torch.load(args.resume, map_location='cpu') 206 | 207 | # used for partial initialization 208 | input_dict = checkpoint['state_dict'] 209 | curr_dict = model.state_dict() 210 | state_dict = input_dict.copy() 211 | for key in input_dict: 212 | if key not in curr_dict: 213 | print(key) 214 | continue 215 | if curr_dict[key].shape != input_dict[key].shape: 216 | state_dict.pop(key) 217 | print("key {} skipped because of size mismatch.".format( 218 | key)) 219 | model.load_state_dict(state_dict, strict=False) 220 | if 'optimizer' in checkpoint and args.start_epoch < 0: 221 | optimizer.load_state_dict(checkpoint['optimizer']) 222 | if args.start_epoch < 0: 223 | args.start_epoch = max(0, checkpoint['epoch']) 224 | block.log("Successfully loaded checkpoint (at epoch {})".format( 225 | checkpoint['epoch'])) 226 | 227 | 228 | def build_and_initialize_model_and_optimizer(block, args): 229 | model = args.network_class(args) 230 | block.log('Number of parameters: {val:,}'.format(val= 231 | sum([p.data.nelement() 232 | if p.requires_grad else 0 for p in model.parameters()]))) 233 | 234 | block.log('Initializing CUDA') 235 | assert torch.cuda.is_available(), 'only GPUs support at the moment' 236 | model.cuda(torch.cuda.current_device()) 237 | 238 | optimizer = args.optimizer_class( 239 | [p for p in model.parameters() if p.requires_grad], lr=args.lr) 240 | 241 | block.log("Attempting to Load checkpoint '{}'".format(args.resume)) 242 | if args.resume and os.path.isfile(args.resume): 243 | load_model(model, optimizer, block, args) 244 | elif args.resume: 245 | block.log("No checkpoint found at '{}'".format(args.resume)) 246 | exit(1) 247 | else: 248 | block.log("Random initialization, checkpoint not provided.") 249 | args.start_epoch = 0 250 | 251 | if args.fp16: 252 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 253 | 254 | # Run multi-process when it is needed. 255 | if args.world_size > 1: 256 | model = DistributedDataParallel(model) 257 | 258 | return model, optimizer 259 | 260 | 261 | def get_learning_rate_scheduler(optimizer, block, args): 262 | block.log('Base leaning rate {}.'.format(args.lr)) 263 | if args.lr_scheduler == 'ExponentialLR': 264 | block.log('Using exponential decay learning rate scheduler with ' 265 | '{} decay rate'.format(args.lr_gamma)) 266 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 267 | args.lr_gamma) 268 | elif args.lr_scheduler == 'MultiStepLR': 269 | block.log('Using multi-step learning rate scheduler with {} gamma ' 270 | 'and {} milestones.'.format(args.lr_gamma, 271 | args.lr_milestones)) 272 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 273 | optimizer, milestones=args.lr_milestones, gamma=args.lr_gamma) 274 | elif args.lr_scheduler == 'PolyLR': 275 | block.log('Using polynomial decay learning rate scheduler with {} gamma ' 276 | 'and {} milestones.'.format(args.lr_gamma, 277 | args.lr_milestones)) 278 | 279 | lr_gamma = math.log(0.1) / math.log(1 - (args.lr_milestones[0] - 1e-6) / args.epochs) 280 | 281 | # Poly with lr_gamma until args.lr_milestones[0], then stepLR with factor of 0.1 282 | lambda_map = lambda epoch_index: math.pow(1 - epoch_index / args.epochs, lr_gamma) \ 283 | if np.searchsorted(args.lr_milestones, epoch_index + 1) == 0 \ 284 | else math.pow(10, -1 * np.searchsorted(args.lr_milestones, epoch_index + 1)) 285 | 286 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_map) 287 | 288 | else: 289 | raise NameError('Unknown {} learning rate scheduler'.format( 290 | args.lr_scheduler)) 291 | 292 | return lr_scheduler 293 | 294 | 295 | def forward_only(inputs_gpu, targets_gpu, model): 296 | # Forward pass. 297 | losses, outputs, targets = model(inputs_gpu, targets_gpu) 298 | 299 | # Loss. 300 | for k in losses: 301 | losses[k] = losses[k].mean(dim=0) 302 | loss = losses['tot'] 303 | 304 | return loss, outputs, targets 305 | 306 | 307 | def calc_linf_grad_norm(args,parameters): 308 | if isinstance(parameters, torch.Tensor): 309 | parameters = [parameters] 310 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 311 | max_norm = max(p.grad.data.abs().max() for p in parameters) 312 | max_norm_reduced = torch.cuda.FloatTensor([max_norm]) 313 | if args.world_size > 1: 314 | torch.distributed.all_reduce(max_norm_reduced, 315 | op=torch.distributed.ReduceOp.MAX) 316 | return max_norm_reduced[0].item() 317 | 318 | 319 | def train_step(batch_cpu, model, optimizer, block, args, print_linf_grad=False): 320 | # Move data to GPU. 321 | 322 | inputs = {k: [b.cuda() for b in batch_cpu[k]] 323 | for k in batch_cpu if k in ['image', 'fwd_mvec', 'bwd_mvec', 'depth']} 324 | tar_index = batch_cpu['tindex'].cuda() 325 | 326 | # Forward pass. 327 | loss, outputs, targets = forward_only(inputs, tar_index, model) 328 | 329 | # Backward and SGP steps. 330 | optimizer.zero_grad() 331 | if args.fp16: 332 | with amp.scale_loss(loss, optimizer) as scaled_loss: 333 | scaled_loss.backward() 334 | else: 335 | loss.backward() 336 | 337 | # Calculate and print norm infinity of the gradients. 338 | if print_linf_grad: 339 | block.log('gradients Linf: {:0.3f}'.format(calc_linf_grad_norm(args, 340 | model.parameters()))) 341 | 342 | # Clip gradients by value. 343 | if args.clip_gradients > 0: 344 | torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_gradients) 345 | 346 | optimizer.step() 347 | 348 | return loss, outputs, targets 349 | 350 | 351 | def evaluate_epoch(model, val_loader, block, args, epoch=0): 352 | # Because train and val number of frame interpolate could be different. 353 | if args.val_num_interp != args.num_interp: 354 | model_ = model 355 | if args.world_size > 1: 356 | model_ = model.module 357 | model_.tlinespace = torch.linspace( 358 | 0, 1, 2 + args.val_num_interp).float().cuda() 359 | 360 | # calculate validation loss, create videos, or dump predicted frames 361 | v_psnr, v_ssim, v_ie, loss_values = evaluate(args, val_loader, model, args.val_num_interp, epoch, block) 362 | 363 | if args.val_num_interp != args.num_interp: 364 | model_ = model 365 | if args.world_size > 1: 366 | model_ = model.module 367 | model_.tlinespace = torch.linspace(0, 1, 368 | 2 + args.num_interp).float().cuda() 369 | # Move back the model to train mode. 370 | model.train() 371 | 372 | return v_psnr, v_ssim, v_ie, loss_values 373 | 374 | 375 | def write_summary(global_index, learning_rate, t_loss, 376 | v_loss, v_psnr, v_ssim, v_ie, args): 377 | # Write to tensorboard. 378 | if args.rank == 0: 379 | args.logger.add_scalar("lr", learning_rate, global_index) 380 | args.logger.add_scalars("Loss", 381 | {'trainLoss': t_loss, 'valLoss': v_loss}, 382 | global_index) 383 | args.logger.add_scalar("PSNR", v_psnr, global_index) 384 | args.logger.add_scalar("SSIM", v_ssim, global_index) 385 | args.logger.add_scalar("RMS", v_ie, global_index) 386 | 387 | 388 | def train_epoch(epoch, args, model, optimizer, lr_scheduler, 389 | train_sampler, train_loader, 390 | v_psnr, v_ssim, v_ie, v_loss, block): 391 | # Average loss calculator. 392 | loss_values = utils.AverageMeter() 393 | 394 | # Advance Learning rate. 395 | lr_scheduler.step() 396 | 397 | # This will ensure the data is shuffled each epoch. 398 | if train_sampler is not None: 399 | train_sampler.set_epoch(epoch) 400 | 401 | # Get number of batches in one epoch. 402 | num_batches = len(train_loader) if args.train_n_batches < 0 \ 403 | else args.train_n_batches 404 | 405 | global_index = 0 406 | for i, batch in enumerate(train_loader): 407 | 408 | # Set global index. 409 | global_index = epoch * num_batches + i 410 | 411 | # Move one step. 412 | loss, outputs, _ = train_step( 413 | batch, model, optimizer, block, args, 414 | ((global_index + 1) % args.print_freq == 0)) 415 | 416 | # Update the loss accumulator. 417 | loss_values.update(loss.data.item(), outputs.size(0)) 418 | 419 | # Summary writer. 420 | if (global_index + 1) % args.print_freq == 0: 421 | 422 | # Reduce the loss. 423 | if args.world_size > 1: 424 | t_loss_gpu = torch.Tensor([loss_values.val]).cuda() 425 | torch.distributed.all_reduce(t_loss_gpu) 426 | t_loss = t_loss_gpu.item() / args.world_size 427 | else: 428 | t_loss = loss_values.val 429 | 430 | # Write to tensorboard. 431 | write_summary(global_index, lr_scheduler.get_lr()[0], t_loss, 432 | v_loss, v_psnr, v_ssim, v_ie, args) 433 | 434 | # And reset the loss accumulator. 435 | loss_values.reset() 436 | 437 | # Print some output. 438 | dict2print = {'iter': global_index, 439 | 'epoch': str(epoch) + '/' + str(args.epochs), 440 | 'batch': str(i + 1) + '/' + str(num_batches)} 441 | str2print = ' '.join(key + " : " + str(dict2print[key]) 442 | for key in dict2print) 443 | str2print += ' trainLoss:' + ' %1.3f' % t_loss 444 | str2print += ' valLoss' + ' %1.3f' % v_loss 445 | str2print += ' valPSNR' + ' %1.3f' % v_psnr 446 | str2print += ' lr:' + ' %1.6f' % (lr_scheduler.get_lr()[0]) 447 | block.log(str2print) 448 | 449 | # Break the training loop if we have reached the maximum number of batches. 450 | if (i + 1) >= num_batches: 451 | break 452 | return global_index 453 | 454 | 455 | def save_model(model, optimizer, epoch, global_index, max_psnr, block, args): 456 | # Write on rank zero only 457 | if args.rank == 0: 458 | if args.world_size > 1: 459 | model_ = model.module 460 | else: 461 | model_ = model 462 | state_dict = model_.state_dict() 463 | tmp_keys = state_dict.copy() 464 | for k in state_dict: 465 | [tmp_keys.pop(k) if (k in tmp_keys and ikey in k) 466 | else None for ikey in model_.ignore_keys] 467 | state_dict = tmp_keys.copy() 468 | # save checkpoint 469 | model_optim_state = {'epoch': epoch, 470 | 'arch': args.model, 471 | 'state_dict': state_dict, 472 | 'optimizer': optimizer.state_dict(), 473 | } 474 | model_name = os.path.join( 475 | args.save_root, '_ckpt_epoch_%03d_iter_%07d_psnr_%1.2f.pt.tar' % ( 476 | epoch, global_index, max_psnr)) 477 | torch.save(model_optim_state, model_name) 478 | block.log('saved model {}'.format(model_name)) 479 | 480 | return model_name 481 | 482 | 483 | def train(model, optimizer, lr_scheduler, train_loader, 484 | train_sampler, val_loader, block, args): 485 | # Set the model to train mode. 486 | model.train() 487 | 488 | # Keep track of maximum PSNR. 489 | max_psnr = -1 490 | 491 | # Perform an initial evaluation. 492 | if args.initial_eval: 493 | block.log('Initial evaluation.') 494 | 495 | v_psnr, v_ssim, v_ie, v_loss = evaluate_epoch(model, val_loader, block, args, args.start_epoch) 496 | else: 497 | v_psnr, v_ssim, v_ie, v_loss = 20.0, 0.5, 15.0, 0.0 498 | 499 | for epoch in range(args.start_epoch, args.epochs): 500 | 501 | # Train for an epoch. 502 | global_index = train_epoch(epoch, args, model, optimizer, lr_scheduler, 503 | train_sampler, train_loader, v_psnr, v_ssim, v_ie, v_loss, block) 504 | 505 | if (epoch + 1) % args.save_freq == 0: 506 | v_psnr, v_ssim, v_ie, v_loss = evaluate_epoch(model, val_loader, block, args, epoch + 1) 507 | if v_psnr > max_psnr: 508 | max_psnr = v_psnr 509 | save_model(model, optimizer, epoch + 1, global_index, 510 | max_psnr, block, args) 511 | 512 | return 0 513 | 514 | 515 | def main(): 516 | # Parse the args. 517 | with utils.TimerBlock("\nParsing Arguments") as block: 518 | args = parse_and_set_args(block) 519 | 520 | # Initialize torch.distributed. 521 | with utils.TimerBlock("Initializing Distributed"): 522 | initialize_distributed(args) 523 | 524 | # Set all random seed for reproducibility. 525 | with utils.TimerBlock("Setting Random Seed"): 526 | set_random_seed(args.seed) 527 | 528 | # Train and validation data loaders. 529 | with utils.TimerBlock("Building {} Dataset".format(args.dataset)) as block: 530 | train_loader, train_sampler, val_loader = get_train_and_valid_data_loaders(block, args) 531 | 532 | # Build the model and optimizer. 533 | with utils.TimerBlock("Building {} Model and {} Optimizer".format( 534 | args.model, args.optimizer_class.__name__)) as block: 535 | model, optimizer = build_and_initialize_model_and_optimizer(block, args) 536 | 537 | # Learning rate scheduler. 538 | with utils.TimerBlock("Building {} Learning Rate Scheduler".format( 539 | args.optimizer)) as block: 540 | lr_scheduler = get_learning_rate_scheduler(optimizer, block, args) 541 | 542 | # Set the tf writer on rank 0. 543 | with utils.TimerBlock("Creating Tensorboard Writers"): 544 | if args.rank == 0: 545 | args.logger = SummaryWriter(log_dir=args.save_root) 546 | 547 | with utils.TimerBlock("Training Model") as block: 548 | train(model, optimizer, lr_scheduler, train_loader, 549 | train_sampler, val_loader, block, args) 550 | 551 | return 0 552 | 553 | 554 | if __name__ == '__main__': 555 | main() 556 | --------------------------------------------------------------------------------