├── .gitignore
├── requirements.txt
├── photos
├── one.png
├── two.png
└── output.gif
├── README.md
├── inference.py
├── fusion.py
├── export.py
├── util.py
├── pyramid_flow_estimator.py
├── feature_extractor.py
├── interpolator.py
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | /photos/output.mp4
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | torch
3 | tqdm
--------------------------------------------------------------------------------
/photos/one.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dajes/frame-interpolation-pytorch/HEAD/photos/one.png
--------------------------------------------------------------------------------
/photos/two.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dajes/frame-interpolation-pytorch/HEAD/photos/two.png
--------------------------------------------------------------------------------
/photos/output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dajes/frame-interpolation-pytorch/HEAD/photos/output.gif
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Frame interpolation in PyTorch
3 |
4 | This is an unofficial PyTorch inference implementation
5 | of [FILM: Frame Interpolation for Large Motion, In ECCV 2022](https://film-net.github.io/).\
6 | [Original repository link](https://github.com/google-research/frame-interpolation)
7 |
8 | The project is focused on creating simple and TorchScript compilable inference interface for the original pretrained TF2
9 | model.
10 |
11 | # Quickstart
12 |
13 | Download a compiled model from [the release](https://github.com/dajes/frame-interpolation-pytorch/releases)
14 | and specify the path to the file in the following snippet:
15 |
16 | ```python
17 | import torch
18 |
19 | device = torch.device('cuda')
20 | precision = torch.float16
21 |
22 | model = torch.jit.load(model_path, map_location='cpu')
23 | model.eval().to(device=device, dtype=precision)
24 |
25 | img1 = torch.rand(1, 3, 720, 1080).to(precision).to(device)
26 | img3 = torch.rand(1, 3, 720, 1080).to(precision).to(device)
27 | dt = img1.new_full((1, 1), .5)
28 |
29 | with torch.no_grad():
30 | img2 = model(img1, img3, dt) # Will be of the same shape as inputs (1, 3, 720, 1080)
31 |
32 | ```
33 |
34 | # Exporting model by yourself
35 |
36 | You will need to install TensorFlow of the version specified in
37 | the [original repo](https://github.com/google-research/frame-interpolation#installation) and download SavedModel of "
38 | Style" network from [there](https://github.com/google-research/frame-interpolation#pre-trained-models)
39 |
40 | After you have downloaded the SavedModel and can load it via ```tf.compat.v2.saved_model.load(path)```:
41 |
42 | * Clone the repository
43 |
44 | ```
45 | git clone https://github.com/dajes/frame-interpolation-pytorch
46 | cd frame-interpolation-pytorch
47 | ```
48 |
49 | * Install dependencies
50 |
51 | ```
52 | python -m pip install -r requirements.txt
53 | ```
54 |
55 | * Run ```export.py```:
56 |
57 | ```
58 | python export.py "model_path" "save_path" [--statedict] [--fp32] [--skiptest] [--gpu]
59 | ```
60 |
61 | Argument list:
62 |
63 | * ```model_path``` Path to the TF SavedModel
64 | * ```save_path``` Path to save the PyTorch state dict
65 | * ```--statedict``` Export to state dict instead of TorchScript
66 | * ```--fp32``` Save weights at full precision
67 | * ```--skiptest``` Skip testing and save model immediately instead
68 | * ```--gpu``` Whether to attempt to use GPU for testing
69 |
70 | # Testing exported model
71 | The following script creates an MP4 video of interpolated frames between 2 input images:
72 | ```
73 | python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--gpu] [--fp16] [--frames FRAMES] [--fps FPS]
74 | ```
75 | * ```model_path``` Path to the exported TorchScript checkpoint
76 | * ```img1``` Path to the first image
77 | * ```img2``` Path to the second image
78 | * ```--save_path SAVE_PATH``` Path to save the interpolated frames as a video, if absent it will be saved in the same directory as ```img1``` is located and named ```output.mp4```
79 | * ```--gpu``` Whether to attempt to use GPU for predictions
80 | * ```--fp16``` Whether to use fp16 for calculations, speeds inference up on GPUs with tensor cores
81 | * ```--frames FRAMES``` Number of frames to interpolate between the input images
82 | * ```--fps FPS``` FPS of the output video
83 |
84 | ### Results on the 2 example photos from original repository:
85 |
86 |
87 |
88 |
90 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import os
3 | from tqdm import tqdm
4 | import torch
5 | import numpy as np
6 | import cv2
7 |
8 | from util import load_image
9 |
10 |
11 | def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
12 | model = torch.jit.load(model_path, map_location='cpu')
13 | model.eval()
14 | img_batch_1, crop_region_1 = load_image(img1)
15 | img_batch_2, crop_region_2 = load_image(img2)
16 |
17 | img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2)
18 | img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2)
19 |
20 | if not half:
21 | model.float()
22 |
23 | if gpu and torch.cuda.is_available():
24 | if half:
25 | model = model.half()
26 | else:
27 | model.float()
28 | model = model.cuda()
29 |
30 | if save_path == 'img1 folder':
31 | save_path = os.path.join(os.path.split(img1)[0], 'output.mp4')
32 |
33 | results = [
34 | img_batch_1,
35 | img_batch_2
36 | ]
37 |
38 | idxes = [0, inter_frames + 1]
39 | remains = list(range(1, inter_frames + 1))
40 |
41 | splits = torch.linspace(0, 1, inter_frames + 2)
42 |
43 | for _ in tqdm(range(len(remains)), 'Generating in-between frames'):
44 | starts = splits[idxes[:-1]]
45 | ends = splits[idxes[1:]]
46 | distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
47 | matrix = torch.argmin(distances).item()
48 | start_i, step = np.unravel_index(matrix, distances.shape)
49 | end_i = start_i + 1
50 |
51 | x0 = results[start_i]
52 | x1 = results[end_i]
53 |
54 | if gpu and torch.cuda.is_available():
55 | if half:
56 | x0 = x0.half()
57 | x1 = x1.half()
58 | x0 = x0.cuda()
59 | x1 = x1.cuda()
60 |
61 | dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
62 |
63 | with torch.no_grad():
64 | prediction = model(x0, x1, dt)
65 | insert_position = bisect.bisect_left(idxes, remains[step])
66 | idxes.insert(insert_position, remains[step])
67 | results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
68 | del remains[step]
69 |
70 | video_folder = os.path.split(save_path)[0]
71 | os.makedirs(video_folder, exist_ok=True)
72 |
73 | y1, x1, y2, x2 = crop_region_1
74 | frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy()[y1:y2, x1:x2].copy() for tensor in results]
75 |
76 | w, h = frames[0].shape[1::-1]
77 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
78 | writer = cv2.VideoWriter(save_path, fourcc, fps, (w, h))
79 | for frame in frames:
80 | writer.write(frame)
81 |
82 | for frame in frames[1:][::-1]:
83 | writer.write(frame)
84 |
85 | writer.release()
86 |
87 |
88 | if __name__ == '__main__':
89 | import argparse
90 |
91 | parser = argparse.ArgumentParser(description='Test frame interpolator model')
92 |
93 | parser.add_argument('model_path', type=str, help='Path to the TorchScript model')
94 | parser.add_argument('img1', type=str, help='Path to the first image')
95 | parser.add_argument('img2', type=str, help='Path to the second image')
96 |
97 | parser.add_argument('--save_path', type=str, default='img1 folder', help='Path to save the interpolated frames')
98 | parser.add_argument('--gpu', action='store_true', help='Use GPU')
99 | parser.add_argument('--fp16', action='store_true', help='Use FP16')
100 | parser.add_argument('--frames', type=int, default=18, help='Number of frames to interpolate')
101 | parser.add_argument('--fps', type=int, default=10, help='FPS of the output video')
102 |
103 | args = parser.parse_args()
104 |
105 | inference(args.model_path, args.img1, args.img2, args.save_path, args.gpu, args.frames, args.fps, args.fp16)
106 |
--------------------------------------------------------------------------------
/fusion.py:
--------------------------------------------------------------------------------
1 | """The final fusion stage for the film_net frame interpolator.
2 |
3 | The inputs to this module are the warped input images, image features and
4 | flow fields, all aligned to the target frame (often midway point between the
5 | two original inputs). The output is the final image. FILM has no explicit
6 | occlusion handling -- instead using the abovementioned information this module
7 | automatically decides how to best blend the inputs together to produce content
8 | in areas where the pixels can only be borrowed from one of the inputs.
9 |
10 | Similarly, this module also decides on how much to blend in each input in case
11 | of fractional timestep that is not at the halfway point. For example, if the two
12 | inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
13 | it often makes most sense to favor the first input. However, this is not
14 | always the case -- in particular in occluded pixels.
15 |
16 | The architecture of the Fusion module follows U-net [1] architecture's decoder
17 | side, e.g. each pyramid level consists of concatenation with upsampled coarser
18 | level output, and two 3x3 convolutions.
19 |
20 | The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
21 | upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
22 | uses max-pooling which has a tendency to create checkerboard artifacts.
23 |
24 | [1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
25 | Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
26 | [2] https://distill.pub/2016/deconv-checkerboard/
27 | """
28 | from typing import List
29 |
30 | import torch
31 | from torch import nn
32 | from torch.nn import functional as F
33 |
34 | from util import Conv2d
35 |
36 | _NUMBER_OF_COLOR_CHANNELS = 3
37 |
38 |
39 | def get_channels_at_level(level, filters):
40 | n_images = 2
41 | channels = _NUMBER_OF_COLOR_CHANNELS
42 | flows = 2
43 |
44 | return (sum(filters << i for i in range(level)) + channels + flows) * n_images
45 |
46 |
47 | class Fusion(nn.Module):
48 | """The decoder."""
49 |
50 | def __init__(self, n_layers=4, specialized_layers=3, filters=64):
51 | """
52 | Args:
53 | m: specialized levels
54 | """
55 | super().__init__()
56 |
57 | # The final convolution that outputs RGB:
58 | self.output_conv = nn.Conv2d(filters, 3, kernel_size=1)
59 |
60 | # Each item 'convs[i]' will contain the list of convolutions to be applied
61 | # for pyramid level 'i'.
62 | self.convs = nn.ModuleList()
63 |
64 | # Create the convolutions. Roughly following the feature extractor, we
65 | # double the number of filters when the resolution halves, but only up to
66 | # the specialized_levels, after which we use the same number of filters on
67 | # all levels.
68 | #
69 | # We create the convs in fine-to-coarse order, so that the array index
70 | # for the convs will correspond to our normal indexing (0=finest level).
71 | # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442)
72 |
73 | in_channels = get_channels_at_level(n_layers, filters)
74 | increase = 0
75 | for i in range(n_layers)[::-1]:
76 | num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
77 | convs = nn.ModuleList([
78 | Conv2d(in_channels, num_filters, size=2, activation=None),
79 | Conv2d(in_channels + (increase or num_filters), num_filters, size=3),
80 | Conv2d(num_filters, num_filters, size=3)]
81 | )
82 | self.convs.append(convs)
83 | in_channels = num_filters
84 | increase = get_channels_at_level(i, filters) - num_filters // 2
85 |
86 | def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor:
87 | """Runs the fusion module.
88 |
89 | Args:
90 | pyramid: The input feature pyramid as list of tensors. Each tensor being
91 | in (B x H x W x C) format, with finest level tensor first.
92 |
93 | Returns:
94 | A batch of RGB images.
95 | Raises:
96 | ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
97 | the constructor.
98 | """
99 |
100 | # As a slight difference to a conventional decoder (e.g. U-net), we don't
101 | # apply any extra convolutions to the coarsest level, but just pass it
102 | # to finer levels for concatenation. This choice has not been thoroughly
103 | # evaluated, but is motivated by the educated guess that the fusion part
104 | # probably does not need large spatial context, because at this point the
105 | # features are spatially aligned by the preceding warp.
106 | net = pyramid[-1]
107 |
108 | # Loop starting from the 2nd coarsest level:
109 | # for i in reversed(range(0, len(pyramid) - 1)):
110 | for k, layers in enumerate(self.convs):
111 | i = len(self.convs) - 1 - k
112 | # Resize the tensor from coarser level to match for concatenation.
113 | level_size = pyramid[i].shape[2:4]
114 | net = F.interpolate(net, size=level_size, mode='nearest')
115 | net = layers[0](net)
116 | net = torch.cat([pyramid[i], net], dim=1)
117 | net = layers[1](net)
118 | net = layers[2](net)
119 | net = self.output_conv(net)
120 | return net
121 |
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | import torch
6 |
7 | from interpolator import Interpolator
8 |
9 |
10 | def translate_state_dict(var_dict, state_dict):
11 | for name, (prev_name, weight) in zip(state_dict, var_dict.items()):
12 | print('Mapping', prev_name, '->', name)
13 | weight = torch.from_numpy(weight)
14 | if 'kernel' in prev_name:
15 | # Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W)
16 | weight = weight.permute(3, 2, 0, 1)
17 |
18 | assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}'
19 |
20 | state_dict[name] = weight
21 |
22 |
23 | def import_state_dict(interpolator: Interpolator, saved_model):
24 | variables = saved_model.keras_api.variables
25 |
26 | extract_dict = interpolator.extract.state_dict()
27 | flow_dict = interpolator.predict_flow.state_dict()
28 | fuse_dict = interpolator.fuse.state_dict()
29 |
30 | extract_vars = {}
31 | _flow_vars = {}
32 | _fuse_vars = {}
33 |
34 | for var in variables:
35 | name = var.name
36 | if name.startswith('feat_net'):
37 | extract_vars[name[9:]] = var.numpy()
38 | elif name.startswith('predict_flow'):
39 | _flow_vars[name[13:]] = var.numpy()
40 | elif name.startswith('fusion'):
41 | _fuse_vars[name[7:]] = var.numpy()
42 |
43 | # reverse order of modules to allow jit export
44 | # TODO: improve this hack
45 | flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True))
46 | fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True))
47 |
48 | assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}'
49 | assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}'
50 | assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}'
51 |
52 | for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)):
53 | translate_state_dict(var_dict, state_dict)
54 |
55 | interpolator.extract.load_state_dict(extract_dict)
56 | interpolator.predict_flow.load_state_dict(flow_dict)
57 | interpolator.fuse.load_state_dict(fuse_dict)
58 |
59 |
60 | def verify_debug_outputs(pt_outputs, tf_outputs):
61 | max_error = 0
62 | for name, predicted in pt_outputs.items():
63 | if name == 'image':
64 | continue
65 | pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted]
66 | true_frfp = [f.numpy() for f in tf_outputs[name]]
67 |
68 | for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)):
69 | assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}'
70 | error = np.max(np.abs(pred - true))
71 | max_error = max(max_error, error)
72 | assert error < 1, f'{name} {i} max error: {error}'
73 | print('Max intermediate error:', max_error)
74 |
75 |
76 | def test_model(interpolator, model, half=False, gpu=False):
77 | torch.manual_seed(0)
78 | time = torch.full((1, 1), .5)
79 | x0 = torch.rand(1, 3, 256, 256)
80 | x1 = torch.rand(1, 3, 256, 256)
81 |
82 | x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
83 | x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
84 | time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32)
85 | tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False)
86 |
87 | if half:
88 | x0 = x0.half()
89 | x1 = x1.half()
90 | time = time.half()
91 |
92 | if gpu and torch.cuda.is_available():
93 | x0 = x0.cuda()
94 | x1 = x1.cuda()
95 | time = time.cuda()
96 |
97 | with torch.no_grad():
98 | pt_outputs = interpolator.debug_forward(x0, x1, time)
99 |
100 | verify_debug_outputs(pt_outputs, tf_outputs)
101 |
102 | with torch.no_grad():
103 | prediction = interpolator(x0, x1, time)
104 | output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy()
105 | true_color = tf_outputs['image'].numpy()
106 | error = np.abs(output_color - true_color).max()
107 |
108 | print('Color max error:', error)
109 |
110 |
111 | def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False):
112 | print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} '
113 | f'using {"CG"[use_gpu]}PU')
114 | model = tf.compat.v2.saved_model.load(model_path)
115 | interpolator = Interpolator()
116 | interpolator.eval()
117 | import_state_dict(interpolator, model)
118 |
119 | if use_gpu and torch.cuda.is_available():
120 | interpolator = interpolator.cuda()
121 | else:
122 | use_gpu = False
123 |
124 | if fp16:
125 | interpolator = interpolator.half()
126 | if export_to_torchscript:
127 | interpolator = torch.jit.script(interpolator)
128 | if export_to_torchscript:
129 | interpolator.save(save_path)
130 | else:
131 | torch.save(interpolator.state_dict(), save_path)
132 |
133 | if not skiptest:
134 | if not use_gpu and fp16:
135 | warnings.warn('Testing FP16 model on CPU is impossible, casting it back')
136 | interpolator = interpolator.float()
137 | fp16 = False
138 | test_model(interpolator, model, fp16, use_gpu)
139 |
140 |
141 | if __name__ == '__main__':
142 | import argparse
143 |
144 | parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict')
145 |
146 | parser.add_argument('model_path', type=str, help='Path to the TF SavedModel')
147 | parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict')
148 | parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript')
149 | parser.add_argument('--fp32', action='store_true', help='Save at full precision')
150 | parser.add_argument('--skiptest', action='store_true', help='Skip testing and save model immediately instead')
151 | parser.add_argument('--gpu', action='store_true', help='Use GPU')
152 |
153 | args = parser.parse_args()
154 |
155 | main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)
156 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | """Various utilities used in the film_net frame interpolator model."""
2 | from typing import List, Optional
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 |
10 |
11 | def pad_batch(batch, align):
12 | height, width = batch.shape[1:3]
13 | height_to_pad = (align - height % align) if height % align != 0 else 0
14 | width_to_pad = (align - width % align) if width % align != 0 else 0
15 |
16 | crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)]
17 | batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
18 | (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant')
19 | return batch, crop_region
20 |
21 |
22 | def load_image(path, align=64):
23 | image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
24 | image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align)
25 | return image_batch, crop_region
26 |
27 |
28 | def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]:
29 | """Builds an image pyramid from a given image.
30 |
31 | The original image is included in the pyramid and the rest are generated by
32 | successively halving the resolution.
33 |
34 | Args:
35 | image: the input image.
36 | options: film_net options object
37 |
38 | Returns:
39 | A list of images starting from the finest with options.pyramid_levels items
40 | """
41 |
42 | pyramid = []
43 | for i in range(pyramid_levels):
44 | pyramid.append(image)
45 | if i < pyramid_levels - 1:
46 | image = F.avg_pool2d(image, 2, 2)
47 | return pyramid
48 |
49 |
50 | def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
51 | """Backward warps the image using the given flow.
52 |
53 | Specifically, the output pixel in batch b, at position x, y will be computed
54 | as follows:
55 | (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
56 | output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
57 |
58 | Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
59 | y in position 1.
60 |
61 | Args:
62 | image: An image with shape BxHxWxC.
63 | flow: A flow with shape BxHxWx2, with the two channels denoting the relative
64 | offset in order: (dx, dy).
65 | Returns:
66 | A warped image.
67 | """
68 | flow = -flow.flip(1)
69 |
70 | dtype = flow.dtype
71 | device = flow.device
72 |
73 | # warped = tfa_image.dense_image_warp(image, flow)
74 | # Same as above but with pytorch
75 | ls1 = 1 - 1 / flow.shape[3]
76 | ls2 = 1 - 1 / flow.shape[2]
77 |
78 | normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor(
79 | [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None]
80 | normalized_flow2 = torch.stack([
81 | torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1],
82 | torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0],
83 | ], dim=3)
84 |
85 | warped = F.grid_sample(image, normalized_flow2,
86 | mode='bilinear', padding_mode='border', align_corners=False)
87 | return warped.reshape(image.shape)
88 |
89 |
90 | def multiply_pyramid(pyramid: List[torch.Tensor],
91 | scalar: torch.Tensor) -> List[torch.Tensor]:
92 | """Multiplies all image batches in the pyramid by a batch of scalars.
93 |
94 | Args:
95 | pyramid: Pyramid of image batches.
96 | scalar: Batch of scalars.
97 |
98 | Returns:
99 | An image pyramid with all images multiplied by the scalar.
100 | """
101 | # To multiply each image with its corresponding scalar, we first transpose
102 | # the batch of images from BxHxWxC-format to CxHxWxB. This can then be
103 | # multiplied with a batch of scalars, then we transpose back to the standard
104 | # BxHxWxC form.
105 | return [image * scalar[..., None, None] for image in pyramid]
106 |
107 |
108 | def flow_pyramid_synthesis(
109 | residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
110 | """Converts a residual flow pyramid into a flow pyramid."""
111 | flow = residual_pyramid[-1]
112 | flow_pyramid: List[torch.Tensor] = [flow]
113 | for residual_flow in residual_pyramid[:-1][::-1]:
114 | level_size = residual_flow.shape[2:4]
115 | flow = F.interpolate(2 * flow, size=level_size, mode='bilinear')
116 | flow = residual_flow + flow
117 | flow_pyramid.insert(0, flow)
118 | return flow_pyramid
119 |
120 |
121 | def pyramid_warp(feature_pyramid: List[torch.Tensor],
122 | flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
123 | """Warps the feature pyramid using the flow pyramid.
124 |
125 | Args:
126 | feature_pyramid: feature pyramid starting from the finest level.
127 | flow_pyramid: flow fields, starting from the finest level.
128 |
129 | Returns:
130 | Reverse warped feature pyramid.
131 | """
132 | warped_feature_pyramid = []
133 | for features, flow in zip(feature_pyramid, flow_pyramid):
134 | warped_feature_pyramid.append(warp(features, flow))
135 | return warped_feature_pyramid
136 |
137 |
138 | def concatenate_pyramids(pyramid1: List[torch.Tensor],
139 | pyramid2: List[torch.Tensor]) -> List[torch.Tensor]:
140 | """Concatenates each pyramid level together in the channel dimension."""
141 | result = []
142 | for features1, features2 in zip(pyramid1, pyramid2):
143 | result.append(torch.cat([features1, features2], dim=1))
144 | return result
145 |
146 |
147 | class Conv2d(nn.Sequential):
148 | def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'):
149 | assert activation in (None, 'relu')
150 | super().__init__(
151 | nn.Conv2d(
152 | in_channels=in_channels,
153 | out_channels=out_channels,
154 | kernel_size=size,
155 | padding='same' if size % 2 else 0)
156 | )
157 | self.size = size
158 | self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None
159 |
160 | def forward(self, x):
161 | if not self.size % 2:
162 | x = F.pad(x, (0, 1, 0, 1))
163 | y = self[0](x)
164 | if self.activation is not None:
165 | y = self.activation(y)
166 | return y
167 |
--------------------------------------------------------------------------------
/pyramid_flow_estimator.py:
--------------------------------------------------------------------------------
1 | """PyTorch layer for estimating optical flow by a residual flow pyramid.
2 |
3 | This approach of estimating optical flow between two images can be traced back
4 | to [1], but is also used by later neural optical flow computation methods such
5 | as SpyNet [2] and PWC-Net [3].
6 |
7 | The basic idea is that the optical flow is first estimated in a coarse
8 | resolution, then the flow is upsampled to warp the higher resolution image and
9 | then a residual correction is computed and added to the estimated flow. This
10 | process is repeated in a pyramid on coarse to fine order to successively
11 | increase the resolution of both optical flow and the warped image.
12 |
13 | In here, the optical flow predictor is used as an internal component for the
14 | film_net frame interpolator, to warp the two input images into the inbetween,
15 | target frame.
16 |
17 | [1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
18 | [2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
19 | Network. 2016
20 | [3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
21 | Pyramid, Warping, and Cost Volume, 2017
22 | """
23 | from typing import List
24 |
25 | import torch
26 | from torch import nn
27 | from torch.nn import functional as F
28 |
29 | import util
30 |
31 |
32 | class FlowEstimator(nn.Module):
33 | """Small-receptive field predictor for computing the flow between two images.
34 |
35 | This is used to compute the residual flow fields in PyramidFlowEstimator.
36 |
37 | Note that while the number of 3x3 convolutions & filters to apply is
38 | configurable, two extra 1x1 convolutions are appended to extract the flow in
39 | the end.
40 |
41 | Attributes:
42 | name: The name of the layer
43 | num_convs: Number of 3x3 convolutions to apply
44 | num_filters: Number of filters in each 3x3 convolution
45 | """
46 |
47 | def __init__(self, in_channels: int, num_convs: int, num_filters: int):
48 | super(FlowEstimator, self).__init__()
49 |
50 | self._convs = nn.ModuleList()
51 | for i in range(num_convs):
52 | self._convs.append(util.Conv2d(in_channels=in_channels, out_channels=num_filters, size=3))
53 | in_channels = num_filters
54 | self._convs.append(util.Conv2d(in_channels, num_filters // 2, size=1))
55 | in_channels = num_filters // 2
56 | # For the final convolution, we want no activation at all to predict the
57 | # optical flow vector values. We have done extensive testing on explicitly
58 | # bounding these values using sigmoid, but it turned out that having no
59 | # activation gives better results.
60 | self._convs.append(util.Conv2d(in_channels, 2, size=1, activation=None))
61 |
62 | def forward(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor:
63 | """Estimates optical flow between two images.
64 |
65 | Args:
66 | features_a: per pixel feature vectors for image A (B x H x W x C)
67 | features_b: per pixel feature vectors for image B (B x H x W x C)
68 |
69 | Returns:
70 | A tensor with optical flow from A to B
71 | """
72 | net = torch.cat([features_a, features_b], dim=1)
73 | for conv in self._convs:
74 | net = conv(net)
75 | return net
76 |
77 |
78 | class PyramidFlowEstimator(nn.Module):
79 | """Predicts optical flow by coarse-to-fine refinement.
80 | """
81 |
82 | def __init__(self, filters: int = 64,
83 | flow_convs: tuple = (3, 3, 3, 3),
84 | flow_filters: tuple = (32, 64, 128, 256)):
85 | super(PyramidFlowEstimator, self).__init__()
86 |
87 | in_channels = filters << 1
88 | predictors = []
89 | for i in range(len(flow_convs)):
90 | predictors.append(
91 | FlowEstimator(
92 | in_channels=in_channels,
93 | num_convs=flow_convs[i],
94 | num_filters=flow_filters[i]))
95 | in_channels += filters << (i + 2)
96 | self._predictor = predictors[-1]
97 | self._predictors = nn.ModuleList(predictors[:-1][::-1])
98 |
99 | def forward(self, feature_pyramid_a: List[torch.Tensor],
100 | feature_pyramid_b: List[torch.Tensor]) -> List[torch.Tensor]:
101 | """Estimates residual flow pyramids between two image pyramids.
102 |
103 | Each image pyramid is represented as a list of tensors in fine-to-coarse
104 | order. Each individual image is represented as a tensor where each pixel is
105 | a vector of image features.
106 |
107 | util.flow_pyramid_synthesis can be used to convert the residual flow
108 | pyramid returned by this method into a flow pyramid, where each level
109 | encodes the flow instead of a residual correction.
110 |
111 | Args:
112 | feature_pyramid_a: image pyramid as a list in fine-to-coarse order
113 | feature_pyramid_b: image pyramid as a list in fine-to-coarse order
114 |
115 | Returns:
116 | List of flow tensors, in fine-to-coarse order, each level encoding the
117 | difference against the bilinearly upsampled version from the coarser
118 | level. The coarsest flow tensor, e.g. the last element in the array is the
119 | 'DC-term', e.g. not a residual (alternatively you can think of it being a
120 | residual against zero).
121 | """
122 | levels = len(feature_pyramid_a)
123 | v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
124 | residuals = [v]
125 | for i in range(levels - 2, len(self._predictors) - 1, -1):
126 | # Upsamples the flow to match the current pyramid level. Also, scales the
127 | # magnitude by two to reflect the new size.
128 | level_size = feature_pyramid_a[i].shape[2:4]
129 | v = F.interpolate(2 * v, size=level_size, mode='bilinear')
130 | # Warp feature_pyramid_b[i] image based on the current flow estimate.
131 | warped = util.warp(feature_pyramid_b[i], v)
132 | # Estimate the residual flow between pyramid_a[i] and warped image:
133 | v_residual = self._predictor(feature_pyramid_a[i], warped)
134 | residuals.insert(0, v_residual)
135 | v = v_residual + v
136 |
137 | for k, predictor in enumerate(self._predictors):
138 | i = len(self._predictors) - 1 - k
139 | # Upsamples the flow to match the current pyramid level. Also, scales the
140 | # magnitude by two to reflect the new size.
141 | level_size = feature_pyramid_a[i].shape[2:4]
142 | v = F.interpolate(2 * v, size=level_size, mode='bilinear')
143 | # Warp feature_pyramid_b[i] image based on the current flow estimate.
144 | warped = util.warp(feature_pyramid_b[i], v)
145 | # Estimate the residual flow between pyramid_a[i] and warped image:
146 | v_residual = predictor(feature_pyramid_a[i], warped)
147 | residuals.insert(0, v_residual)
148 | v = v_residual + v
149 | return residuals
150 |
--------------------------------------------------------------------------------
/feature_extractor.py:
--------------------------------------------------------------------------------
1 | """PyTorch layer for extracting image features for the film_net interpolator.
2 |
3 | The feature extractor implemented here converts an image pyramid into a pyramid
4 | of deep features. The feature pyramid serves a similar purpose as U-Net
5 | architecture's encoder, but we use a special cascaded architecture described in
6 | Multi-view Image Fusion [1].
7 |
8 | For comprehensiveness, below is a short description of the idea. While the
9 | description is a bit involved, the cascaded feature pyramid can be used just
10 | like any image feature pyramid.
11 |
12 | Why cascaded architeture?
13 | =========================
14 | To understand the concept it is worth reviewing a traditional feature pyramid
15 | first: *A traditional feature pyramid* as in U-net or in many optical flow
16 | networks is built by alternating between convolutions and pooling, starting
17 | from the input image.
18 |
19 | It is well known that early features of such architecture correspond to low
20 | level concepts such as edges in the image whereas later layers extract
21 | semantically higher level concepts such as object classes etc. In other words,
22 | the meaning of the filters in each resolution level is different. For problems
23 | such as semantic segmentation and many others this is a desirable property.
24 |
25 | However, the asymmetric features preclude sharing weights across resolution
26 | levels in the feature extractor itself and in any subsequent neural networks
27 | that follow. This can be a downside, since optical flow prediction, for
28 | instance is symmetric across resolution levels. The cascaded feature
29 | architecture addresses this shortcoming.
30 |
31 | How is it built?
32 | ================
33 | The *cascaded* feature pyramid contains feature vectors that have constant
34 | length and meaning on each resolution level, except few of the finest ones. The
35 | advantage of this is that the subsequent optical flow layer can learn
36 | synergically from many resolutions. This means that coarse level prediction can
37 | benefit from finer resolution training examples, which can be useful with
38 | moderately sized datasets to avoid overfitting.
39 |
40 | The cascaded feature pyramid is built by extracting shallower subtree pyramids,
41 | each one of them similar to the traditional architecture. Each subtree
42 | pyramid S_i is extracted starting from each resolution level:
43 |
44 | image resolution 0 -> S_0
45 | image resolution 1 -> S_1
46 | image resolution 2 -> S_2
47 | ...
48 |
49 | If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
50 | is constructed by concatenating features as follows (assuming subtree depth=3):
51 |
52 | lvl
53 | feat_0 = concat( S_0_0 )
54 | feat_1 = concat( S_1_0 S_0_1 )
55 | feat_2 = concat( S_2_0 S_1_1 S_0_2 )
56 | feat_3 = concat( S_3_0 S_2_1 S_1_2 )
57 | feat_4 = concat( S_4_0 S_3_1 S_2_2 )
58 | feat_5 = concat( S_5_0 S_4_1 S_3_2 )
59 | ....
60 |
61 | In above, all levels except feat_0 and feat_1 have the same number of features
62 | with similar semantic meaning. This enables training a single optical flow
63 | predictor module shared by levels 2,3,4,5... . For more details and evaluation
64 | see [1].
65 |
66 | [1] Multi-view Image Fusion, Trinidad et al. 2019
67 | """
68 | from typing import List
69 |
70 | import torch
71 | from torch import nn
72 | from torch.nn import functional as F
73 |
74 | from util import Conv2d
75 |
76 |
77 | class SubTreeExtractor(nn.Module):
78 | """Extracts a hierarchical set of features from an image.
79 |
80 | This is a conventional, hierarchical image feature extractor, that extracts
81 | [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
82 | Each level is followed by average pooling.
83 | """
84 |
85 | def __init__(self, in_channels=3, channels=64, n_layers=4):
86 | super().__init__()
87 | convs = []
88 | for i in range(n_layers):
89 | convs.append(nn.Sequential(
90 | Conv2d(in_channels, (channels << i), 3),
91 | Conv2d((channels << i), (channels << i), 3)
92 | ))
93 | in_channels = channels << i
94 | self.convs = nn.ModuleList(convs)
95 |
96 | def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
97 | """Extracts a pyramid of features from the image.
98 |
99 | Args:
100 | image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
101 | n: number of pyramid levels to extract. This can be less or equal to
102 | options.sub_levels given in the __init__.
103 | Returns:
104 | The pyramid of features, starting from the finest level. Each element
105 | contains the output after the last convolution on the corresponding
106 | pyramid level.
107 | """
108 | head = image
109 | pyramid = []
110 | for i, layer in enumerate(self.convs):
111 | head = layer(head)
112 | pyramid.append(head)
113 | if i < n - 1:
114 | head = F.avg_pool2d(head, kernel_size=2, stride=2)
115 | return pyramid
116 |
117 |
118 | class FeatureExtractor(nn.Module):
119 | """Extracts features from an image pyramid using a cascaded architecture.
120 | """
121 |
122 | def __init__(self, in_channels=3, channels=64, sub_levels=4):
123 | super().__init__()
124 | self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
125 | self.sub_levels = sub_levels
126 |
127 | def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
128 | """Extracts a cascaded feature pyramid.
129 |
130 | Args:
131 | image_pyramid: Image pyramid as a list, starting from the finest level.
132 | Returns:
133 | A pyramid of cascaded features.
134 | """
135 | sub_pyramids: List[List[torch.Tensor]] = []
136 | for i in range(len(image_pyramid)):
137 | # At each level of the image pyramid, creates a sub_pyramid of features
138 | # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
139 | # We use the same instance since we want to share the weights.
140 | #
141 | # However, we cap the depth of the sub_pyramid so we don't create features
142 | # that are beyond the coarsest level of the cascaded feature pyramid we
143 | # want to generate.
144 | capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
145 | sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
146 | # Below we generate the cascades of features on each level of the feature
147 | # pyramid. Assuming sub_levels=3, The layout of the features will be
148 | # as shown in the example on file documentation above.
149 | feature_pyramid: List[torch.Tensor] = []
150 | for i in range(len(image_pyramid)):
151 | features = sub_pyramids[i][0]
152 | for j in range(1, self.sub_levels):
153 | if j <= i:
154 | features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
155 | feature_pyramid.append(features)
156 | return feature_pyramid
157 |
--------------------------------------------------------------------------------
/interpolator.py:
--------------------------------------------------------------------------------
1 | """The film_net frame interpolator main model code.
2 |
3 | Basics
4 | ======
5 | The film_net is an end-to-end learned neural frame interpolator implemented as
6 | a PyTorch model. It has the following inputs and outputs:
7 |
8 | Inputs:
9 | x0: image A.
10 | x1: image B.
11 | time: desired sub-frame time.
12 |
13 | Outputs:
14 | image: the predicted in-between image at the chosen time in range [0, 1].
15 |
16 | Additional outputs include forward and backward warped image pyramids, flow
17 | pyramids, etc., that can be visualized for debugging and analysis.
18 |
19 | Note that many training sets only contain triplets with ground truth at
20 | time=0.5. If a model has been trained with such training set, it will only work
21 | well for synthesizing frames at time=0.5. Such models can only generate more
22 | in-between frames using recursion.
23 |
24 | Architecture
25 | ============
26 | The inference consists of three main stages: 1) feature extraction 2) warping
27 | 3) fusion. On high-level, the architecture has similarities to Context-aware
28 | Synthesis for Video Frame Interpolation [1], but the exact architecture is
29 | closer to Multi-view Image Fusion [2] with some modifications for the frame
30 | interpolation use-case.
31 |
32 | Feature extraction stage employs the cascaded multi-scale architecture described
33 | in [2]. The advantage of this architecture is that coarse level flow prediction
34 | can be learned from finer resolution image samples. This is especially useful
35 | to avoid overfitting with moderately sized datasets.
36 |
37 | The warping stage uses a residual flow prediction idea that is similar to
38 | PWC-Net [3], Multi-view Image Fusion [2] and many others.
39 |
40 | The fusion stage is similar to U-Net's decoder where the skip connections are
41 | connected to warped image and feature pyramids. This is described in [2].
42 |
43 | Implementation Conventions
44 | ====================
45 | Pyramids
46 | --------
47 | Throughtout the model, all image and feature pyramids are stored as python lists
48 | with finest level first followed by downscaled versions obtained by successively
49 | halving the resolution. The depths of all pyramids are determined by
50 | options.pyramid_levels. The only exception to this is internal to the feature
51 | extractor, where smaller feature pyramids are temporarily constructed with depth
52 | options.sub_levels.
53 |
54 | Color ranges & gamma
55 | --------------------
56 | The model code makes no assumptions on whether the images are in gamma or
57 | linearized space or what is the range of RGB color values. So a model can be
58 | trained with different choices. This does not mean that all the choices lead to
59 | similar results. In practice the model has been proven to work well with RGB
60 | scale = [0,1] with gamma-space images (i.e. not linearized).
61 |
62 | [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
63 | [2] Multi-view Image Fusion, Trinidad et al, 2019
64 | [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
65 | """
66 | from typing import Dict, List
67 |
68 | import torch
69 | from torch import nn
70 |
71 | import util
72 | from feature_extractor import FeatureExtractor
73 | from fusion import Fusion
74 | from pyramid_flow_estimator import PyramidFlowEstimator
75 |
76 |
77 | class Interpolator(nn.Module):
78 | def __init__(
79 | self,
80 | pyramid_levels=7,
81 | fusion_pyramid_levels=5,
82 | specialized_levels=3,
83 | sub_levels=4,
84 | filters=64,
85 | flow_convs=(3, 3, 3, 3),
86 | flow_filters=(32, 64, 128, 256),
87 | ):
88 | super().__init__()
89 | self.pyramid_levels = pyramid_levels
90 | self.fusion_pyramid_levels = fusion_pyramid_levels
91 |
92 | self.extract = FeatureExtractor(3, filters, sub_levels)
93 | self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters)
94 | self.fuse = Fusion(sub_levels, specialized_levels, filters)
95 |
96 | def shuffle_images(self, x0, x1):
97 | return [
98 | util.build_image_pyramid(x0, self.pyramid_levels),
99 | util.build_image_pyramid(x1, self.pyramid_levels)
100 | ]
101 |
102 | def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
103 | image_pyramids = self.shuffle_images(x0, x1)
104 |
105 | # Siamese feature pyramids:
106 | feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])]
107 |
108 | # Predict forward flow.
109 | forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1])
110 |
111 | # Predict backward flow.
112 | backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0])
113 |
114 | # Concatenate features and images:
115 |
116 | # Note that we keep up to 'fusion_pyramid_levels' levels as only those
117 | # are used by the fusion module.
118 |
119 | forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
120 |
121 | backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
122 |
123 | # We multiply the flows with t and 1-t to warp to the desired fractional time.
124 | #
125 | # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
126 | # lator for multi-frame interpolation. Below, we create a constant tensor of
127 | # shape [B]. We use the `time` tensor to infer the batch size.
128 | backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt)
129 | forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt)
130 |
131 | pyramids_to_warp = [
132 | util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels],
133 | feature_pyramids[0][:self.fusion_pyramid_levels]),
134 | util.concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels],
135 | feature_pyramids[1][:self.fusion_pyramid_levels])
136 | ]
137 |
138 | # Warp features and images using the flow. Note that we use backward warping
139 | # and backward flow is used to read from image 0 and forward flow from
140 | # image 1.
141 | forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
142 | backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
143 |
144 | aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
145 | backward_warped_pyramid)
146 | aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
147 | aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
148 |
149 | return {
150 | 'image': [self.fuse(aligned_pyramid)],
151 | 'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
152 | 'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
153 | 'forward_flow_pyramid': forward_flow_pyramid,
154 | 'backward_flow_pyramid': backward_flow_pyramid,
155 | }
156 |
157 | def forward(self, x0, x1, batch_dt) -> torch.Tensor:
158 | return self.debug_forward(x0, x1, batch_dt)['image'][0]
159 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------