├── README.md ├── config.py ├── data_lib.py ├── film-pytorch.yaml ├── losses ├── losses.py ├── utils.py └── vgg19_loss.py ├── models ├── feature_extractor.py ├── fusion.py ├── interpolator.py ├── options.py ├── pyramid_flow_estimator.py └── utils.py ├── train.py ├── train_lib.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of FILM: Frame Interpolation for Large Motion (from Google) 2 | - PyTorch Implementation of FILM: Frame Interpolation for Large Motion (https://film-net.github.io/) 3 | - Easy to use, read, debug than original TF code 4 | - It shows comparable performances as reported by original paper (PSNR ~ 34 on Vimeo90K) 5 | - Tensorboard logging for metrics(PSNR, SSIM), generated images (x0, prediction, ground truth, x1) 6 | 7 | ## Requirements 8 | - Python 3.11.0 9 | - Pytorch 1.13.1 10 | - CUDA 11.6 11 | 12 | ## Installation 13 | 14 | * Option 1) Copy created conda environment 15 | ``` 16 | git clone https://github.com/google-research/frame-interpolation 17 | cd frame-interpolation 18 | conda env create -f film-pytorch.yaml 19 | conda activate film-pytorch 20 | ``` 21 | 22 | * Option 2) Install requirements yourself 23 | ``` 24 | git clone https://github.com/google-research/frame-interpolation 25 | cd frame-interpolation 26 | conda env create -n film-pytorch 27 | conda activate film-pytorch 28 | conda install -c conda-forge python 29 | pip install torch==1.13.1+cu116 torchvision --extra-index-url https://download.pytorch.org/whl/cu116 30 | pip install scipy torchmetrics tensorboardX opencv-python tqdm 31 | ``` 32 | 33 | ## Training 34 | - It accepts Vimeo-like data directory for train, you need to pass argument for --train_data 35 | 36 | ``` 37 | film-pytorch 38 | └── datasets 39 | └── vimeo_triplet 40 | ├── sequences 41 | readme.txt 42 | tri_testlist.txt 43 | tri_trainlist.txt 44 | ``` 45 | - for training, you can specify train_data path, batch_size, epoch, resume(for loading checkpoint), exp_name 46 | ``` 47 | python train.py --train_data datasets/vimeo_triplet --exp_name 230115_exp1 --batch_size 8 --epoch 100 --resume 'path to checkpoint' 48 | ``` 49 | 50 | ## Inference 51 | to be released very soon (in progress) 52 | * 1) One mid-frame interpolation 53 | 54 | To generate an intermediate photo from the input near-duplicate photos, simply run: 55 | 56 | ``` 57 | python inference.py --frame1 data/one.png --frame2 data/two.png --model_path pretrained/film_style --output temp/output.png 58 | ``` 59 | 60 | This will produce the sub-frame at `t=0.5` and save as 'photos/output_middle.png'. 61 | 62 | * 2) Many in-between frames interpolation 63 | 64 | It takes in a set of directories identified by a glob (--pattern). Each directory 65 | is expected to contain at least two input frames, with each contiguous frame 66 | pair treated as an input to generate in-between frames. Frames should be named such that when sorted (naturally) with `natsort`, their desired order is unchanged. 67 | 68 | 69 | ``` 70 | python inference.py --data "photos" --model_path pretrained/film_style --times_to_interpolate 3 --output_video 71 | ``` 72 | 73 | You will find the interpolated frames (including the input frames) in 74 | 'photos/interpolated_frames/', and the interpolated video at 75 | 'photos/interpolated.mp4'. 76 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | project_name = 'film_style' 2 | 3 | 4 | data_params = { 5 | 'batch_size': 16 6 | } 7 | 8 | 9 | model_params = { 10 | 'pyramid_levels': 7, 11 | 'fusion_pyramid_levels': 5, 12 | 'specialized_levels': 3, 13 | 'sub_levels': 4, 14 | 'flow_convs': [3, 3, 3, 3], 15 | 'flow_filters': [32, 64, 128, 256], 16 | 'filters': 64 17 | } 18 | 19 | train_params = { 20 | 'learning_rate': 0.0001*0.5, 21 | 'learning_rate_decay_steps': 750000, 22 | 'learning_rate_decay_rate': 0.464158, 23 | 'learning_rate_staircase': True, 24 | 'num_steps': 3000000, 25 | 'weight_decay': 1e-3 26 | } 27 | -------------------------------------------------------------------------------- /data_lib.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os, random, cv2, torch 3 | import numpy as np 4 | from glob import glob 5 | 6 | # adopted from XVFI (https://github.com/JihyongOh/XVFI) 7 | 8 | class Custom_Train(data.Dataset): 9 | def __init__(self, args): 10 | self.args = args 11 | self.t = 0.5 12 | self.framesPath = [] 13 | f = open(os.path.join(args.train_data, 'tri_trainlist.txt'), 14 | 'r') # './datasets/vimeo_triplet/sequences/tri_trainlist.txt' 15 | while True: 16 | scene_path = f.readline().split('\n')[0] 17 | if not scene_path: break 18 | frames_list = sorted(glob(os.path.join(args.train_data, 'sequences', scene_path,'*.png'))) 19 | self.framesPath.append(frames_list) 20 | f.close 21 | self.nScenes = len(self.framesPath) 22 | if self.nScenes == 0: 23 | raise (RuntimeError("Found 0 files in subfolders of: " + args.train_data + "\n")) 24 | print("nScenes of Vimeo train triplet : ", self.nScenes) 25 | 26 | def __getitem__(self, idx): 27 | candidate_frames = self.framesPath[idx] 28 | """ Randomly reverse frames """ 29 | if (random.randint(0, 1)): 30 | frameRange = [0, 2, 1] 31 | else: 32 | frameRange = [2, 0, 1] 33 | # frames : (C, T, H, W) 34 | frames = frames_loader_train(self.args, candidate_frames, frameRange) # including "np2Tensor [-1,1] normalized" 35 | 36 | outputs = {'x0': frames[:,0,:,:], 'x1': frames[:,1,:,:], 'y': frames[:,2,:,:] ,'time': np.expand_dims(np.array(0.5, dtype=np.float32), 0)} 37 | return outputs 38 | 39 | def __len__(self): 40 | return self.nScenes 41 | 42 | def frames_loader_train(args, candidate_frames, frameRange): 43 | frames = [] 44 | for frameIndex in frameRange: 45 | frame = cv2.imread(candidate_frames[frameIndex]) 46 | #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 47 | frames.append(frame) 48 | (ih, iw, c) = frame.shape 49 | frames = np.stack(frames, axis=0) # (T, H, W, 3) 50 | 51 | if args.need_patch: ## random crop 52 | ps = args.patch_size 53 | ix = random.randrange(0, iw - ps + 1) 54 | iy = random.randrange(0, ih - ps + 1) 55 | frames = frames[:, iy:iy + ps, ix:ix + ps, :] 56 | 57 | if random.random() < 0.5: # random horizontal flip 58 | frames = frames[:, :, ::-1, :] 59 | 60 | frames = frames[:, :, :,:] # (T, H, W, 3) , H and W should be divided by 2**(pyramid_levels) 61 | # No vertical flip 62 | 63 | """rot = random.randint(0, 3) # random rotate 64 | frames = np.rot90(frames, rot, (1, 2))""" 65 | 66 | """ np2Tensor [-1,1] normalized """ 67 | frames = RGBframes_np2Tensor(frames, 3) # (C, T, H, W) 68 | 69 | return frames 70 | 71 | def RGBframes_np2Tensor(imgIn, channel): 72 | ## input : T, H, W, C 73 | if channel == 1: 74 | # rgb --> Y (gray) 75 | imgIn = np.sum(imgIn * np.reshape([65.481, 128.553, 24.966], [1, 1, 1, 3]) / 255.0, axis=3, 76 | keepdims=True) + 16.0 77 | 78 | # to Tensor 79 | ts = (3, 0, 1, 2) # dimension order should be [C, T, H, W] 80 | imgIn = torch.Tensor(imgIn.transpose(ts).astype(float)).mul_(1.0) 81 | 82 | # normalization [-1,1] 83 | imgIn = (imgIn / 255.0 - 0.5) * 2 84 | 85 | return imgIn 86 | 87 | def create_training_dataset(args, augmentation_fns=None): 88 | data_train = Custom_Train(args) 89 | dataloader = data.DataLoader(data_train, batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=4, pin_memory=True) 90 | return dataloader 91 | -------------------------------------------------------------------------------- /film-pytorch.yaml: -------------------------------------------------------------------------------- 1 | name: film-pytorch 2 | channels: 3 | - conda-forge 4 | - soumith 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - bzip2=1.0.8=h7f98852_4 10 | - ca-certificates=2022.12.7=ha878542_0 11 | - ld_impl_linux-64=2.39=hcc3a1bd_1 12 | - libffi=3.4.2=h7f98852_5 13 | - libgcc-ng=12.2.0=h65d4601_19 14 | - libgomp=12.2.0=h65d4601_19 15 | - libnsl=2.0.0=h7f98852_0 16 | - libsqlite=3.40.0=h753d276_0 17 | - libuuid=2.32.1=h7f98852_1000 18 | - libzlib=1.2.13=h166bdaf_4 19 | - ncurses=6.3=h27087fc_1 20 | - openssl=3.0.7=h0b41bf4_1 21 | - pip=22.3.1=pyhd8ed1ab_0 22 | - python=3.11.0=ha86cf86_0_cpython 23 | - readline=8.1.2=h0f457ee_0 24 | - setuptools=65.6.3=pyhd8ed1ab_0 25 | - tk=8.6.12=h27826a3_0 26 | - tzdata=2022g=h191b570_0 27 | - wheel=0.38.4=pyhd8ed1ab_0 28 | - xz=5.2.6=h166bdaf_0 29 | - pip: 30 | - absl-py==1.3.0 31 | - cachetools==5.2.1 32 | - certifi==2022.12.7 33 | - charset-normalizer==2.1.1 34 | - google-auth==2.16.0 35 | - google-auth-oauthlib==0.4.6 36 | - grpcio==1.51.1 37 | - idna==3.4 38 | - markdown==3.4.1 39 | - markupsafe==2.1.1 40 | - numpy==1.24.1 41 | - oauthlib==3.2.2 42 | - opencv-python==4.7.0.68 43 | - packaging==23.0 44 | - pillow==9.4.0 45 | - protobuf==3.20.1 46 | - pyasn1==0.4.8 47 | - pyasn1-modules==0.2.8 48 | - requests==2.28.1 49 | - requests-oauthlib==1.3.1 50 | - rsa==4.9 51 | - scipy==1.10.0 52 | - six==1.16.0 53 | - tensorboard==2.11.0 54 | - tensorboard-data-server==0.6.1 55 | - tensorboard-plugin-wit==1.8.1 56 | - tensorboardx==2.5.1 57 | - torch==1.13.1+cu116 58 | - torchmetrics==0.11.0 59 | - torchvision==0.2.0 60 | - tqdm==4.64.1 61 | - typing-extensions==4.4.0 62 | - urllib3==1.26.13 63 | - werkzeug==2.2.2 64 | prefix: /home/nvadmin/anaconda3/envs/film-pytorch 65 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | 2 | from .vgg19_loss import PerceptualLoss 3 | import torch 4 | import numpy as np 5 | 6 | from torchmetrics import PeakSignalNoiseRatio 7 | from torchmetrics.functional import structural_similarity_index_measure as ssim 8 | 9 | PSNR = PeakSignalNoiseRatio().cuda() 10 | vgg = PerceptualLoss().cuda() 11 | 12 | 13 | def vgg_loss(example, prediction): 14 | return vgg(prediction['image'], example['y'])[0] 15 | 16 | def style_loss(example, prediction): 17 | return vgg(prediction['image'], example['y'])[1] 18 | 19 | def perceptual_loss(example, prediction): 20 | return sum(vgg(prediction['image'], example['y'])) 21 | 22 | def l1_loss(example, prediction): 23 | return torch.mean(torch.abs(prediction['image'] - example['y'])) 24 | 25 | def l1_warped_loss(example, prediction): 26 | loss = torch.zeros(1, dtype=torch.float32) 27 | if 'x0_warped' in prediction: 28 | loss += torch.mean(torch.abs(prediction['x0_warped'] - example['y'])) 29 | if 'x1_warped' in prediction: 30 | loss += torch.mean(torch.abs(prediction['x1_warped'] - example['y'])) 31 | return loss 32 | 33 | def l2_loss(example, prediction): 34 | return torch.mean(torch.square(prediction['image'] - example['y'])) 35 | 36 | def ssim_loss(example, prediction): 37 | return ssim(prediction['image'], example['y']) # to do : max_val=1.0 38 | 39 | def psnr_loss(example, prediction): 40 | return PSNR(prediction['image'], example['y']) 41 | 42 | def get_loss(loss_name): 43 | if loss_name == 'l1': 44 | return l1_loss 45 | elif loss_name == 'l2': 46 | return l2_loss 47 | elif loss_name == 'ssim': 48 | return ssim_loss 49 | elif loss_name == 'vgg': 50 | return vgg_loss 51 | elif loss_name == 'style': 52 | return style_loss 53 | elif loss_name == 'psnr': 54 | return psnr_loss 55 | elif loss_name == 'l1_warped': 56 | return l1_warped_loss 57 | elif loss_name == 'perceptual': 58 | return perceptual_loss 59 | else: 60 | raise ValueError('Invalid loss function %s' % loss_name) 61 | 62 | def get_loss_op(loss_name): 63 | loss = get_loss(loss_name) 64 | return lambda example, prediction: loss(example, prediction) 65 | 66 | def get_weight_op(weight_schedule): 67 | return lambda iterations: weight_schedule(iterations) 68 | 69 | def create_losses(loss_names, loss_weight=None): 70 | losses = dict() 71 | for name in (loss_names): # to do : loss_weight 72 | """#unique_values = np.unique(weight_schedule.values) 73 | #if len(unique_values) == 1: #and unique_values[0] == 1.0: 74 | # weighted_name = name 75 | #else: 76 | # weighted_name = 'k*' + name 77 | #losses[weighted_name] = (get_loss_op(name), get_weight_op(weight_schedule)) # to do 78 | #print(f"name {str(name)}")""" 79 | losses[name] = (get_loss_op(name)) 80 | return losses 81 | 82 | def training_losses(loss_names, loss_weights=None, loss_weight_schedules=None, loss_weight_parameters=None): 83 | weight_schedules = [] # to do 84 | """if not loss_weights: 85 | for weight_schedule, weight_parameters in zip(loss_weight_schedules, loss_weight_parameters): 86 | weight_schedules.append(weight_schedule(**weight_parameters)) 87 | else: 88 | for loss_weight in loss_weights: 89 | weight_parameters = { 90 | 'boundaries': [0], 91 | 'values': 2 * [loss_weight,] 92 | } 93 | weight_schedules.append(torch.optim.lr_scheduler.ConstantLR(optimizer)) # to do : lr parameter""" 94 | return create_losses(loss_names, weight_schedules) 95 | -------------------------------------------------------------------------------- /losses/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AntiAliasInterpolation2d(nn.Module): 6 | """ 7 | Band-limited downsampling, for better preservation of the input signal. 8 | """ 9 | 10 | def __init__(self, channels, scale): 11 | super(AntiAliasInterpolation2d, self).__init__() 12 | sigma = (1 / scale - 1) / 2 13 | kernel_size = 2 * round(sigma * 4) + 1 14 | self.ka = kernel_size // 2 15 | self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka 16 | 17 | kernel_size = [kernel_size, kernel_size] 18 | sigma = [sigma, sigma] 19 | # The gaussian kernel is the product of the 20 | # gaussian function of each dimension. 21 | kernel = 1 22 | meshgrids = torch.meshgrid( 23 | [ 24 | torch.arange(size, dtype=torch.float32) 25 | for size in kernel_size 26 | ] 27 | ) 28 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 29 | mean = (size - 1) / 2 30 | kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) 31 | 32 | # Make sure sum of values in gaussian kernel equals 1. 33 | kernel = kernel / torch.sum(kernel) 34 | # Reshape to depthwise convolutional weight 35 | kernel = kernel.view(1, 1, *kernel.size()) 36 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)).cuda() 37 | 38 | self.register_buffer('weight', kernel) 39 | self.groups = channels 40 | self.scale = scale 41 | inv_scale = 1 / scale 42 | self.int_inv_scale = int(inv_scale) 43 | 44 | def forward(self, input): 45 | if self.scale == 1.0: 46 | return input 47 | 48 | out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) 49 | out = F.conv2d(out, weight=self.weight, groups=self.groups) 50 | out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] 51 | 52 | return out 53 | -------------------------------------------------------------------------------- /losses/vgg19_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from .utils import AntiAliasInterpolation2d 7 | 8 | class Vgg19(nn.Module): 9 | def __init__(self, requires_grad=False): 10 | super(Vgg19, self).__init__() 11 | vgg_pretrained_features = models.vgg19(pretrained=True).features 12 | self.slice1 = torch.nn.Sequential() 13 | self.slice2 = torch.nn.Sequential() 14 | self.slice3 = torch.nn.Sequential() 15 | self.slice4 = torch.nn.Sequential() 16 | self.slice5 = torch.nn.Sequential() 17 | for x in range(2): 18 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 19 | for x in range(2, 7): 20 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(7, 12): 22 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(12, 21): 24 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(21, 30): 26 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 27 | 28 | self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), 29 | requires_grad=False) 30 | self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), 31 | requires_grad=False) 32 | 33 | if not requires_grad: 34 | for param in self.parameters(): 35 | param.requires_grad = False 36 | 37 | def forward(self, x): 38 | x = (x - self.mean) / self.std 39 | h_relu1 = self.slice1(x) 40 | h_relu2 = self.slice2(h_relu1) 41 | h_relu3 = self.slice3(h_relu2) 42 | h_relu4 = self.slice4(h_relu3) 43 | h_relu5 = self.slice5(h_relu4) 44 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 45 | return out 46 | 47 | 48 | class ImagePyramide(torch.nn.Module): 49 | def __init__(self, scales, num_channels): 50 | super(ImagePyramide, self).__init__() 51 | downs = [] 52 | for scale in scales: 53 | downs.append(AntiAliasInterpolation2d(num_channels, scale)) 54 | self.downs = downs 55 | 56 | def forward(self, x): 57 | out_dict = [] 58 | for down_module in self.downs: 59 | out_dict.append(down_module(x)) 60 | return out_dict 61 | 62 | 63 | class PerceptualLoss(nn.Module): 64 | def __init__(self, 65 | scales=[1, 0.5, 0.25, 0.125], 66 | loss_weights=[1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 67 | ): 68 | super(PerceptualLoss, self).__init__() 69 | self.pyramid = ImagePyramide(scales, 3) 70 | self.vgg = Vgg19() 71 | self.scales = scales 72 | self.loss_weights = loss_weights 73 | self.l1 = nn.L1Loss() 74 | self.l2 = nn.MSELoss() 75 | 76 | def perceptual(self, p_vgg, g_vgg): 77 | loss = 0 78 | for p, g in zip(p_vgg, g_vgg): 79 | for i, weight in enumerate(self.loss_weights): 80 | loss += weight * (self.l1(p, g).mean()) 81 | return loss 82 | def gram(self, p_vgg, g_vgg): 83 | loss = 0 84 | for p, g in zip(p_vgg, g_vgg): 85 | for i, weight in enumerate(self.loss_weights): 86 | loss += weight * (self.l2(p, g).mean()) 87 | return loss 88 | def forward(self, pred, gt): 89 | preds = self.pyramid(pred) 90 | gts = self.pyramid(gt) 91 | 92 | perceptual_loss, style_loss = 0, 0 93 | for p, g in zip(preds, gts): 94 | p_vgg = self.vgg(p) 95 | g_vgg = self.vgg(g) 96 | 97 | perceptual_loss += self.perceptual(p_vgg, g_vgg) 98 | style_loss += self.gram(self.compute_gram(p_vgg), self.compute_gram(g_vgg)) 99 | 100 | return (perceptual_loss , style_loss) 101 | 102 | def compute_gram(self, feature_pyramid): 103 | gram = [] 104 | for x in feature_pyramid: 105 | #print(f"feature {x.shape}") 106 | b, c, h, w = x.shape 107 | f = x.view(b, c, w * h) 108 | f_T = f.transpose(1,2) 109 | G = f.bmm(f_T) / (h * w * c) 110 | gram.append(G) 111 | return gram 112 | -------------------------------------------------------------------------------- /models/feature_extractor.py: -------------------------------------------------------------------------------- 1 | 2 | from .options import Options 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SubTreeExtractor(nn.Module): 9 | def __init__(self, config): 10 | super(SubTreeExtractor, self).__init__() 11 | k = config.filters # 64 filters 12 | n = config.sub_levels # 4 13 | self.convs = nn.ModuleList() 14 | self.convs.append(nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size=3, padding='same')) 15 | self.convs.append(nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size=3, padding='same')) 16 | self.convs.append(nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=3, padding='same')) 17 | self.convs.append(nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size=3, padding='same')) 18 | 19 | for i in range(0, n-2): # todo : need to find in_channels[i] 20 | self.convs.append(nn.Conv2d(in_channels = (k << (i)), out_channels = (k << (i+1)), kernel_size=3, padding='same')) 21 | self.convs.append(nn.Conv2d(in_channels = (k << (i+1)), out_channels = (k << (i+1)), kernel_size=3, padding='same')) 22 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2) 23 | self.leaky_relu = nn.LeakyReLU(0.2) 24 | 25 | def forward(self, image, n): 26 | 27 | head = image 28 | pyramid = [] 29 | for i in range(n): 30 | head = self.leaky_relu(self.convs[2*i](head)) 31 | head = self.leaky_relu(self.convs[2*i+1](head)) 32 | pyramid.append(head) 33 | if i < n-1: 34 | head = self.pool(head) 35 | return pyramid 36 | 37 | 38 | class FeatureExtractor(nn.Module): 39 | def __init__(self, config): 40 | super(FeatureExtractor, self).__init__() 41 | self.extract_sublevels = SubTreeExtractor(config) 42 | self.options = config 43 | 44 | def forward(self, image_pyramid): 45 | sub_pyramids = [] 46 | for i in range(len(image_pyramid)): 47 | capped_sub_levels = min(len(image_pyramid) - i, self.options.sub_levels) # (4, 4, 4, 4, 3, 2, 1) 48 | sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels)) 49 | feature_pyramid = [] 50 | for i in range(len(image_pyramid)): 51 | features = sub_pyramids[i][0] 52 | for j in range(1, self.options.sub_levels): 53 | if j <= i: 54 | features = torch.concat([features, sub_pyramids[i-j][j]], axis=1) # we concat feature pyramid along channel axis 55 | feature_pyramid.append(features) 56 | return feature_pyramid 57 | -------------------------------------------------------------------------------- /models/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Fusion(nn.Module): 6 | def __init__(self, config): 7 | super(Fusion, self).__init__() 8 | 9 | self.convs = nn.ModuleList() 10 | self.levels = config.fusion_pyramid_levels # 5 11 | 12 | for i in range(config.fusion_pyramid_levels - 1): # (0, 1, 2, 3) 13 | m = config.specialized_levels # 3 14 | k = config.filters # 64 15 | num_filters = (k << i) if i < m else (k << m) 16 | fusion_in_channels=[128, 256, 512, 970] 17 | fusion_middle_channels=[138, 330, 714,1482] 18 | convs = nn.ModuleList() 19 | convs.append(nn.Conv2d(in_channels = fusion_in_channels[i], out_channels = num_filters, kernel_size=[2, 2], padding='same')) 20 | convs.append(nn.Conv2d(in_channels = fusion_middle_channels[i], out_channels = num_filters, kernel_size=[3, 3], padding='same')) 21 | convs.append(nn.Conv2d(in_channels = num_filters, out_channels = num_filters, kernel_size=[3, 3], padding='same')) 22 | self.convs.append(convs) 23 | 24 | self.output_conv = nn.Conv2d(in_channels = 64, out_channels = 3, kernel_size=1) 25 | self.leaky_relu = nn.LeakyReLU(0.2) 26 | 27 | def forward(self, pyramid): 28 | if len(pyramid) != self.levels: 29 | raise ValueError( 30 | 'Fusion called with different number of pyramid levels ' 31 | f'{len(pyramid)} than it was configured for, {self.levels}.') 32 | 33 | net = pyramid[-1] 34 | for i in reversed(range(0, self.levels-1)): 35 | level_size = pyramid[i].shape[1:3] 36 | net = F.interpolate(net, scale_factor=2, mode='nearest') 37 | net = self.convs[i][0](net) 38 | net = torch.concat([pyramid[i], net], axis=1) 39 | net = self.leaky_relu(self.convs[i][1](net)) 40 | net = self.leaky_relu(self.convs[i][2](net)) 41 | net = self.output_conv(net) 42 | return net 43 | -------------------------------------------------------------------------------- /models/interpolator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .feature_extractor import FeatureExtractor 5 | from .fusion import Fusion 6 | from .pyramid_flow_estimator import PyramidFlowEstimator 7 | from . import utils 8 | 9 | def create_model(config): 10 | return FILM_interpolator(config) 11 | 12 | class FILM_interpolator(nn.Module): 13 | 14 | def __init__(self, config): 15 | super(FILM_interpolator, self).__init__() 16 | self.config = config 17 | self.feature_extractor = FeatureExtractor(self.config) 18 | self.flow_estimator = PyramidFlowEstimator(self.config) 19 | self.fuse = Fusion(self.config) 20 | 21 | def forward(self, batch): 22 | 23 | x0, x1, time, y = batch['x0'], batch['x1'], batch['time'], batch['y'] 24 | 25 | image_pyramids = [ 26 | utils.build_image_pyramid(x0, self.config), 27 | utils.build_image_pyramid(x1, self.config) 28 | ] 29 | 30 | feature_pyramids = [self.feature_extractor(image_pyramids[0]), self.feature_extractor(image_pyramids[1])] 31 | 32 | forward_residual_flow = self.flow_estimator(feature_pyramids[0], feature_pyramids[1]) 33 | backward_residual_flow = self.flow_estimator(feature_pyramids[1], feature_pyramids[0]) 34 | 35 | fusion_pyramid_levels = self.config.fusion_pyramid_levels 36 | forward_flow_pyramid = utils.flow_pyramid_synthesis(forward_residual_flow)[:fusion_pyramid_levels] 37 | backward_flow_pyramid = utils.flow_pyramid_synthesis(backward_residual_flow)[:fusion_pyramid_levels] 38 | 39 | mid_time = torch.ones_like(time) * 0.5 40 | backward_flow = utils.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0]) 41 | forward_flow = utils.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0]) 42 | 43 | pyramids_to_warp = [ # fusion_pyramid_levels: 5 44 | utils.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels], feature_pyramids[0][:fusion_pyramid_levels]), 45 | utils.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels], feature_pyramids[1][:fusion_pyramid_levels]) 46 | ] 47 | 48 | forward_warped_pyramid = utils.pyramid_warp(pyramids_to_warp[0], backward_flow) 49 | backward_warped_pyramid = utils.pyramid_warp(pyramids_to_warp[1], forward_flow) 50 | 51 | aligned_pyramid = utils.concatenate_pyramids(forward_warped_pyramid, backward_warped_pyramid) 52 | aligned_pyramid = utils.concatenate_pyramids(aligned_pyramid, backward_flow) 53 | aligned_pyramid = utils.concatenate_pyramids(aligned_pyramid, forward_flow) 54 | 55 | prediction = self.fuse(aligned_pyramid) 56 | 57 | output_color = prediction[...,:] 58 | outputs = {'image': output_color} 59 | 60 | if self.config.use_aux_outputs: 61 | outputs.update({ 62 | 'x0_warped': forward_warped_pyramid[0][..., 0:3], 63 | 'x1_warped': backward_warped_pyramid[0][..., 0:3], 64 | 'forward_residual_flow_pyramid': forward_residual_flow, 65 | 'backward_residual_flow_pyramid': backward_residual_flow, 66 | 'forward_flow_pyramid': forward_flow_pyramid, 67 | 'backward_flow_pyramid': backward_flow_pyramid, 68 | }) 69 | 70 | return outputs 71 | 72 | 73 | -------------------------------------------------------------------------------- /models/options.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Options(nn.Module): 4 | def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, flow_convs=[3, 3, 3, 3], fusion_in_channels=[32, 64, 128, 256, 512], flow_filters=[32, 64, 128, 256], sub_levels=4, filters=64, use_aux_outputs=True): 5 | super(Options, self).__init__() 6 | self.pyramid_levels = pyramid_levels 7 | self.fusion_pyramid_levels = fusion_pyramid_levels 8 | self.specialized_levels = specialized_levels 9 | self.flow_convs = flow_convs #or [4, 4, 4, 4] 10 | self.flow_filters = flow_filters #or [64, 128, 256, 256] 11 | self.sub_levels = sub_levels 12 | self.filters = filters 13 | self.use_aux_outputs = use_aux_outputs 14 | self.fusion_in_channels = fusion_in_channels 15 | 16 | -------------------------------------------------------------------------------- /models/pyramid_flow_estimator.py: -------------------------------------------------------------------------------- 1 | 2 | from .utils import warp 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class FlowEstimator(nn.Module): 8 | def __init__(self, num_convs, feature_levels, num_filters): 9 | super(FlowEstimator, self).__init__() 10 | #print(f"num_convs {num_convs} num_filters {num_filters}") 11 | feature_pyramids = [64, 192, 448, 960] 12 | self._convs = nn.ModuleList() 13 | self._convs.append(nn.Conv2d(in_channels = feature_pyramids[feature_levels], out_channels = num_filters, kernel_size=3, padding='same')) 14 | for i in range(1, num_convs): 15 | self._convs.append(nn.Conv2d(in_channels = num_filters, out_channels = num_filters, kernel_size=3, padding='same')) 16 | self._convs.append(nn.Conv2d(in_channels = num_filters, out_channels = num_filters//2, kernel_size=1, padding='same')) 17 | self._convs.append(nn.Conv2d(in_channels = num_filters//2, out_channels=2, kernel_size=1, padding='same')) 18 | self.leaky_relu = nn.LeakyReLU(0.2) 19 | 20 | def forward(self, features_a, features_b): 21 | net = torch.concat([features_a, features_b], axis=1) 22 | for i in range(len(self._convs)-1): 23 | conv_ = self._convs[i] 24 | net = self.leaky_relu(conv_(net)) 25 | conv_ = self._convs[-1] 26 | net = conv_(net) 27 | return net 28 | 29 | """ for your convenience """ 30 | # flow_convs=[3, 3, 3, 3] 31 | # specialized_levels=3 32 | # flow_filters=[32, 64, 128, 256] 33 | # 'pyramid_levels': 7 34 | 35 | class PyramidFlowEstimator(nn.Module): 36 | def __init__(self, config): 37 | super(PyramidFlowEstimator, self).__init__() 38 | self._predictors = nn.ModuleList() 39 | for i in range(config.specialized_levels): # 3 (0, 1, 2) 40 | self._predictors.append(FlowEstimator(num_convs=config.flow_convs[i], feature_levels=i, num_filters=config.flow_filters[i])) 41 | shared_predictor = FlowEstimator(num_convs=config.flow_convs[-1], feature_levels=config.specialized_levels, num_filters=config.flow_filters[-1]) 42 | for i in range(config.specialized_levels, config.pyramid_levels): 43 | self._predictors.append(shared_predictor) 44 | 45 | def forward(self, feature_pyramid_a, feature_pyramid_b): 46 | 47 | levels = len(feature_pyramid_a) 48 | v = self._predictors[-1](feature_pyramid_a[-1], feature_pyramid_b[-1]) 49 | residuals = [v] 50 | for i in reversed(range(0, levels-1)): 51 | v = F.interpolate(v, scale_factor=2) 52 | warped = warp(feature_pyramid_b[i], v) 53 | v_residual = self._predictors[i](feature_pyramid_a[i], warped) 54 | residuals.append(v_residual) 55 | v = v_residual + v 56 | return list(reversed(residuals)) 57 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .options import Options 7 | 8 | def build_image_pyramid(image, options): 9 | 10 | levels = options.pyramid_levels 11 | pyramid = [] 12 | pool = nn.AvgPool2d(kernel_size=2, stride=2) 13 | for i in range(0, levels): 14 | pyramid.append(image) 15 | if i < levels-1: 16 | image = pool(image) 17 | return pyramid 18 | 19 | def warp(image, flow): 20 | warped = F.grid_sample(image, torch.permute(flow, (0, 2, 3, 1)), align_corners=False) 21 | return warped 22 | 23 | def multiply_pyramid(pyramid, scalar): 24 | return [ torch.permute(torch.permute(image, (1, 2, 3, 0)) * scalar, [3, 0, 1, 2]) for image in pyramid] 25 | 26 | def flow_pyramid_synthesis(residual_pyramid): 27 | flow = residual_pyramid[-1] 28 | flow_pyramid = [flow] 29 | for residual_flow in reversed(residual_pyramid[:-1]): 30 | level_size = residual_flow.shape[1:3] 31 | flow = F.interpolate(flow, scale_factor=2) 32 | flow = residual_flow + flow 33 | flow_pyramid.append(flow) 34 | return list(reversed(flow_pyramid)) 35 | 36 | def pyramid_warp(feature_pyramid, flow_pyramid): 37 | warped_feature_pyramid = [] 38 | for features, flow in zip(feature_pyramid, flow_pyramid): 39 | warped_feature_pyramid.append(warp(features, flow)) 40 | return warped_feature_pyramid 41 | 42 | def concatenate_pyramids(pyramid1, pyramid2): 43 | result = [] 44 | for features1, features2 in zip(pyramid1, pyramid2): 45 | result.append(torch.concat([features1, features2], axis=1)) 46 | return result 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os, argparse, torch 3 | from torch.optim.lr_scheduler import ExponentialLR 4 | from torch.optim import Adam 5 | from tensorboardX import SummaryWriter 6 | from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 7 | 8 | import data_lib 9 | import train_lib 10 | from losses import losses 11 | from config import * 12 | from models import interpolator as film_net_interpolator 13 | from models.interpolator import FILM_interpolator 14 | from models import options as film_net_options 15 | from utils import load_checkpoint 16 | 17 | def create_model(): 18 | 19 | options = film_net_options.Options() 20 | return film_net_interpolator.create_model(options) 21 | 22 | def parse_args(): 23 | desc = "Pytorch implementation for FILM" 24 | parser = argparse.ArgumentParser(description=desc) 25 | parser.add_argument('--exp_name', type=str, default='230111', help='experiment name') 26 | parser.add_argument('--resume', type=str, default=None, help='checkpoint path to resume training') 27 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint_dir', help='path to save checkpoint') 28 | parser.add_argument('--log_dir', type=str, default='log_dir', help='path to save tensorboard log') 29 | parser.add_argument('--train_data', type=str, default='./datasets/vimeo_triplet', help='path to train data') 30 | parser.add_argument('--batch_size', type=int, default=8, help='batch size') 31 | parser.add_argument('--epoch', type=int, default=100, help='batch size') 32 | parser.add_argument('--log_img', type=str, default='log_img', help='path to save image') 33 | parser.add_argument('--need_patch', type=bool, default=False, help='whether to use patch or full resol. image') 34 | parser.add_argument('--patch_size', type=int, default=256, help='patch size') 35 | 36 | return parser.parse_args() 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | summary_writer = SummaryWriter(os.path.join(args.log_dir, args.exp_name)) 41 | 42 | options = film_net_options.Options() 43 | model = FILM_interpolator(options).cuda() 44 | 45 | optimizer = Adam(model.parameters(), lr = train_params['learning_rate'], betas=(0.9, 0.999), weight_decay = train_params['weight_decay']) 46 | 47 | if args.resume: 48 | global_step = load_checkpoint(args.resume, model, optimizer) 49 | 50 | train_lib.train( 51 | args, 52 | model=model, 53 | summary=summary_writer, 54 | optimizer=optimizer, 55 | create_losses_fn=losses.training_losses, 56 | #create_metrics_fn=metrics_lib.create_metrics_fn, 57 | dataloader=data_lib.create_training_dataset(args, augmentation_fns=None), 58 | #eval_datasets=data_lib.create_eval_datasets(), 59 | #resume=args.resume 60 | ) 61 | 62 | -------------------------------------------------------------------------------- /train_lib.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from torchmetrics.functional import structural_similarity_index_measure 3 | from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 4 | from torchvision.utils import save_image 5 | import torch 6 | 7 | from utils import to_gpu, save_checkpoint, load_checkpoint, metrics, log_image 8 | 9 | def train(args, model, summary, optimizer, create_losses_fn, dataloader, eval_loop_fn=None, eval_datasets=None, resume=None): 10 | 11 | loss_functions = create_losses_fn(['l1', 'perceptual']) 12 | PSNR = PeakSignalNoiseRatio().cuda() 13 | SSIM = StructuralSimilarityIndexMeasure().cuda() 14 | 15 | 16 | model.train() 17 | global_step = 0 18 | for epoch in range(args.epoch): 19 | 20 | save_checkpoint(args, model, optimizer, epoch) 21 | metric_psnr, metric_ssim = 0, 0 22 | 23 | for i, batch in enumerate(tqdm(dataloader)): 24 | batch = to_gpu(batch) 25 | model.zero_grad() 26 | predictions = model(to_gpu(batch)) 27 | optimizer.zero_grad() 28 | losses = [] 29 | for (loss_function) in loss_functions: 30 | loss = loss_functions[loss_function](batch, predictions) 31 | losses.append(loss) 32 | loss = sum(losses) 33 | loss.backward() 34 | optimizer.step() 35 | global_step+=1 36 | 37 | summary.add_scalar('train/loss', float(loss), global_step=global_step) 38 | 39 | if i % 100==0: 40 | log_image(batch, predictions, args, summary, epoch, i, global_step) 41 | with torch.no_grad(): 42 | psnr, ssim = metrics(predictions, batch, summary, PSNR, SSIM, global_step) 43 | metric_psnr += psnr 44 | metric_ssim += ssim 45 | 46 | summary.add_scalar('train/psnr_epoch', float(metric_psnr/(len(dataloader))*100), global_step=epoch) 47 | summary.add_scalar('train/ssim_epoch', float(metric_ssim/(len(dataloader))*100), global_step=epoch) 48 | 49 | 50 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | import numpy as np 3 | from torchmetrics.functional import structural_similarity_index_measure 4 | from torchmetrics import PeakSignalNoiseRatio 5 | from torchvision.utils import save_image 6 | import cv2 7 | 8 | def to_gpu(batch): 9 | batch = {'x0': batch['x0'].cuda(non_blocking=True), 'x1': batch['x1'].cuda(non_blocking=True), 'y': batch['y'].cuda(non_blocking=True), 'time': batch['time'].cuda(non_blocking=True)} 10 | return batch 11 | 12 | def save_checkpoint(args, model, optimizer, step): 13 | save_dir = os.path.join(args.checkpoint_dir, args.exp_name) 14 | os.makedirs(save_dir, exist_ok=True) 15 | checkpoint = {'step': step, 16 | 'state_dict': model.state_dict(), 17 | 'optimizer': optimizer.state_dict()} 18 | torch.save(checkpoint, os.path.join(save_dir, f'checkpoint_latest.pt')) 19 | if step % 10 ==0: 20 | torch.save(checkpoint, os.path.join(save_dir, f'checkpoint_{step}.pt')) 21 | 22 | def load_checkpoint(path, model, optimizer): 23 | checkpoint = torch.load(path) 24 | for key in list(checkpoint['state_dict'].keys()): 25 | checkpoint['state_dict'][key.replace('module.','')] = checkpoint['state_dict'].pop(key) 26 | model.load_state_dict(checkpoint['state_dict']) 27 | optimizer.load_state_dict(checkpoint['optimizer']) 28 | return checkpoint['step'] 29 | 30 | def metrics(predictions, batch, summary, PSNR, SSIM, global_step): 31 | psnr = PSNR(predictions['image'], batch['y']) 32 | ssim = structural_similarity_index_measure(predictions['image'], batch['y']) 33 | summary.add_scalar('train/psnr', float(psnr), global_step=global_step) 34 | summary.add_scalar('train/ssim', float(ssim), global_step=global_step) 35 | return psnr, ssim 36 | 37 | def log_image(batch, predictions, args, summary, epoch, i, global_step): 38 | b, c, h, w = batch['x0'].shape 39 | img = np.zeros((h, w*4, c), dtype=np.uint8) 40 | x0 = denorm255(batch['x0'][0]) 41 | prediction = denorm255(predictions['image'][0]) 42 | ground_truth = denorm255(batch['y'][0]) 43 | x1 = denorm255(batch['x1'][0]) 44 | 45 | img[:,:w, :] = np.transpose(x0.detach().cpu().numpy(), (1, 2, 0)).astype(np.uint8) 46 | img[:,w:2*w, :] = np.transpose(prediction.detach().cpu().numpy(), (1, 2, 0)).astype(np.uint8) 47 | img[:,2*w:3*w, :] = np.transpose(ground_truth.detach().cpu().numpy(), (1, 2, 0)).astype(np.uint8) 48 | img[:,3*w:4*w, :] = np.transpose(x1.detach().cpu().numpy(), (1, 2, 0)).astype(np.uint8) 49 | summary.add_image('(x0, prediction, ground truth, x1', img[:,:,::-1], global_step=global_step, dataformats='HWC') 50 | save_dir = os.path.join(args.log_img, args.exp_name) 51 | os.makedirs(save_dir,exist_ok=True) 52 | cv2.imwrite(os.path.join(save_dir, f'epoch_{epoch}_iter_{i}.png'), img) 53 | 54 | def denorm255(x): 55 | out = (x + 1.0) / 2.0 56 | return out.clamp_(0.0, 1.0) * 255.0 57 | --------------------------------------------------------------------------------