├── .gitignore ├── requirements.txt ├── photos ├── one.png ├── two.png └── output.gif ├── README.md ├── inference.py ├── fusion.py ├── export.py ├── util.py ├── pyramid_flow_estimator.py ├── feature_extractor.py ├── interpolator.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /photos/output.mp4 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | torch 3 | tqdm -------------------------------------------------------------------------------- /photos/one.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dajes/frame-interpolation-pytorch/HEAD/photos/one.png -------------------------------------------------------------------------------- /photos/two.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dajes/frame-interpolation-pytorch/HEAD/photos/two.png -------------------------------------------------------------------------------- /photos/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dajes/frame-interpolation-pytorch/HEAD/photos/output.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Frame interpolation in PyTorch 3 | 4 | This is an unofficial PyTorch inference implementation 5 | of [FILM: Frame Interpolation for Large Motion, In ECCV 2022](https://film-net.github.io/).\ 6 | [Original repository link](https://github.com/google-research/frame-interpolation) 7 | 8 | The project is focused on creating simple and TorchScript compilable inference interface for the original pretrained TF2 9 | model. 10 | 11 | # Quickstart 12 | 13 | Download a compiled model from [the release](https://github.com/dajes/frame-interpolation-pytorch/releases) 14 | and specify the path to the file in the following snippet: 15 | 16 | ```python 17 | import torch 18 | 19 | device = torch.device('cuda') 20 | precision = torch.float16 21 | 22 | model = torch.jit.load(model_path, map_location='cpu') 23 | model.eval().to(device=device, dtype=precision) 24 | 25 | img1 = torch.rand(1, 3, 720, 1080).to(precision).to(device) 26 | img3 = torch.rand(1, 3, 720, 1080).to(precision).to(device) 27 | dt = img1.new_full((1, 1), .5) 28 | 29 | with torch.no_grad(): 30 | img2 = model(img1, img3, dt) # Will be of the same shape as inputs (1, 3, 720, 1080) 31 | 32 | ``` 33 | 34 | # Exporting model by yourself 35 | 36 | You will need to install TensorFlow of the version specified in 37 | the [original repo](https://github.com/google-research/frame-interpolation#installation) and download SavedModel of " 38 | Style" network from [there](https://github.com/google-research/frame-interpolation#pre-trained-models) 39 | 40 | After you have downloaded the SavedModel and can load it via ```tf.compat.v2.saved_model.load(path)```: 41 | 42 | * Clone the repository 43 | 44 | ``` 45 | git clone https://github.com/dajes/frame-interpolation-pytorch 46 | cd frame-interpolation-pytorch 47 | ``` 48 | 49 | * Install dependencies 50 | 51 | ``` 52 | python -m pip install -r requirements.txt 53 | ``` 54 | 55 | * Run ```export.py```: 56 | 57 | ``` 58 | python export.py "model_path" "save_path" [--statedict] [--fp32] [--skiptest] [--gpu] 59 | ``` 60 | 61 | Argument list: 62 | 63 | * ```model_path``` Path to the TF SavedModel 64 | * ```save_path``` Path to save the PyTorch state dict 65 | * ```--statedict``` Export to state dict instead of TorchScript 66 | * ```--fp32``` Save weights at full precision 67 | * ```--skiptest``` Skip testing and save model immediately instead 68 | * ```--gpu``` Whether to attempt to use GPU for testing 69 | 70 | # Testing exported model 71 | The following script creates an MP4 video of interpolated frames between 2 input images: 72 | ``` 73 | python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--gpu] [--fp16] [--frames FRAMES] [--fps FPS] 74 | ``` 75 | * ```model_path``` Path to the exported TorchScript checkpoint 76 | * ```img1``` Path to the first image 77 | * ```img2``` Path to the second image 78 | * ```--save_path SAVE_PATH``` Path to save the interpolated frames as a video, if absent it will be saved in the same directory as ```img1``` is located and named ```output.mp4``` 79 | * ```--gpu``` Whether to attempt to use GPU for predictions 80 | * ```--fp16``` Whether to use fp16 for calculations, speeds inference up on GPUs with tensor cores 81 | * ```--frames FRAMES``` Number of frames to interpolate between the input images 82 | * ```--fps FPS``` FPS of the output video 83 | 84 | ### Results on the 2 example photos from original repository: 85 |

86 | 87 | 88 |

89 | 90 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import os 3 | from tqdm import tqdm 4 | import torch 5 | import numpy as np 6 | import cv2 7 | 8 | from util import load_image 9 | 10 | 11 | def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half): 12 | model = torch.jit.load(model_path, map_location='cpu') 13 | model.eval() 14 | img_batch_1, crop_region_1 = load_image(img1) 15 | img_batch_2, crop_region_2 = load_image(img2) 16 | 17 | img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2) 18 | img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2) 19 | 20 | if not half: 21 | model.float() 22 | 23 | if gpu and torch.cuda.is_available(): 24 | if half: 25 | model = model.half() 26 | else: 27 | model.float() 28 | model = model.cuda() 29 | 30 | if save_path == 'img1 folder': 31 | save_path = os.path.join(os.path.split(img1)[0], 'output.mp4') 32 | 33 | results = [ 34 | img_batch_1, 35 | img_batch_2 36 | ] 37 | 38 | idxes = [0, inter_frames + 1] 39 | remains = list(range(1, inter_frames + 1)) 40 | 41 | splits = torch.linspace(0, 1, inter_frames + 2) 42 | 43 | for _ in tqdm(range(len(remains)), 'Generating in-between frames'): 44 | starts = splits[idxes[:-1]] 45 | ends = splits[idxes[1:]] 46 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() 47 | matrix = torch.argmin(distances).item() 48 | start_i, step = np.unravel_index(matrix, distances.shape) 49 | end_i = start_i + 1 50 | 51 | x0 = results[start_i] 52 | x1 = results[end_i] 53 | 54 | if gpu and torch.cuda.is_available(): 55 | if half: 56 | x0 = x0.half() 57 | x1 = x1.half() 58 | x0 = x0.cuda() 59 | x1 = x1.cuda() 60 | 61 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) 62 | 63 | with torch.no_grad(): 64 | prediction = model(x0, x1, dt) 65 | insert_position = bisect.bisect_left(idxes, remains[step]) 66 | idxes.insert(insert_position, remains[step]) 67 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) 68 | del remains[step] 69 | 70 | video_folder = os.path.split(save_path)[0] 71 | os.makedirs(video_folder, exist_ok=True) 72 | 73 | y1, x1, y2, x2 = crop_region_1 74 | frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy()[y1:y2, x1:x2].copy() for tensor in results] 75 | 76 | w, h = frames[0].shape[1::-1] 77 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 78 | writer = cv2.VideoWriter(save_path, fourcc, fps, (w, h)) 79 | for frame in frames: 80 | writer.write(frame) 81 | 82 | for frame in frames[1:][::-1]: 83 | writer.write(frame) 84 | 85 | writer.release() 86 | 87 | 88 | if __name__ == '__main__': 89 | import argparse 90 | 91 | parser = argparse.ArgumentParser(description='Test frame interpolator model') 92 | 93 | parser.add_argument('model_path', type=str, help='Path to the TorchScript model') 94 | parser.add_argument('img1', type=str, help='Path to the first image') 95 | parser.add_argument('img2', type=str, help='Path to the second image') 96 | 97 | parser.add_argument('--save_path', type=str, default='img1 folder', help='Path to save the interpolated frames') 98 | parser.add_argument('--gpu', action='store_true', help='Use GPU') 99 | parser.add_argument('--fp16', action='store_true', help='Use FP16') 100 | parser.add_argument('--frames', type=int, default=18, help='Number of frames to interpolate') 101 | parser.add_argument('--fps', type=int, default=10, help='FPS of the output video') 102 | 103 | args = parser.parse_args() 104 | 105 | inference(args.model_path, args.img1, args.img2, args.save_path, args.gpu, args.frames, args.fps, args.fp16) 106 | -------------------------------------------------------------------------------- /fusion.py: -------------------------------------------------------------------------------- 1 | """The final fusion stage for the film_net frame interpolator. 2 | 3 | The inputs to this module are the warped input images, image features and 4 | flow fields, all aligned to the target frame (often midway point between the 5 | two original inputs). The output is the final image. FILM has no explicit 6 | occlusion handling -- instead using the abovementioned information this module 7 | automatically decides how to best blend the inputs together to produce content 8 | in areas where the pixels can only be borrowed from one of the inputs. 9 | 10 | Similarly, this module also decides on how much to blend in each input in case 11 | of fractional timestep that is not at the halfway point. For example, if the two 12 | inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1, 13 | it often makes most sense to favor the first input. However, this is not 14 | always the case -- in particular in occluded pixels. 15 | 16 | The architecture of the Fusion module follows U-net [1] architecture's decoder 17 | side, e.g. each pyramid level consists of concatenation with upsampled coarser 18 | level output, and two 3x3 convolutions. 19 | 20 | The upsampling is implemented as 'resize convolution', e.g. nearest neighbor 21 | upsampling followed by 2x2 convolution as explained in [2]. The classic U-net 22 | uses max-pooling which has a tendency to create checkerboard artifacts. 23 | 24 | [1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image 25 | Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf 26 | [2] https://distill.pub/2016/deconv-checkerboard/ 27 | """ 28 | from typing import List 29 | 30 | import torch 31 | from torch import nn 32 | from torch.nn import functional as F 33 | 34 | from util import Conv2d 35 | 36 | _NUMBER_OF_COLOR_CHANNELS = 3 37 | 38 | 39 | def get_channels_at_level(level, filters): 40 | n_images = 2 41 | channels = _NUMBER_OF_COLOR_CHANNELS 42 | flows = 2 43 | 44 | return (sum(filters << i for i in range(level)) + channels + flows) * n_images 45 | 46 | 47 | class Fusion(nn.Module): 48 | """The decoder.""" 49 | 50 | def __init__(self, n_layers=4, specialized_layers=3, filters=64): 51 | """ 52 | Args: 53 | m: specialized levels 54 | """ 55 | super().__init__() 56 | 57 | # The final convolution that outputs RGB: 58 | self.output_conv = nn.Conv2d(filters, 3, kernel_size=1) 59 | 60 | # Each item 'convs[i]' will contain the list of convolutions to be applied 61 | # for pyramid level 'i'. 62 | self.convs = nn.ModuleList() 63 | 64 | # Create the convolutions. Roughly following the feature extractor, we 65 | # double the number of filters when the resolution halves, but only up to 66 | # the specialized_levels, after which we use the same number of filters on 67 | # all levels. 68 | # 69 | # We create the convs in fine-to-coarse order, so that the array index 70 | # for the convs will correspond to our normal indexing (0=finest level). 71 | # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442) 72 | 73 | in_channels = get_channels_at_level(n_layers, filters) 74 | increase = 0 75 | for i in range(n_layers)[::-1]: 76 | num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) 77 | convs = nn.ModuleList([ 78 | Conv2d(in_channels, num_filters, size=2, activation=None), 79 | Conv2d(in_channels + (increase or num_filters), num_filters, size=3), 80 | Conv2d(num_filters, num_filters, size=3)] 81 | ) 82 | self.convs.append(convs) 83 | in_channels = num_filters 84 | increase = get_channels_at_level(i, filters) - num_filters // 2 85 | 86 | def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor: 87 | """Runs the fusion module. 88 | 89 | Args: 90 | pyramid: The input feature pyramid as list of tensors. Each tensor being 91 | in (B x H x W x C) format, with finest level tensor first. 92 | 93 | Returns: 94 | A batch of RGB images. 95 | Raises: 96 | ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in 97 | the constructor. 98 | """ 99 | 100 | # As a slight difference to a conventional decoder (e.g. U-net), we don't 101 | # apply any extra convolutions to the coarsest level, but just pass it 102 | # to finer levels for concatenation. This choice has not been thoroughly 103 | # evaluated, but is motivated by the educated guess that the fusion part 104 | # probably does not need large spatial context, because at this point the 105 | # features are spatially aligned by the preceding warp. 106 | net = pyramid[-1] 107 | 108 | # Loop starting from the 2nd coarsest level: 109 | # for i in reversed(range(0, len(pyramid) - 1)): 110 | for k, layers in enumerate(self.convs): 111 | i = len(self.convs) - 1 - k 112 | # Resize the tensor from coarser level to match for concatenation. 113 | level_size = pyramid[i].shape[2:4] 114 | net = F.interpolate(net, size=level_size, mode='nearest') 115 | net = layers[0](net) 116 | net = torch.cat([pyramid[i], net], dim=1) 117 | net = layers[1](net) 118 | net = layers[2](net) 119 | net = self.output_conv(net) 120 | return net 121 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import torch 6 | 7 | from interpolator import Interpolator 8 | 9 | 10 | def translate_state_dict(var_dict, state_dict): 11 | for name, (prev_name, weight) in zip(state_dict, var_dict.items()): 12 | print('Mapping', prev_name, '->', name) 13 | weight = torch.from_numpy(weight) 14 | if 'kernel' in prev_name: 15 | # Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W) 16 | weight = weight.permute(3, 2, 0, 1) 17 | 18 | assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}' 19 | 20 | state_dict[name] = weight 21 | 22 | 23 | def import_state_dict(interpolator: Interpolator, saved_model): 24 | variables = saved_model.keras_api.variables 25 | 26 | extract_dict = interpolator.extract.state_dict() 27 | flow_dict = interpolator.predict_flow.state_dict() 28 | fuse_dict = interpolator.fuse.state_dict() 29 | 30 | extract_vars = {} 31 | _flow_vars = {} 32 | _fuse_vars = {} 33 | 34 | for var in variables: 35 | name = var.name 36 | if name.startswith('feat_net'): 37 | extract_vars[name[9:]] = var.numpy() 38 | elif name.startswith('predict_flow'): 39 | _flow_vars[name[13:]] = var.numpy() 40 | elif name.startswith('fusion'): 41 | _fuse_vars[name[7:]] = var.numpy() 42 | 43 | # reverse order of modules to allow jit export 44 | # TODO: improve this hack 45 | flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True)) 46 | fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True)) 47 | 48 | assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}' 49 | assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}' 50 | assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}' 51 | 52 | for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)): 53 | translate_state_dict(var_dict, state_dict) 54 | 55 | interpolator.extract.load_state_dict(extract_dict) 56 | interpolator.predict_flow.load_state_dict(flow_dict) 57 | interpolator.fuse.load_state_dict(fuse_dict) 58 | 59 | 60 | def verify_debug_outputs(pt_outputs, tf_outputs): 61 | max_error = 0 62 | for name, predicted in pt_outputs.items(): 63 | if name == 'image': 64 | continue 65 | pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted] 66 | true_frfp = [f.numpy() for f in tf_outputs[name]] 67 | 68 | for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)): 69 | assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}' 70 | error = np.max(np.abs(pred - true)) 71 | max_error = max(max_error, error) 72 | assert error < 1, f'{name} {i} max error: {error}' 73 | print('Max intermediate error:', max_error) 74 | 75 | 76 | def test_model(interpolator, model, half=False, gpu=False): 77 | torch.manual_seed(0) 78 | time = torch.full((1, 1), .5) 79 | x0 = torch.rand(1, 3, 256, 256) 80 | x1 = torch.rand(1, 3, 256, 256) 81 | 82 | x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32) 83 | x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32) 84 | time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32) 85 | tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False) 86 | 87 | if half: 88 | x0 = x0.half() 89 | x1 = x1.half() 90 | time = time.half() 91 | 92 | if gpu and torch.cuda.is_available(): 93 | x0 = x0.cuda() 94 | x1 = x1.cuda() 95 | time = time.cuda() 96 | 97 | with torch.no_grad(): 98 | pt_outputs = interpolator.debug_forward(x0, x1, time) 99 | 100 | verify_debug_outputs(pt_outputs, tf_outputs) 101 | 102 | with torch.no_grad(): 103 | prediction = interpolator(x0, x1, time) 104 | output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy() 105 | true_color = tf_outputs['image'].numpy() 106 | error = np.abs(output_color - true_color).max() 107 | 108 | print('Color max error:', error) 109 | 110 | 111 | def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False): 112 | print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} ' 113 | f'using {"CG"[use_gpu]}PU') 114 | model = tf.compat.v2.saved_model.load(model_path) 115 | interpolator = Interpolator() 116 | interpolator.eval() 117 | import_state_dict(interpolator, model) 118 | 119 | if use_gpu and torch.cuda.is_available(): 120 | interpolator = interpolator.cuda() 121 | else: 122 | use_gpu = False 123 | 124 | if fp16: 125 | interpolator = interpolator.half() 126 | if export_to_torchscript: 127 | interpolator = torch.jit.script(interpolator) 128 | if export_to_torchscript: 129 | interpolator.save(save_path) 130 | else: 131 | torch.save(interpolator.state_dict(), save_path) 132 | 133 | if not skiptest: 134 | if not use_gpu and fp16: 135 | warnings.warn('Testing FP16 model on CPU is impossible, casting it back') 136 | interpolator = interpolator.float() 137 | fp16 = False 138 | test_model(interpolator, model, fp16, use_gpu) 139 | 140 | 141 | if __name__ == '__main__': 142 | import argparse 143 | 144 | parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict') 145 | 146 | parser.add_argument('model_path', type=str, help='Path to the TF SavedModel') 147 | parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict') 148 | parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript') 149 | parser.add_argument('--fp32', action='store_true', help='Save at full precision') 150 | parser.add_argument('--skiptest', action='store_true', help='Skip testing and save model immediately instead') 151 | parser.add_argument('--gpu', action='store_true', help='Use GPU') 152 | 153 | args = parser.parse_args() 154 | 155 | main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest) 156 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """Various utilities used in the film_net frame interpolator model.""" 2 | from typing import List, Optional 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | def pad_batch(batch, align): 12 | height, width = batch.shape[1:3] 13 | height_to_pad = (align - height % align) if height % align != 0 else 0 14 | width_to_pad = (align - width % align) if width % align != 0 else 0 15 | 16 | crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)] 17 | batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), 18 | (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant') 19 | return batch, crop_region 20 | 21 | 22 | def load_image(path, align=64): 23 | image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) 24 | image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align) 25 | return image_batch, crop_region 26 | 27 | 28 | def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]: 29 | """Builds an image pyramid from a given image. 30 | 31 | The original image is included in the pyramid and the rest are generated by 32 | successively halving the resolution. 33 | 34 | Args: 35 | image: the input image. 36 | options: film_net options object 37 | 38 | Returns: 39 | A list of images starting from the finest with options.pyramid_levels items 40 | """ 41 | 42 | pyramid = [] 43 | for i in range(pyramid_levels): 44 | pyramid.append(image) 45 | if i < pyramid_levels - 1: 46 | image = F.avg_pool2d(image, 2, 2) 47 | return pyramid 48 | 49 | 50 | def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: 51 | """Backward warps the image using the given flow. 52 | 53 | Specifically, the output pixel in batch b, at position x, y will be computed 54 | as follows: 55 | (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) 56 | output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) 57 | 58 | Note that the flow vectors are expected as [x, y], e.g. x in position 0 and 59 | y in position 1. 60 | 61 | Args: 62 | image: An image with shape BxHxWxC. 63 | flow: A flow with shape BxHxWx2, with the two channels denoting the relative 64 | offset in order: (dx, dy). 65 | Returns: 66 | A warped image. 67 | """ 68 | flow = -flow.flip(1) 69 | 70 | dtype = flow.dtype 71 | device = flow.device 72 | 73 | # warped = tfa_image.dense_image_warp(image, flow) 74 | # Same as above but with pytorch 75 | ls1 = 1 - 1 / flow.shape[3] 76 | ls2 = 1 - 1 / flow.shape[2] 77 | 78 | normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor( 79 | [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None] 80 | normalized_flow2 = torch.stack([ 81 | torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1], 82 | torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0], 83 | ], dim=3) 84 | 85 | warped = F.grid_sample(image, normalized_flow2, 86 | mode='bilinear', padding_mode='border', align_corners=False) 87 | return warped.reshape(image.shape) 88 | 89 | 90 | def multiply_pyramid(pyramid: List[torch.Tensor], 91 | scalar: torch.Tensor) -> List[torch.Tensor]: 92 | """Multiplies all image batches in the pyramid by a batch of scalars. 93 | 94 | Args: 95 | pyramid: Pyramid of image batches. 96 | scalar: Batch of scalars. 97 | 98 | Returns: 99 | An image pyramid with all images multiplied by the scalar. 100 | """ 101 | # To multiply each image with its corresponding scalar, we first transpose 102 | # the batch of images from BxHxWxC-format to CxHxWxB. This can then be 103 | # multiplied with a batch of scalars, then we transpose back to the standard 104 | # BxHxWxC form. 105 | return [image * scalar[..., None, None] for image in pyramid] 106 | 107 | 108 | def flow_pyramid_synthesis( 109 | residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: 110 | """Converts a residual flow pyramid into a flow pyramid.""" 111 | flow = residual_pyramid[-1] 112 | flow_pyramid: List[torch.Tensor] = [flow] 113 | for residual_flow in residual_pyramid[:-1][::-1]: 114 | level_size = residual_flow.shape[2:4] 115 | flow = F.interpolate(2 * flow, size=level_size, mode='bilinear') 116 | flow = residual_flow + flow 117 | flow_pyramid.insert(0, flow) 118 | return flow_pyramid 119 | 120 | 121 | def pyramid_warp(feature_pyramid: List[torch.Tensor], 122 | flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: 123 | """Warps the feature pyramid using the flow pyramid. 124 | 125 | Args: 126 | feature_pyramid: feature pyramid starting from the finest level. 127 | flow_pyramid: flow fields, starting from the finest level. 128 | 129 | Returns: 130 | Reverse warped feature pyramid. 131 | """ 132 | warped_feature_pyramid = [] 133 | for features, flow in zip(feature_pyramid, flow_pyramid): 134 | warped_feature_pyramid.append(warp(features, flow)) 135 | return warped_feature_pyramid 136 | 137 | 138 | def concatenate_pyramids(pyramid1: List[torch.Tensor], 139 | pyramid2: List[torch.Tensor]) -> List[torch.Tensor]: 140 | """Concatenates each pyramid level together in the channel dimension.""" 141 | result = [] 142 | for features1, features2 in zip(pyramid1, pyramid2): 143 | result.append(torch.cat([features1, features2], dim=1)) 144 | return result 145 | 146 | 147 | class Conv2d(nn.Sequential): 148 | def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'): 149 | assert activation in (None, 'relu') 150 | super().__init__( 151 | nn.Conv2d( 152 | in_channels=in_channels, 153 | out_channels=out_channels, 154 | kernel_size=size, 155 | padding='same' if size % 2 else 0) 156 | ) 157 | self.size = size 158 | self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None 159 | 160 | def forward(self, x): 161 | if not self.size % 2: 162 | x = F.pad(x, (0, 1, 0, 1)) 163 | y = self[0](x) 164 | if self.activation is not None: 165 | y = self.activation(y) 166 | return y 167 | -------------------------------------------------------------------------------- /pyramid_flow_estimator.py: -------------------------------------------------------------------------------- 1 | """PyTorch layer for estimating optical flow by a residual flow pyramid. 2 | 3 | This approach of estimating optical flow between two images can be traced back 4 | to [1], but is also used by later neural optical flow computation methods such 5 | as SpyNet [2] and PWC-Net [3]. 6 | 7 | The basic idea is that the optical flow is first estimated in a coarse 8 | resolution, then the flow is upsampled to warp the higher resolution image and 9 | then a residual correction is computed and added to the estimated flow. This 10 | process is repeated in a pyramid on coarse to fine order to successively 11 | increase the resolution of both optical flow and the warped image. 12 | 13 | In here, the optical flow predictor is used as an internal component for the 14 | film_net frame interpolator, to warp the two input images into the inbetween, 15 | target frame. 16 | 17 | [1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987. 18 | [2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid 19 | Network. 2016 20 | [3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using 21 | Pyramid, Warping, and Cost Volume, 2017 22 | """ 23 | from typing import List 24 | 25 | import torch 26 | from torch import nn 27 | from torch.nn import functional as F 28 | 29 | import util 30 | 31 | 32 | class FlowEstimator(nn.Module): 33 | """Small-receptive field predictor for computing the flow between two images. 34 | 35 | This is used to compute the residual flow fields in PyramidFlowEstimator. 36 | 37 | Note that while the number of 3x3 convolutions & filters to apply is 38 | configurable, two extra 1x1 convolutions are appended to extract the flow in 39 | the end. 40 | 41 | Attributes: 42 | name: The name of the layer 43 | num_convs: Number of 3x3 convolutions to apply 44 | num_filters: Number of filters in each 3x3 convolution 45 | """ 46 | 47 | def __init__(self, in_channels: int, num_convs: int, num_filters: int): 48 | super(FlowEstimator, self).__init__() 49 | 50 | self._convs = nn.ModuleList() 51 | for i in range(num_convs): 52 | self._convs.append(util.Conv2d(in_channels=in_channels, out_channels=num_filters, size=3)) 53 | in_channels = num_filters 54 | self._convs.append(util.Conv2d(in_channels, num_filters // 2, size=1)) 55 | in_channels = num_filters // 2 56 | # For the final convolution, we want no activation at all to predict the 57 | # optical flow vector values. We have done extensive testing on explicitly 58 | # bounding these values using sigmoid, but it turned out that having no 59 | # activation gives better results. 60 | self._convs.append(util.Conv2d(in_channels, 2, size=1, activation=None)) 61 | 62 | def forward(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor: 63 | """Estimates optical flow between two images. 64 | 65 | Args: 66 | features_a: per pixel feature vectors for image A (B x H x W x C) 67 | features_b: per pixel feature vectors for image B (B x H x W x C) 68 | 69 | Returns: 70 | A tensor with optical flow from A to B 71 | """ 72 | net = torch.cat([features_a, features_b], dim=1) 73 | for conv in self._convs: 74 | net = conv(net) 75 | return net 76 | 77 | 78 | class PyramidFlowEstimator(nn.Module): 79 | """Predicts optical flow by coarse-to-fine refinement. 80 | """ 81 | 82 | def __init__(self, filters: int = 64, 83 | flow_convs: tuple = (3, 3, 3, 3), 84 | flow_filters: tuple = (32, 64, 128, 256)): 85 | super(PyramidFlowEstimator, self).__init__() 86 | 87 | in_channels = filters << 1 88 | predictors = [] 89 | for i in range(len(flow_convs)): 90 | predictors.append( 91 | FlowEstimator( 92 | in_channels=in_channels, 93 | num_convs=flow_convs[i], 94 | num_filters=flow_filters[i])) 95 | in_channels += filters << (i + 2) 96 | self._predictor = predictors[-1] 97 | self._predictors = nn.ModuleList(predictors[:-1][::-1]) 98 | 99 | def forward(self, feature_pyramid_a: List[torch.Tensor], 100 | feature_pyramid_b: List[torch.Tensor]) -> List[torch.Tensor]: 101 | """Estimates residual flow pyramids between two image pyramids. 102 | 103 | Each image pyramid is represented as a list of tensors in fine-to-coarse 104 | order. Each individual image is represented as a tensor where each pixel is 105 | a vector of image features. 106 | 107 | util.flow_pyramid_synthesis can be used to convert the residual flow 108 | pyramid returned by this method into a flow pyramid, where each level 109 | encodes the flow instead of a residual correction. 110 | 111 | Args: 112 | feature_pyramid_a: image pyramid as a list in fine-to-coarse order 113 | feature_pyramid_b: image pyramid as a list in fine-to-coarse order 114 | 115 | Returns: 116 | List of flow tensors, in fine-to-coarse order, each level encoding the 117 | difference against the bilinearly upsampled version from the coarser 118 | level. The coarsest flow tensor, e.g. the last element in the array is the 119 | 'DC-term', e.g. not a residual (alternatively you can think of it being a 120 | residual against zero). 121 | """ 122 | levels = len(feature_pyramid_a) 123 | v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) 124 | residuals = [v] 125 | for i in range(levels - 2, len(self._predictors) - 1, -1): 126 | # Upsamples the flow to match the current pyramid level. Also, scales the 127 | # magnitude by two to reflect the new size. 128 | level_size = feature_pyramid_a[i].shape[2:4] 129 | v = F.interpolate(2 * v, size=level_size, mode='bilinear') 130 | # Warp feature_pyramid_b[i] image based on the current flow estimate. 131 | warped = util.warp(feature_pyramid_b[i], v) 132 | # Estimate the residual flow between pyramid_a[i] and warped image: 133 | v_residual = self._predictor(feature_pyramid_a[i], warped) 134 | residuals.insert(0, v_residual) 135 | v = v_residual + v 136 | 137 | for k, predictor in enumerate(self._predictors): 138 | i = len(self._predictors) - 1 - k 139 | # Upsamples the flow to match the current pyramid level. Also, scales the 140 | # magnitude by two to reflect the new size. 141 | level_size = feature_pyramid_a[i].shape[2:4] 142 | v = F.interpolate(2 * v, size=level_size, mode='bilinear') 143 | # Warp feature_pyramid_b[i] image based on the current flow estimate. 144 | warped = util.warp(feature_pyramid_b[i], v) 145 | # Estimate the residual flow between pyramid_a[i] and warped image: 146 | v_residual = predictor(feature_pyramid_a[i], warped) 147 | residuals.insert(0, v_residual) 148 | v = v_residual + v 149 | return residuals 150 | -------------------------------------------------------------------------------- /feature_extractor.py: -------------------------------------------------------------------------------- 1 | """PyTorch layer for extracting image features for the film_net interpolator. 2 | 3 | The feature extractor implemented here converts an image pyramid into a pyramid 4 | of deep features. The feature pyramid serves a similar purpose as U-Net 5 | architecture's encoder, but we use a special cascaded architecture described in 6 | Multi-view Image Fusion [1]. 7 | 8 | For comprehensiveness, below is a short description of the idea. While the 9 | description is a bit involved, the cascaded feature pyramid can be used just 10 | like any image feature pyramid. 11 | 12 | Why cascaded architeture? 13 | ========================= 14 | To understand the concept it is worth reviewing a traditional feature pyramid 15 | first: *A traditional feature pyramid* as in U-net or in many optical flow 16 | networks is built by alternating between convolutions and pooling, starting 17 | from the input image. 18 | 19 | It is well known that early features of such architecture correspond to low 20 | level concepts such as edges in the image whereas later layers extract 21 | semantically higher level concepts such as object classes etc. In other words, 22 | the meaning of the filters in each resolution level is different. For problems 23 | such as semantic segmentation and many others this is a desirable property. 24 | 25 | However, the asymmetric features preclude sharing weights across resolution 26 | levels in the feature extractor itself and in any subsequent neural networks 27 | that follow. This can be a downside, since optical flow prediction, for 28 | instance is symmetric across resolution levels. The cascaded feature 29 | architecture addresses this shortcoming. 30 | 31 | How is it built? 32 | ================ 33 | The *cascaded* feature pyramid contains feature vectors that have constant 34 | length and meaning on each resolution level, except few of the finest ones. The 35 | advantage of this is that the subsequent optical flow layer can learn 36 | synergically from many resolutions. This means that coarse level prediction can 37 | benefit from finer resolution training examples, which can be useful with 38 | moderately sized datasets to avoid overfitting. 39 | 40 | The cascaded feature pyramid is built by extracting shallower subtree pyramids, 41 | each one of them similar to the traditional architecture. Each subtree 42 | pyramid S_i is extracted starting from each resolution level: 43 | 44 | image resolution 0 -> S_0 45 | image resolution 1 -> S_1 46 | image resolution 2 -> S_2 47 | ... 48 | 49 | If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid 50 | is constructed by concatenating features as follows (assuming subtree depth=3): 51 | 52 | lvl 53 | feat_0 = concat( S_0_0 ) 54 | feat_1 = concat( S_1_0 S_0_1 ) 55 | feat_2 = concat( S_2_0 S_1_1 S_0_2 ) 56 | feat_3 = concat( S_3_0 S_2_1 S_1_2 ) 57 | feat_4 = concat( S_4_0 S_3_1 S_2_2 ) 58 | feat_5 = concat( S_5_0 S_4_1 S_3_2 ) 59 | .... 60 | 61 | In above, all levels except feat_0 and feat_1 have the same number of features 62 | with similar semantic meaning. This enables training a single optical flow 63 | predictor module shared by levels 2,3,4,5... . For more details and evaluation 64 | see [1]. 65 | 66 | [1] Multi-view Image Fusion, Trinidad et al. 2019 67 | """ 68 | from typing import List 69 | 70 | import torch 71 | from torch import nn 72 | from torch.nn import functional as F 73 | 74 | from util import Conv2d 75 | 76 | 77 | class SubTreeExtractor(nn.Module): 78 | """Extracts a hierarchical set of features from an image. 79 | 80 | This is a conventional, hierarchical image feature extractor, that extracts 81 | [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels. 82 | Each level is followed by average pooling. 83 | """ 84 | 85 | def __init__(self, in_channels=3, channels=64, n_layers=4): 86 | super().__init__() 87 | convs = [] 88 | for i in range(n_layers): 89 | convs.append(nn.Sequential( 90 | Conv2d(in_channels, (channels << i), 3), 91 | Conv2d((channels << i), (channels << i), 3) 92 | )) 93 | in_channels = channels << i 94 | self.convs = nn.ModuleList(convs) 95 | 96 | def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]: 97 | """Extracts a pyramid of features from the image. 98 | 99 | Args: 100 | image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS. 101 | n: number of pyramid levels to extract. This can be less or equal to 102 | options.sub_levels given in the __init__. 103 | Returns: 104 | The pyramid of features, starting from the finest level. Each element 105 | contains the output after the last convolution on the corresponding 106 | pyramid level. 107 | """ 108 | head = image 109 | pyramid = [] 110 | for i, layer in enumerate(self.convs): 111 | head = layer(head) 112 | pyramid.append(head) 113 | if i < n - 1: 114 | head = F.avg_pool2d(head, kernel_size=2, stride=2) 115 | return pyramid 116 | 117 | 118 | class FeatureExtractor(nn.Module): 119 | """Extracts features from an image pyramid using a cascaded architecture. 120 | """ 121 | 122 | def __init__(self, in_channels=3, channels=64, sub_levels=4): 123 | super().__init__() 124 | self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels) 125 | self.sub_levels = sub_levels 126 | 127 | def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: 128 | """Extracts a cascaded feature pyramid. 129 | 130 | Args: 131 | image_pyramid: Image pyramid as a list, starting from the finest level. 132 | Returns: 133 | A pyramid of cascaded features. 134 | """ 135 | sub_pyramids: List[List[torch.Tensor]] = [] 136 | for i in range(len(image_pyramid)): 137 | # At each level of the image pyramid, creates a sub_pyramid of features 138 | # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor. 139 | # We use the same instance since we want to share the weights. 140 | # 141 | # However, we cap the depth of the sub_pyramid so we don't create features 142 | # that are beyond the coarsest level of the cascaded feature pyramid we 143 | # want to generate. 144 | capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels) 145 | sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels)) 146 | # Below we generate the cascades of features on each level of the feature 147 | # pyramid. Assuming sub_levels=3, The layout of the features will be 148 | # as shown in the example on file documentation above. 149 | feature_pyramid: List[torch.Tensor] = [] 150 | for i in range(len(image_pyramid)): 151 | features = sub_pyramids[i][0] 152 | for j in range(1, self.sub_levels): 153 | if j <= i: 154 | features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) 155 | feature_pyramid.append(features) 156 | return feature_pyramid 157 | -------------------------------------------------------------------------------- /interpolator.py: -------------------------------------------------------------------------------- 1 | """The film_net frame interpolator main model code. 2 | 3 | Basics 4 | ====== 5 | The film_net is an end-to-end learned neural frame interpolator implemented as 6 | a PyTorch model. It has the following inputs and outputs: 7 | 8 | Inputs: 9 | x0: image A. 10 | x1: image B. 11 | time: desired sub-frame time. 12 | 13 | Outputs: 14 | image: the predicted in-between image at the chosen time in range [0, 1]. 15 | 16 | Additional outputs include forward and backward warped image pyramids, flow 17 | pyramids, etc., that can be visualized for debugging and analysis. 18 | 19 | Note that many training sets only contain triplets with ground truth at 20 | time=0.5. If a model has been trained with such training set, it will only work 21 | well for synthesizing frames at time=0.5. Such models can only generate more 22 | in-between frames using recursion. 23 | 24 | Architecture 25 | ============ 26 | The inference consists of three main stages: 1) feature extraction 2) warping 27 | 3) fusion. On high-level, the architecture has similarities to Context-aware 28 | Synthesis for Video Frame Interpolation [1], but the exact architecture is 29 | closer to Multi-view Image Fusion [2] with some modifications for the frame 30 | interpolation use-case. 31 | 32 | Feature extraction stage employs the cascaded multi-scale architecture described 33 | in [2]. The advantage of this architecture is that coarse level flow prediction 34 | can be learned from finer resolution image samples. This is especially useful 35 | to avoid overfitting with moderately sized datasets. 36 | 37 | The warping stage uses a residual flow prediction idea that is similar to 38 | PWC-Net [3], Multi-view Image Fusion [2] and many others. 39 | 40 | The fusion stage is similar to U-Net's decoder where the skip connections are 41 | connected to warped image and feature pyramids. This is described in [2]. 42 | 43 | Implementation Conventions 44 | ==================== 45 | Pyramids 46 | -------- 47 | Throughtout the model, all image and feature pyramids are stored as python lists 48 | with finest level first followed by downscaled versions obtained by successively 49 | halving the resolution. The depths of all pyramids are determined by 50 | options.pyramid_levels. The only exception to this is internal to the feature 51 | extractor, where smaller feature pyramids are temporarily constructed with depth 52 | options.sub_levels. 53 | 54 | Color ranges & gamma 55 | -------------------- 56 | The model code makes no assumptions on whether the images are in gamma or 57 | linearized space or what is the range of RGB color values. So a model can be 58 | trained with different choices. This does not mean that all the choices lead to 59 | similar results. In practice the model has been proven to work well with RGB 60 | scale = [0,1] with gamma-space images (i.e. not linearized). 61 | 62 | [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018 63 | [2] Multi-view Image Fusion, Trinidad et al, 2019 64 | [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume 65 | """ 66 | from typing import Dict, List 67 | 68 | import torch 69 | from torch import nn 70 | 71 | import util 72 | from feature_extractor import FeatureExtractor 73 | from fusion import Fusion 74 | from pyramid_flow_estimator import PyramidFlowEstimator 75 | 76 | 77 | class Interpolator(nn.Module): 78 | def __init__( 79 | self, 80 | pyramid_levels=7, 81 | fusion_pyramid_levels=5, 82 | specialized_levels=3, 83 | sub_levels=4, 84 | filters=64, 85 | flow_convs=(3, 3, 3, 3), 86 | flow_filters=(32, 64, 128, 256), 87 | ): 88 | super().__init__() 89 | self.pyramid_levels = pyramid_levels 90 | self.fusion_pyramid_levels = fusion_pyramid_levels 91 | 92 | self.extract = FeatureExtractor(3, filters, sub_levels) 93 | self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters) 94 | self.fuse = Fusion(sub_levels, specialized_levels, filters) 95 | 96 | def shuffle_images(self, x0, x1): 97 | return [ 98 | util.build_image_pyramid(x0, self.pyramid_levels), 99 | util.build_image_pyramid(x1, self.pyramid_levels) 100 | ] 101 | 102 | def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]: 103 | image_pyramids = self.shuffle_images(x0, x1) 104 | 105 | # Siamese feature pyramids: 106 | feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])] 107 | 108 | # Predict forward flow. 109 | forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1]) 110 | 111 | # Predict backward flow. 112 | backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0]) 113 | 114 | # Concatenate features and images: 115 | 116 | # Note that we keep up to 'fusion_pyramid_levels' levels as only those 117 | # are used by the fusion module. 118 | 119 | forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels] 120 | 121 | backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels] 122 | 123 | # We multiply the flows with t and 1-t to warp to the desired fractional time. 124 | # 125 | # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo- 126 | # lator for multi-frame interpolation. Below, we create a constant tensor of 127 | # shape [B]. We use the `time` tensor to infer the batch size. 128 | backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt) 129 | forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt) 130 | 131 | pyramids_to_warp = [ 132 | util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels], 133 | feature_pyramids[0][:self.fusion_pyramid_levels]), 134 | util.concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels], 135 | feature_pyramids[1][:self.fusion_pyramid_levels]) 136 | ] 137 | 138 | # Warp features and images using the flow. Note that we use backward warping 139 | # and backward flow is used to read from image 0 and forward flow from 140 | # image 1. 141 | forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow) 142 | backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow) 143 | 144 | aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid, 145 | backward_warped_pyramid) 146 | aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow) 147 | aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow) 148 | 149 | return { 150 | 'image': [self.fuse(aligned_pyramid)], 151 | 'forward_residual_flow_pyramid': forward_residual_flow_pyramid, 152 | 'backward_residual_flow_pyramid': backward_residual_flow_pyramid, 153 | 'forward_flow_pyramid': forward_flow_pyramid, 154 | 'backward_flow_pyramid': backward_flow_pyramid, 155 | } 156 | 157 | def forward(self, x0, x1, batch_dt) -> torch.Tensor: 158 | return self.debug_forward(x0, x1, batch_dt)['image'][0] 159 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------