├── LICENSE ├── README.md ├── benchmarks ├── Middlebury_Other.py ├── SNU_FILM.py ├── UCF101.py ├── Vimeo90K.py └── speed_parameters.py ├── datasets.py ├── demo_2x.py ├── demo_8x.py ├── figures ├── 8x_interpolation.png ├── benchmarks.png ├── fig1_1.gif ├── fig1_2.gif ├── fig1_3.gif ├── fig2_1.gif ├── fig2_2.gif ├── img0.png ├── img1.png ├── img_overlaid.png ├── middlebury.png ├── middlebury_other.png ├── out_2x.gif ├── out_8x.gif └── vimeo90k.png ├── generate_flow.py ├── liteflownet ├── README.md ├── correlation │ ├── README.md │ └── correlation.py ├── requirements.txt └── run.py ├── loss.py ├── metric.py ├── models ├── IFRNet.py ├── IFRNet_L.py └── IFRNet_S.py ├── train_gopro.py ├── train_vimeo90k.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Lingtong Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation 2 | The official PyTorch implementation of [IFRNet](https://arxiv.org/abs/2205.14620) (CVPR 2022). 3 | 4 | Authors: [Lingtong Kong](https://scholar.google.com.hk/citations?user=KKzKc_8AAAAJ&hl=zh-CN), [Boyuan Jiang](https://byjiang.com/), Donghao Luo, Wenqing Chu, Xiaoming Huang, [Ying Tai](https://tyshiwo.github.io/), Chengjie Wang, [Jie Yang](http://www.pami.sjtu.edu.cn/jieyang) 5 | 6 | ## Highlights 7 | Existing flow-based frame interpolation methods almost all first estimate or model intermediate optical flow, and then use flow warped context features to synthesize target frame. However, they ignore the mutual promotion of intermediate optical flow and intermediate context feature. Also, their cascaded architecture can substantially increase the inference delay and model parameters, blocking them from lots of mobile and real-time applications. For the first time, we merge above separated flow estimation and context feature refinement into a single encoder-decoder based IFRNet for compactness and fast inference, where these two crucial elements can benefit from each other. Moreover, task-oriented flow distillation loss and feature space geometry consistency loss are newly proposed to promote intermediate motion estimation and intermediate feature reconstruction of IFRNet, respectively. Benchmark results demonstrate that our IFRNet not only achieves state-of-the-art VFI accuracy, but also enjoys fast inference speed and lightweight model size. 8 | 9 | ![](./figures/vimeo90k.png) 10 | 11 | ## YouTube Demos 12 | [[4K60p] うたわれるもの 偽りの仮面 OP フレーム補間+超解像 (IFRnetとReal-CUGAN)](https://www.youtube.com/watch?v=tV2imgGS-5Q) 13 | 14 | [[4K60p] 天神乱漫 -LUCKY or UNLUCKY!?- OP (IFRnetとReal-CUGAN)](https://www.youtube.com/watch?v=NtpJqDZaM-4) 15 | 16 | [RIFE IFRnet 比較](https://www.youtube.com/watch?v=lHqnOQgpZHQ) 17 | 18 | [IFRNet frame interpolation](https://www.youtube.com/watch?v=ygSdCCZCsZU) 19 | 20 | ## Preparation 21 | 1. PyTorch >= 1.3.0 (We have verified that this repository supports Python 3.6/3.7, PyTorch 1.3.0/1.9.1). 22 | 2. Download training and test datasets: [Vimeo90K](http://toflow.csail.mit.edu/), [UCF101](https://liuziwei7.github.io/projects/VoxelFlow), [SNU-FILM](https://myungsub.github.io/CAIN/), [Middlebury](https://vision.middlebury.edu/flow/data/), [GoPro](https://seungjunnah.github.io/Datasets/gopro.html) and [Adobe240](http://www.cs.ubc.ca/labs/imager/tr/2017/DeepVideoDeblurring/). 23 | 3. Set the right dataset path on your machine. 24 | 25 | ## Download Pre-trained Models and Play with Demos 26 | Figures from left to right are overlaid input frames, 2x and 8x video interpolation results respectively. 27 |

28 | 29 | 30 | 31 |

32 | 33 | 1. Download our pre-trained models in this [link](https://www.dropbox.com/sh/hrewbpedd2cgdp3/AADbEivu0-CKDQcHtKdMNJPJa?dl=0), and then put file checkpoints into the root dir. 34 | 35 | 2. Run the following scripts to generate 2x and 8x frame interpolation demos 36 |
$ python demo_2x.py
 37 | $ python demo_8x.py
38 | 39 | 40 | ## Training on Vimeo90K Triplet Dataset for 2x Frame Interpolation 41 | 1. First, run this script to generate optical flow pseudo label 42 |
$ python generate_flow.py
43 | 44 | 2. Then, start training by executing one of the following commands with selected model 45 |
$ python -m torch.distributed.launch --nproc_per_node=4 train_vimeo90k.py --world_size 4 --model_name 'IFRNet' --epochs 300 --batch_size 6 --lr_start 1e-4 --lr_end 1e-5
 46 | $ python -m torch.distributed.launch --nproc_per_node=4 train_vimeo90k.py --world_size 4 --model_name 'IFRNet_L' --epochs 300 --batch_size 6 --lr_start 1e-4 --lr_end 1e-5
 47 | $ python -m torch.distributed.launch --nproc_per_node=4 train_vimeo90k.py --world_size 4 --model_name 'IFRNet_S' --epochs 300 --batch_size 6 --lr_start 1e-4 --lr_end 1e-5
48 | 49 | ## Benchmarks for 2x Frame Interpolation 50 | To test running time and model parameters, you can run 51 |
$ python benchmarks/speed_parameters.py
52 | 53 | To test frame interpolation accuracy on Vimeo90K, UCF101 and SNU-FILM datasets, you can run 54 |
$ python benchmarks/Vimeo90K.py
 55 | $ python benchmarks/UCF101.py
 56 | $ python benchmarks/SNU_FILM.py
57 | 58 | ## Quantitative Comparison for 2x Frame Interpolation 59 | Proposed IFRNet achieves state-of-the-art frame interpolation accuracy with less inference time and computation complexity. We expect proposed single encoder-decoder joint refinement based IFRNet to be a useful component for many frame rate up-conversion, video compression and intermediate view synthesis systems. Time and FLOPs are measured on 1280 x 720 resolution. 60 | 61 | ![](./figures/benchmarks.png) 62 | 63 | 64 | ## Qualitative Comparison for 2x Frame Interpolation 65 | Video comparison for 2x interpolation of methods using 2 input frames on SNU-FILM dataset. 66 | 67 | ![](./figures/fig2_1.gif) 68 | 69 | ![](./figures/fig2_2.gif) 70 | 71 | 72 | ## Middlebury Benchmark 73 | Results on the [Middlebury](https://vision.middlebury.edu/flow/eval/results/results-i1.php) online benchmark. 74 | 75 | ![](./figures/middlebury.png) 76 | 77 | Results on the Middlebury Other dataset. 78 | 79 | ![](./figures/middlebury_other.png) 80 | 81 | 82 | ## Training on GoPro Dataset for 8x Frame Interpolation 83 | 1. Start training by executing one of the following commands with selected model 84 |
$ python -m torch.distributed.launch --nproc_per_node=4 train_gopro.py --world_size 4 --model_name 'IFRNet' --epochs 600 --batch_size 2 --lr_start 1e-4 --lr_end 1e-5
 85 | $ python -m torch.distributed.launch --nproc_per_node=4 train_gopro.py --world_size 4 --model_name 'IFRNet_L' --epochs 600 --batch_size 2 --lr_start 1e-4 --lr_end 1e-5
 86 | $ python -m torch.distributed.launch --nproc_per_node=4 train_gopro.py --world_size 4 --model_name 'IFRNet_S' --epochs 600 --batch_size 2 --lr_start 1e-4 --lr_end 1e-5
87 | 88 | Since inter-frame motion in 8x interpolation setting is relatively small, task-oriented flow distillation loss is omitted here. Due to the GoPro training set is a relatively small dataset, we suggest to use your specific datasets to train slow-motion generation for better results. 89 | 90 | ## Quantitative Comparison for 8x Frame Interpolation 91 | 92 | 93 | 94 | ## Qualitative Results on GoPro and Adobe240 Datasets for 8x Frame Interpolation 95 | Each video has 9 frames, where the first and the last frames are input, and the middle 7 frames are predicted by IFRNet. 96 | 97 |

98 | 99 | 100 | 101 |

102 | 103 | ## ncnn Implementation of IFRNet 104 | 105 | [ifrnet-ncnn-vulkan](https://github.com/nihui/ifrnet-ncnn-vulkan) uses [ncnn project](https://github.com/Tencent/ncnn) as the universal neural network inference framework. This package includes all the binaries and models required. It is portable, so no CUDA or PyTorch runtime environment is needed. 106 | 107 | ## Citation 108 | When using any parts of the Software or the Paper in your work, please cite the following paper: 109 |
@InProceedings{Kong_2022_CVPR, 
110 |   author = {Kong, Lingtong and Jiang, Boyuan and Luo, Donghao and Chu, Wenqing and Huang, Xiaoming and Tai, Ying and Wang, Chengjie and Yang, Jie}, 
111 |   title = {IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation}, 
112 |   booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 
113 |   year = {2022}
114 | }
115 | -------------------------------------------------------------------------------- /benchmarks/Middlebury_Other.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import torch 5 | import numpy as np 6 | from utils import read 7 | from metric import calculate_psnr, calculate_ssim, calculate_ie 8 | from models.IFRNet import Model 9 | # from models.IFRNet_L import Model 10 | # from models.IFRNet_S import Model 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | model = Model() 16 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth')) 17 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth')) 18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth')) 19 | model.eval() 20 | model.cuda() 21 | 22 | # Replace the 'path' with your Middlebury dataset absolute path. 23 | path = '/home/ltkong/Datasets/Middlebury/' 24 | sequence = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] 25 | 26 | psnr_list = [] 27 | ssim_list = [] 28 | ie_list = [] 29 | for i in sequence: 30 | I0 = read(path + 'other-data/{}/frame10.png'.format(i)) 31 | I1 = read(path + 'other-gt-interp/{}/frame10i11.png'.format(i)) 32 | I2 = read(path + 'other-data/{}/frame11.png'.format(i)) 33 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 34 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 35 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 36 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 37 | 38 | I0_pad = torch.zeros([1, 3, 480, 640]).to(device) 39 | I2_pad = torch.zeros([1, 3, 480, 640]).to(device) 40 | h, w = I0.shape[-2:] 41 | I0_pad[:, :, :h, :w] = I0 42 | I2_pad[:, :, :h, :w] = I2 43 | 44 | I1_pred_pad = model.inference(I0_pad, I2_pad, embt) 45 | I1_pred = I1_pred_pad[:, :, :h, :w] 46 | 47 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() 48 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() 49 | ie = calculate_ie(I1_pred, I1).detach().cpu().numpy() 50 | 51 | psnr_list.append(psnr) 52 | ssim_list.append(ssim) 53 | ie_list.append(ie) 54 | 55 | print('Avg PSNR: {} SSIM: {} IE: {}'.format(np.mean(psnr_list), np.mean(ssim_list), np.mean(ie_list))) 56 | -------------------------------------------------------------------------------- /benchmarks/SNU_FILM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from utils import read 8 | from metric import calculate_psnr, calculate_ssim 9 | from models.IFRNet import Model 10 | # from models.IFRNet_L import Model 11 | # from models.IFRNet_S import Model 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | model = Model() 17 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth')) 18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth')) 19 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth')) 20 | model.eval() 21 | model.cuda() 22 | 23 | divisor = 20 24 | scale_factor = 0.8 25 | 26 | class InputPadder: 27 | """ Pads images such that dimensions are divisible by divisor """ 28 | def __init__(self, dims, divisor=divisor): 29 | self.ht, self.wd = dims[-2:] 30 | pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor 31 | pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor 32 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 33 | 34 | def pad(self, *inputs): 35 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 36 | 37 | def unpad(self,x): 38 | ht, wd = x.shape[-2:] 39 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 40 | return x[..., c[0]:c[1], c[2]:c[3]] 41 | 42 | # Replace the 'path' with your SNU-FILM dataset absolute path. 43 | path = '/home/ltkong/Datasets/SNU-FILM/' 44 | 45 | psnr_list = [] 46 | ssim_list = [] 47 | file_list = [] 48 | test_file = "test-hard.txt" # test-easy.txt, test-medium.txt, test-hard.txt, test-extreme.txt 49 | with open(os.path.join(path, test_file), "r") as f: 50 | for line in f: 51 | line = line.strip() 52 | file_list.append(line.split(' ')) 53 | 54 | for line in file_list: 55 | print(os.path.join(path, line[0])) 56 | I0_path = os.path.join(path, line[0]) 57 | I1_path = os.path.join(path, line[1]) 58 | I2_path = os.path.join(path, line[2]) 59 | I0 = read(I0_path) 60 | I1 = read(I1_path) 61 | I2 = read(I2_path) 62 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 63 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 64 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 65 | padder = InputPadder(I0.shape) 66 | I0, I2 = padder.pad(I0, I2) 67 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 68 | 69 | I1_pred = model.inference(I0, I2, embt, scale_factor=scale_factor) 70 | I1_pred = padder.unpad(I1_pred) 71 | 72 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() 73 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() 74 | 75 | psnr_list.append(psnr) 76 | ssim_list.append(ssim) 77 | 78 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list))) 79 | -------------------------------------------------------------------------------- /benchmarks/UCF101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import torch 5 | import numpy as np 6 | from utils import read 7 | from metric import calculate_psnr, calculate_ssim 8 | from models.IFRNet import Model 9 | # from models.IFRNet_L import Model 10 | # from models.IFRNet_S import Model 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | model = Model() 16 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth')) 17 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth')) 18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth')) 19 | model.eval() 20 | model.cuda() 21 | 22 | # Replace the 'path' with your UCF101 dataset absolute path. 23 | path = '/home/ltkong/Datasets/UCF101/ucf101_interp_ours/' 24 | dirs = sorted(os.listdir(path)) 25 | 26 | psnr_list = [] 27 | ssim_list = [] 28 | for d in dirs: 29 | print(path + d + '/frame_00.png') 30 | I0 = read(path + d + '/frame_00.png') 31 | I1 = read(path + d + '/frame_01_gt.png') 32 | I2 = read(path + d + '/frame_02.png') 33 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 34 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 35 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 36 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 37 | 38 | I1_pred = model.inference(I0, I2, embt) 39 | 40 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() 41 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() 42 | 43 | psnr_list.append(psnr) 44 | ssim_list.append(ssim) 45 | 46 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list))) 47 | -------------------------------------------------------------------------------- /benchmarks/Vimeo90K.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import torch 5 | import numpy as np 6 | from utils import read 7 | from metric import calculate_psnr, calculate_ssim 8 | from models.IFRNet import Model 9 | # from models.IFRNet_L import Model 10 | # from models.IFRNet_S import Model 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | model = Model() 16 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth')) 17 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth')) 18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth')) 19 | model.eval() 20 | model.cuda() 21 | 22 | # Replace the 'path' with your Vimeo90K dataset absolute path. 23 | path = '/home/ltkong/Datasets/Vimeo90K/vimeo_triplet/' 24 | f = open(path + 'tri_testlist.txt', 'r') 25 | 26 | psnr_list = [] 27 | ssim_list = [] 28 | for i in f: 29 | name = str(i).strip() 30 | if(len(name) <= 1): 31 | continue 32 | print(path + 'sequences/' + name + '/im1.png') 33 | I0 = read(path + 'sequences/' + name + '/im1.png') 34 | I1 = read(path + 'sequences/' + name + '/im2.png') 35 | I2 = read(path + 'sequences/' + name + '/im3.png') 36 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 37 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 38 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device) 39 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 40 | 41 | I1_pred = model.inference(I0, I2, embt) 42 | 43 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() 44 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() 45 | 46 | psnr_list.append(psnr) 47 | ssim_list.append(ssim) 48 | 49 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list))) 50 | -------------------------------------------------------------------------------- /benchmarks/speed_parameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from models.IFRNet import Model 9 | # from models.IFRNet_L import Model 10 | # from models.IFRNet_S import Model 11 | 12 | if torch.cuda.is_available(): 13 | torch.backends.cudnn.enabled = True 14 | torch.backends.cudnn.benchmark = True 15 | 16 | img0 = torch.randn(1, 3, 256, 448).cuda() 17 | img1 = torch.randn(1, 3, 256, 448).cuda() 18 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).cuda() 19 | 20 | model = Model().cuda().eval() 21 | 22 | with torch.no_grad(): 23 | for i in range(100): 24 | out = model.inference(img0, img1, embt) 25 | if torch.cuda.is_available(): 26 | torch.cuda.synchronize() 27 | time_stamp = time.time() 28 | for i in range(100): 29 | out = model.inference(img0, img1, embt) 30 | if torch.cuda.is_available(): 31 | torch.cuda.synchronize() 32 | print('Time: {:.3f}s'.format((time.time() - time_stamp) / 100)) 33 | 34 | total = sum([param.nelement() for param in model.parameters()]) 35 | print('Parameters: {:.2f}M'.format(total / 1e6)) 36 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from torch.utils.data import Dataset 7 | from utils import read 8 | 9 | 10 | def random_resize(img0, imgt, img1, flow, p=0.1): 11 | if random.uniform(0, 1) < p: 12 | img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 13 | imgt = cv2.resize(imgt, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 14 | img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 15 | flow = cv2.resize(flow, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) * 2.0 16 | return img0, imgt, img1, flow 17 | 18 | 19 | def random_crop(img0, imgt, img1, flow, crop_size=(224, 224)): 20 | h, w = crop_size[0], crop_size[1] 21 | ih, iw, _ = img0.shape 22 | x = np.random.randint(0, ih-h+1) 23 | y = np.random.randint(0, iw-w+1) 24 | img0 = img0[x:x+h, y:y+w, :] 25 | imgt = imgt[x:x+h, y:y+w, :] 26 | img1 = img1[x:x+h, y:y+w, :] 27 | flow = flow[x:x+h, y:y+w, :] 28 | return img0, imgt, img1, flow 29 | 30 | 31 | def random_reverse_channel(img0, imgt, img1, flow, p=0.5): 32 | if random.uniform(0, 1) < p: 33 | img0 = img0[:, :, ::-1] 34 | imgt = imgt[:, :, ::-1] 35 | img1 = img1[:, :, ::-1] 36 | return img0, imgt, img1, flow 37 | 38 | 39 | def random_vertical_flip(img0, imgt, img1, flow, p=0.3): 40 | if random.uniform(0, 1) < p: 41 | img0 = img0[::-1] 42 | imgt = imgt[::-1] 43 | img1 = img1[::-1] 44 | flow = flow[::-1] 45 | flow = np.concatenate((flow[:, :, 0:1], -flow[:, :, 1:2], flow[:, :, 2:3], -flow[:, :, 3:4]), 2) 46 | return img0, imgt, img1, flow 47 | 48 | 49 | def random_horizontal_flip(img0, imgt, img1, flow, p=0.5): 50 | if random.uniform(0, 1) < p: 51 | img0 = img0[:, ::-1] 52 | imgt = imgt[:, ::-1] 53 | img1 = img1[:, ::-1] 54 | flow = flow[:, ::-1] 55 | flow = np.concatenate((-flow[:, :, 0:1], flow[:, :, 1:2], -flow[:, :, 2:3], flow[:, :, 3:4]), 2) 56 | return img0, imgt, img1, flow 57 | 58 | 59 | def random_rotate(img0, imgt, img1, flow, p=0.05): 60 | if random.uniform(0, 1) < p: 61 | img0 = img0.transpose((1, 0, 2)) 62 | imgt = imgt.transpose((1, 0, 2)) 63 | img1 = img1.transpose((1, 0, 2)) 64 | flow = flow.transpose((1, 0, 2)) 65 | flow = np.concatenate((flow[:, :, 1:2], flow[:, :, 0:1], flow[:, :, 3:4], flow[:, :, 2:3]), 2) 66 | return img0, imgt, img1, flow 67 | 68 | 69 | def random_reverse_time(img0, imgt, img1, flow, p=0.5): 70 | if random.uniform(0, 1) < p: 71 | tmp = img1 72 | img1 = img0 73 | img0 = tmp 74 | flow = np.concatenate((flow[:, :, 2:4], flow[:, :, 0:2]), 2) 75 | return img0, imgt, img1, flow 76 | 77 | 78 | class Vimeo90K_Train_Dataset(Dataset): 79 | def __init__(self, dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet', augment=True): 80 | self.dataset_dir = dataset_dir 81 | self.augment = augment 82 | self.img0_list = [] 83 | self.imgt_list = [] 84 | self.img1_list = [] 85 | self.flow_t0_list = [] 86 | self.flow_t1_list = [] 87 | with open(os.path.join(dataset_dir, 'tri_trainlist.txt'), 'r') as f: 88 | for i in f: 89 | name = str(i).strip() 90 | if(len(name) <= 1): 91 | continue 92 | self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png')) 93 | self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png')) 94 | self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png')) 95 | self.flow_t0_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t0.flo')) 96 | self.flow_t1_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t1.flo')) 97 | 98 | def __len__(self): 99 | return len(self.imgt_list) 100 | 101 | def __getitem__(self, idx): 102 | img0 = read(self.img0_list[idx]) 103 | imgt = read(self.imgt_list[idx]) 104 | img1 = read(self.img1_list[idx]) 105 | flow_t0 = read(self.flow_t0_list[idx]) 106 | flow_t1 = read(self.flow_t1_list[idx]) 107 | flow = np.concatenate((flow_t0, flow_t1), 2).astype(np.float64) 108 | 109 | if self.augment == True: 110 | img0, imgt, img1, flow = random_resize(img0, imgt, img1, flow, p=0.1) 111 | img0, imgt, img1, flow = random_crop(img0, imgt, img1, flow, crop_size=(224, 224)) 112 | img0, imgt, img1, flow = random_reverse_channel(img0, imgt, img1, flow, p=0.5) 113 | img0, imgt, img1, flow = random_vertical_flip(img0, imgt, img1, flow, p=0.3) 114 | img0, imgt, img1, flow = random_horizontal_flip(img0, imgt, img1, flow, p=0.5) 115 | img0, imgt, img1, flow = random_rotate(img0, imgt, img1, flow, p=0.05) 116 | img0, imgt, img1, flow = random_reverse_time(img0, imgt, img1, flow, p=0.5) 117 | 118 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) 119 | imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0) 120 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) 121 | flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32)) 122 | embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32)) 123 | 124 | return img0, imgt, img1, flow, embt 125 | 126 | 127 | class Vimeo90K_Test_Dataset(Dataset): 128 | def __init__(self, dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet'): 129 | self.dataset_dir = dataset_dir 130 | self.img0_list = [] 131 | self.imgt_list = [] 132 | self.img1_list = [] 133 | self.flow_t0_list = [] 134 | self.flow_t1_list = [] 135 | with open(os.path.join(dataset_dir, 'tri_testlist.txt'), 'r') as f: 136 | for i in f: 137 | name = str(i).strip() 138 | if(len(name) <= 1): 139 | continue 140 | self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png')) 141 | self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png')) 142 | self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png')) 143 | self.flow_t0_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t0.flo')) 144 | self.flow_t1_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t1.flo')) 145 | 146 | def __len__(self): 147 | return len(self.imgt_list) 148 | 149 | def __getitem__(self, idx): 150 | img0 = read(self.img0_list[idx]) 151 | imgt = read(self.imgt_list[idx]) 152 | img1 = read(self.img1_list[idx]) 153 | flow_t0 = read(self.flow_t0_list[idx]) 154 | flow_t1 = read(self.flow_t1_list[idx]) 155 | flow = np.concatenate((flow_t0, flow_t1), 2) 156 | 157 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) 158 | imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0) 159 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) 160 | flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32)) 161 | embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32)) 162 | 163 | return img0, imgt, img1, flow, embt 164 | 165 | 166 | 167 | def random_resize_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.1): 168 | if random.uniform(0, 1) < p: 169 | img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 170 | img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 171 | img2 = cv2.resize(img2, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 172 | img3 = cv2.resize(img3, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 173 | img4 = cv2.resize(img4, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 174 | img5 = cv2.resize(img5, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 175 | img6 = cv2.resize(img6, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 176 | img7 = cv2.resize(img7, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 177 | img8 = cv2.resize(img8, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) 178 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 179 | 180 | 181 | def random_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(224, 224)): 182 | h, w = crop_size[0], crop_size[1] 183 | ih, iw, _ = img0.shape 184 | x = np.random.randint(0, ih-h+1) 185 | y = np.random.randint(0, iw-w+1) 186 | img0 = img0[x:x+h, y:y+w, :] 187 | img1 = img1[x:x+h, y:y+w, :] 188 | img2 = img2[x:x+h, y:y+w, :] 189 | img3 = img3[x:x+h, y:y+w, :] 190 | img4 = img4[x:x+h, y:y+w, :] 191 | img5 = img5[x:x+h, y:y+w, :] 192 | img6 = img6[x:x+h, y:y+w, :] 193 | img7 = img7[x:x+h, y:y+w, :] 194 | img8 = img8[x:x+h, y:y+w, :] 195 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 196 | 197 | 198 | def center_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(512, 512)): 199 | h, w = crop_size[0], crop_size[1] 200 | ih, iw, _ = img0.shape 201 | img0 = img0[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 202 | img1 = img1[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 203 | img2 = img2[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 204 | img3 = img3[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 205 | img4 = img4[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 206 | img5 = img5[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 207 | img6 = img6[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 208 | img7 = img7[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 209 | img8 = img8[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :] 210 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 211 | 212 | 213 | def random_reverse_channel_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5): 214 | if random.uniform(0, 1) < p: 215 | img0 = img0[:, :, ::-1] 216 | img1 = img1[:, :, ::-1] 217 | img2 = img2[:, :, ::-1] 218 | img3 = img3[:, :, ::-1] 219 | img4 = img4[:, :, ::-1] 220 | img5 = img5[:, :, ::-1] 221 | img6 = img6[:, :, ::-1] 222 | img7 = img7[:, :, ::-1] 223 | img8 = img8[:, :, ::-1] 224 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 225 | 226 | 227 | def random_vertical_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.3): 228 | if random.uniform(0, 1) < p: 229 | img0 = img0[::-1] 230 | img1 = img1[::-1] 231 | img2 = img2[::-1] 232 | img3 = img3[::-1] 233 | img4 = img4[::-1] 234 | img5 = img5[::-1] 235 | img6 = img6[::-1] 236 | img7 = img7[::-1] 237 | img8 = img8[::-1] 238 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 239 | 240 | 241 | def random_horizontal_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5): 242 | if random.uniform(0, 1) < p: 243 | img0 = img0[:, ::-1] 244 | img1 = img1[:, ::-1] 245 | img2 = img2[:, ::-1] 246 | img3 = img3[:, ::-1] 247 | img4 = img4[:, ::-1] 248 | img5 = img5[:, ::-1] 249 | img6 = img6[:, ::-1] 250 | img7 = img7[:, ::-1] 251 | img8 = img8[:, ::-1] 252 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 253 | 254 | 255 | def random_rotate_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.05): 256 | if random.uniform(0, 1) < p: 257 | img0 = img0.transpose((1, 0, 2)) 258 | img1 = img1.transpose((1, 0, 2)) 259 | img2 = img2.transpose((1, 0, 2)) 260 | img3 = img3.transpose((1, 0, 2)) 261 | img4 = img4.transpose((1, 0, 2)) 262 | img5 = img5.transpose((1, 0, 2)) 263 | img6 = img6.transpose((1, 0, 2)) 264 | img7 = img7.transpose((1, 0, 2)) 265 | img8 = img8.transpose((1, 0, 2)) 266 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 267 | 268 | 269 | def random_reverse_time_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5): 270 | if random.uniform(0, 1) < p: 271 | return img8, img7, img6, img5, img4, img3, img2, img1, img0 272 | else: 273 | return img0, img1, img2, img3, img4, img5, img6, img7, img8 274 | 275 | 276 | class GoPro_Train_Dataset(Dataset): 277 | def __init__(self, dataset_dir='/home/ltkong/Datasets/GOPRO', interFrames=7, n_inputs=2, augment=True): 278 | self.dataset_dir = dataset_dir 279 | self.interFrames = interFrames 280 | self.n_inputs = n_inputs 281 | self.augment = augment 282 | self.setLength = (n_inputs-1)*(interFrames+1)+1 283 | video_list = [ 284 | 'GOPR0372_07_00', 'GOPR0374_11_01', 'GOPR0378_13_00', 'GOPR0384_11_01', 'GOPR0384_11_04', 'GOPR0477_11_00', 'GOPR0868_11_02', 'GOPR0884_11_00', 285 | 'GOPR0372_07_01', 'GOPR0374_11_02', 'GOPR0379_11_00', 'GOPR0384_11_02', 'GOPR0385_11_00', 'GOPR0857_11_00', 'GOPR0871_11_01', 'GOPR0374_11_00', 286 | 'GOPR0374_11_03', 'GOPR0380_11_00', 'GOPR0384_11_03', 'GOPR0386_11_00', 'GOPR0868_11_01', 'GOPR0881_11_00'] 287 | self.frames_list = [] 288 | self.file_list = [] 289 | for video in video_list: 290 | frames = sorted(os.listdir(os.path.join(self.dataset_dir, video))) 291 | n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1 292 | videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength] for i in range(n_sets)] 293 | videoInputs = [[os.path.join(video, f) for f in group] for group in videoInputs] 294 | self.file_list.extend(videoInputs) 295 | 296 | def __len__(self): 297 | return len(self.file_list) 298 | 299 | def __getitem__(self, idx): 300 | imgpaths = [os.path.join(self.dataset_dir, fp) for fp in self.file_list[idx]] 301 | pick_idxs = list(range(0, self.setLength, self.interFrames+1)) 302 | rem = self.interFrames%2 303 | gt_idx = list(range(self.setLength//2-self.interFrames//2, self.setLength//2+self.interFrames//2+rem)) 304 | input_paths = [imgpaths[idx] for idx in pick_idxs] 305 | gt_paths = [imgpaths[idx] for idx in gt_idx] 306 | img0 = np.array(read(input_paths[0])) 307 | img1 = np.array(read(gt_paths[0])) 308 | img2 = np.array(read(gt_paths[1])) 309 | img3 = np.array(read(gt_paths[2])) 310 | img4 = np.array(read(gt_paths[3])) 311 | img5 = np.array(read(gt_paths[4])) 312 | img6 = np.array(read(gt_paths[5])) 313 | img7 = np.array(read(gt_paths[6])) 314 | img8 = np.array(read(input_paths[1])) 315 | 316 | if self.augment == True: 317 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_resize_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.1) 318 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(224, 224)) 319 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_reverse_channel_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5) 320 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_vertical_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.3) 321 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_horizontal_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5) 322 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_rotate_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.05) 323 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_reverse_time_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5) 324 | else: 325 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = center_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(512, 512)) 326 | 327 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) 328 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) 329 | img2 = torch.from_numpy(img2.transpose((2, 0, 1)).astype(np.float32) / 255.0) 330 | img3 = torch.from_numpy(img3.transpose((2, 0, 1)).astype(np.float32) / 255.0) 331 | img4 = torch.from_numpy(img4.transpose((2, 0, 1)).astype(np.float32) / 255.0) 332 | img5 = torch.from_numpy(img5.transpose((2, 0, 1)).astype(np.float32) / 255.0) 333 | img6 = torch.from_numpy(img6.transpose((2, 0, 1)).astype(np.float32) / 255.0) 334 | img7 = torch.from_numpy(img7.transpose((2, 0, 1)).astype(np.float32) / 255.0) 335 | img8 = torch.from_numpy(img8.transpose((2, 0, 1)).astype(np.float32) / 255.0) 336 | 337 | emb1 = torch.from_numpy(np.array(1/8).reshape(1, 1, 1).astype(np.float32)) 338 | emb2 = torch.from_numpy(np.array(2/8).reshape(1, 1, 1).astype(np.float32)) 339 | emb3 = torch.from_numpy(np.array(3/8).reshape(1, 1, 1).astype(np.float32)) 340 | emb4 = torch.from_numpy(np.array(4/8).reshape(1, 1, 1).astype(np.float32)) 341 | emb5 = torch.from_numpy(np.array(5/8).reshape(1, 1, 1).astype(np.float32)) 342 | emb6 = torch.from_numpy(np.array(6/8).reshape(1, 1, 1).astype(np.float32)) 343 | emb7 = torch.from_numpy(np.array(7/8).reshape(1, 1, 1).astype(np.float32)) 344 | 345 | return img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7 346 | 347 | 348 | class GoPro_Test_Dataset(Dataset): 349 | def __init__(self, dataset_dir='/home/ltkong/Datasets/GOPRO', interFrames=7, n_inputs=2): 350 | self.dataset_dir = dataset_dir 351 | self.interFrames = interFrames 352 | self.n_inputs = n_inputs 353 | self.setLength = (n_inputs-1)*(interFrames+1)+1 354 | video_list = [ 355 | 'GOPR0384_11_00', 'GOPR0385_11_01', 'GOPR0410_11_00', 'GOPR0862_11_00', 'GOPR0869_11_00', 'GOPR0881_11_01', 'GOPR0384_11_05', 'GOPR0396_11_00', 356 | 'GOPR0854_11_00', 'GOPR0868_11_00', 'GOPR0871_11_00'] 357 | self.frames_list = [] 358 | self.file_list = [] 359 | for video in video_list: 360 | frames = sorted(os.listdir(os.path.join(self.dataset_dir, video))) 361 | n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1 362 | videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength] for i in range(n_sets)] 363 | videoInputs = [[os.path.join(video, f) for f in group] for group in videoInputs] 364 | self.file_list.extend(videoInputs) 365 | 366 | def __len__(self): 367 | return len(self.file_list) 368 | 369 | def __getitem__(self, idx): 370 | imgpaths = [os.path.join(self.dataset_dir, fp) for fp in self.file_list[idx]] 371 | pick_idxs = list(range(0, self.setLength, self.interFrames+1)) 372 | rem = self.interFrames%2 373 | gt_idx = list(range(self.setLength//2-self.interFrames//2, self.setLength//2+self.interFrames//2+rem)) 374 | input_paths = [imgpaths[idx] for idx in pick_idxs] 375 | gt_paths = [imgpaths[idx] for idx in gt_idx] 376 | img0 = np.array(read(input_paths[0])) 377 | img1 = np.array(read(gt_paths[0])) 378 | img2 = np.array(read(gt_paths[1])) 379 | img3 = np.array(read(gt_paths[2])) 380 | img4 = np.array(read(gt_paths[3])) 381 | img5 = np.array(read(gt_paths[4])) 382 | img6 = np.array(read(gt_paths[5])) 383 | img7 = np.array(read(gt_paths[6])) 384 | img8 = np.array(read(input_paths[1])) 385 | 386 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = center_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(512, 512)) 387 | 388 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0) 389 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0) 390 | img2 = torch.from_numpy(img2.transpose((2, 0, 1)).astype(np.float32) / 255.0) 391 | img3 = torch.from_numpy(img3.transpose((2, 0, 1)).astype(np.float32) / 255.0) 392 | img4 = torch.from_numpy(img4.transpose((2, 0, 1)).astype(np.float32) / 255.0) 393 | img5 = torch.from_numpy(img5.transpose((2, 0, 1)).astype(np.float32) / 255.0) 394 | img6 = torch.from_numpy(img6.transpose((2, 0, 1)).astype(np.float32) / 255.0) 395 | img7 = torch.from_numpy(img7.transpose((2, 0, 1)).astype(np.float32) / 255.0) 396 | img8 = torch.from_numpy(img8.transpose((2, 0, 1)).astype(np.float32) / 255.0) 397 | 398 | emb1 = torch.from_numpy(np.array(1/8).reshape(1, 1, 1).astype(np.float32)) 399 | emb2 = torch.from_numpy(np.array(2/8).reshape(1, 1, 1).astype(np.float32)) 400 | emb3 = torch.from_numpy(np.array(3/8).reshape(1, 1, 1).astype(np.float32)) 401 | emb4 = torch.from_numpy(np.array(4/8).reshape(1, 1, 1).astype(np.float32)) 402 | emb5 = torch.from_numpy(np.array(5/8).reshape(1, 1, 1).astype(np.float32)) 403 | emb6 = torch.from_numpy(np.array(6/8).reshape(1, 1, 1).astype(np.float32)) 404 | emb7 = torch.from_numpy(np.array(7/8).reshape(1, 1, 1).astype(np.float32)) 405 | 406 | return img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7 407 | -------------------------------------------------------------------------------- /demo_2x.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from models.IFRNet import Model 5 | from utils import read 6 | from imageio import mimsave 7 | 8 | 9 | model = Model().cuda().eval() 10 | model.load_state_dict(torch.load('./checkpoints/IFRNet/IFRNet_Vimeo90K.pth')) 11 | 12 | img0_np = read('./figures/img0.png') 13 | img1_np = read('./figures/img1.png') 14 | 15 | img0 = (torch.tensor(img0_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 16 | img1 = (torch.tensor(img1_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 17 | embt = torch.tensor(1/2).view(1, 1, 1, 1).float().cuda() 18 | 19 | imgt_pred = model.inference(img0, img1, embt) 20 | 21 | imgt_pred_np = (imgt_pred[0].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 22 | 23 | images = [img0_np, imgt_pred_np, img1_np] 24 | mimsave('./figures/out_2x.gif', images, fps=3) 25 | -------------------------------------------------------------------------------- /demo_8x.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from models.IFRNet import Model 5 | from utils import read 6 | from imageio import mimsave 7 | 8 | 9 | model = Model().cuda().eval() 10 | model.load_state_dict(torch.load('./checkpoints/IFRNet/IFRNet_GoPro.pth')) 11 | 12 | img0_np = read('./figures/img0.png') 13 | img8_np = read('./figures/img1.png') 14 | 15 | img0 = (torch.tensor(img0_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 16 | img8 = (torch.tensor(img8_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 17 | 18 | emb1 = torch.tensor(1/8).view(1, 1, 1, 1).float().cuda() 19 | emb2 = torch.tensor(2/8).view(1, 1, 1, 1).float().cuda() 20 | emb3 = torch.tensor(3/8).view(1, 1, 1, 1).float().cuda() 21 | emb4 = torch.tensor(4/8).view(1, 1, 1, 1).float().cuda() 22 | emb5 = torch.tensor(5/8).view(1, 1, 1, 1).float().cuda() 23 | emb6 = torch.tensor(6/8).view(1, 1, 1, 1).float().cuda() 24 | emb7 = torch.tensor(7/8).view(1, 1, 1, 1).float().cuda() 25 | 26 | img0_ = torch.cat([img0, img0, img0, img0, img0, img0, img0], 0) 27 | img8_ = torch.cat([img8, img8, img8, img8, img8, img8, img8], 0) 28 | embt = torch.cat([emb1, emb2, emb3, emb4, emb5, emb6, emb7], 0) 29 | 30 | imgt_pred = model.inference(img0_, img8_, embt) 31 | 32 | img1_pred_np = (imgt_pred[0].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 33 | img2_pred_np = (imgt_pred[1].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 34 | img3_pred_np = (imgt_pred[2].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 35 | img4_pred_np = (imgt_pred[3].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 36 | img5_pred_np = (imgt_pred[4].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 37 | img6_pred_np = (imgt_pred[5].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 38 | img7_pred_np = (imgt_pred[6].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8) 39 | 40 | images = [img0_np, img1_pred_np, img2_pred_np, img3_pred_np, img4_pred_np, img5_pred_np, img6_pred_np, img7_pred_np, img8_np] 41 | mimsave('./figures/out_8x.gif', images, fps=9) 42 | -------------------------------------------------------------------------------- /figures/8x_interpolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/8x_interpolation.png -------------------------------------------------------------------------------- /figures/benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/benchmarks.png -------------------------------------------------------------------------------- /figures/fig1_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig1_1.gif -------------------------------------------------------------------------------- /figures/fig1_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig1_2.gif -------------------------------------------------------------------------------- /figures/fig1_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig1_3.gif -------------------------------------------------------------------------------- /figures/fig2_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig2_1.gif -------------------------------------------------------------------------------- /figures/fig2_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig2_2.gif -------------------------------------------------------------------------------- /figures/img0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/img0.png -------------------------------------------------------------------------------- /figures/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/img1.png -------------------------------------------------------------------------------- /figures/img_overlaid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/img_overlaid.png -------------------------------------------------------------------------------- /figures/middlebury.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/middlebury.png -------------------------------------------------------------------------------- /figures/middlebury_other.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/middlebury_other.png -------------------------------------------------------------------------------- /figures/out_2x.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/out_2x.gif -------------------------------------------------------------------------------- /figures/out_8x.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/out_8x.gif -------------------------------------------------------------------------------- /figures/vimeo90k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/vimeo90k.png -------------------------------------------------------------------------------- /generate_flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from liteflownet.run import estimate 6 | from utils import read, write 7 | 8 | 9 | # set vimeo90k_dir with your Vimeo90K triplet dataset path, like '/.../vimeo_triplet' 10 | vimeo90k_dir = '/home/ltkong/Datasets/Vimeo90K/vimeo_triplet' 11 | 12 | vimeo90k_sequences_dir = os.path.join(vimeo90k_dir, 'sequences') 13 | vimeo90k_flow_dir = os.path.join(vimeo90k_dir, 'flow') 14 | 15 | if not os.path.exists(vimeo90k_flow_dir): 16 | os.makedirs(vimeo90k_flow_dir) 17 | 18 | for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)): 19 | vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path) 20 | vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path) 21 | if not os.path.exists(vimeo90k_flow_path_dir): 22 | os.mkdir(vimeo90k_flow_path_dir) 23 | 24 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): 25 | vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id) 26 | if not os.path.exists(vimeo90k_flow_id_dir): 27 | os.mkdir(vimeo90k_flow_id_dir) 28 | 29 | print('Built Flow Path') 30 | 31 | 32 | def pred_flow(img1, img2): 33 | img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0 34 | img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0 35 | 36 | flow = estimate(img1, img2) 37 | 38 | flow = flow.permute(1, 2, 0).cpu().numpy() 39 | return flow 40 | 41 | for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)): 42 | vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path) 43 | vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path) 44 | 45 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): 46 | vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id) 47 | vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id) 48 | 49 | img0_path = vimeo90k_sequences_id_dir + '/im1.png' 50 | imgt_path = vimeo90k_sequences_id_dir + '/im2.png' 51 | img1_path = vimeo90k_sequences_id_dir + '/im3.png' 52 | flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo' 53 | flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo' 54 | 55 | img0 = read(img0_path) 56 | imgt = read(imgt_path) 57 | img1 = read(img1_path) 58 | 59 | flow_t0 = pred_flow(imgt, img0) 60 | flow_t1 = pred_flow(imgt, img1) 61 | 62 | write(flow_t0_path, flow_t0) 63 | write(flow_t1_path, flow_t1) 64 | 65 | print('Written Sequences {}'.format(sequences_path)) 66 | -------------------------------------------------------------------------------- /liteflownet/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-liteflownet 2 | This is a personal reimplementation of LiteFlowNet [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2]. 3 | 4 | Paper 5 | 6 | For the original Caffe version of this work, please see: https://github.com/twhui/LiteFlowNet 7 |
8 | Other optical flow implementations from me: [pytorch-pwc](https://github.com/sniklaus/pytorch-pwc), [pytorch-unflow](https://github.com/sniklaus/pytorch-unflow), [pytorch-spynet](https://github.com/sniklaus/pytorch-spynet) 9 | 10 | ## setup 11 | The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository. If you would like to use Docker, you can take a look at [this](https://github.com/sniklaus/pytorch-liteflownet/pull/43) pull request to get started. 12 | 13 | ## usage 14 | To run it on your own pair of images, use the following command. You can choose between three models, please make sure to see their paper / the code for more details. 15 | 16 | ``` 17 | python run.py --model default --one ./images/one.png --two ./images/two.png --out ./out.flo 18 | ``` 19 | 20 | I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results pretty much identical to the implementation of the original authors in the examples that I tried. There are some numerical deviations that stem from differences in the `DownsampleLayer` of Caffe and the `torch.nn.functional.interpolate` function of PyTorch. Please feel free to contribute to this repository by submitting issues and pull requests. 21 | 22 | ## comparison 23 |

Comparison

24 | 25 | ## license 26 | As stated in the licensing terms of the authors of the paper, their material is provided for research purposes only. Please make sure to further consult their licensing terms. 27 | 28 | ## references 29 | ``` 30 | [1] @inproceedings{Hui_CVPR_2018, 31 | author = {Tak-Wai Hui and Xiaoou Tang and Chen Change Loy}, 32 | title = {{LiteFlowNet}: A Lightweight Convolutional Neural Network for Optical Flow Estimation}, 33 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 34 | year = {2018} 35 | } 36 | ``` 37 | 38 | ``` 39 | [2] @misc{pytorch-liteflownet, 40 | author = {Simon Niklaus}, 41 | title = {A Reimplementation of {LiteFlowNet} Using {PyTorch}}, 42 | year = {2019}, 43 | howpublished = {\url{https://github.com/sniklaus/pytorch-liteflownet}} 44 | } 45 | ``` -------------------------------------------------------------------------------- /liteflownet/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /liteflownet/correlation/correlation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import cupy 4 | import math 5 | import re 6 | import torch 7 | 8 | kernel_Correlation_rearrange = ''' 9 | extern "C" __global__ void kernel_Correlation_rearrange( 10 | const int n, 11 | const float* input, 12 | float* output 13 | ) { 14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 15 | 16 | if (intIndex >= n) { 17 | return; 18 | } 19 | 20 | int intSample = blockIdx.z; 21 | int intChannel = blockIdx.y; 22 | 23 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 24 | 25 | __syncthreads(); 26 | 27 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}}; 28 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}}; 29 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX; 30 | 31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; 32 | } 33 | ''' 34 | 35 | kernel_Correlation_updateOutput = ''' 36 | extern "C" __global__ void kernel_Correlation_updateOutput( 37 | const int n, 38 | const float* rbot0, 39 | const float* rbot1, 40 | float* top 41 | ) { 42 | extern __shared__ char patch_data_char[]; 43 | 44 | float *patch_data = (float *)patch_data_char; 45 | 46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 47 | int x1 = (blockIdx.x + 3) * {{intStride}}; 48 | int y1 = (blockIdx.y + 3) * {{intStride}}; 49 | int item = blockIdx.z; 50 | int ch_off = threadIdx.x; 51 | 52 | // Load 3D patch into shared shared memory 53 | for (int j = 0; j < 1; j++) { // HEIGHT 54 | for (int i = 0; i < 1; i++) { // WIDTH 55 | int ji_off = (j + i) * SIZE_3(rbot0); 56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 58 | int idxPatchData = ji_off + ch; 59 | patch_data[idxPatchData] = rbot0[idx1]; 60 | } 61 | } 62 | } 63 | 64 | __syncthreads(); 65 | 66 | __shared__ float sum[32]; 67 | 68 | // Compute correlation 69 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 70 | sum[ch_off] = 0; 71 | 72 | int s2o = (top_channel % 7 - 3) * {{intStride}}; 73 | int s2p = (top_channel / 7 - 3) * {{intStride}}; 74 | 75 | for (int j = 0; j < 1; j++) { // HEIGHT 76 | for (int i = 0; i < 1; i++) { // WIDTH 77 | int ji_off = (j + i) * SIZE_3(rbot0); 78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 79 | int x2 = x1 + s2o; 80 | int y2 = y1 + s2p; 81 | 82 | int idxPatchData = ji_off + ch; 83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 84 | 85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 86 | } 87 | } 88 | } 89 | 90 | __syncthreads(); 91 | 92 | if (ch_off == 0) { 93 | float total_sum = 0; 94 | for (int idx = 0; idx < 32; idx++) { 95 | total_sum += sum[idx]; 96 | } 97 | const int sumelems = SIZE_3(rbot0); 98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 100 | } 101 | } 102 | } 103 | ''' 104 | 105 | kernel_Correlation_updateGradOne = ''' 106 | #define ROUND_OFF 50000 107 | 108 | extern "C" __global__ void kernel_Correlation_updateGradOne( 109 | const int n, 110 | const int intSample, 111 | const float* rbot0, 112 | const float* rbot1, 113 | const float* gradOutput, 114 | float* gradOne, 115 | float* gradTwo 116 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 117 | int n = intIndex % SIZE_1(gradOne); // channels 118 | int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 3*{{intStride}}; // w-pos 119 | int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 3*{{intStride}}; // h-pos 120 | 121 | // round_off is a trick to enable integer division with ceil, even for negative numbers 122 | // We use a large offset, for the inner part not to become negative. 123 | const int round_off = ROUND_OFF; 124 | const int round_off_s1 = {{intStride}} * round_off; 125 | 126 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 127 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 128 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 129 | 130 | // Same here: 131 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}} 132 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}} 133 | 134 | float sum = 0; 135 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 136 | xmin = max(0,xmin); 137 | xmax = min(SIZE_3(gradOutput)-1,xmax); 138 | 139 | ymin = max(0,ymin); 140 | ymax = min(SIZE_2(gradOutput)-1,ymax); 141 | 142 | for (int p = -3; p <= 3; p++) { 143 | for (int o = -3; o <= 3; o++) { 144 | // Get rbot1 data: 145 | int s2o = {{intStride}} * o; 146 | int s2p = {{intStride}} * p; 147 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; 148 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] 149 | 150 | // Index offset for gradOutput in following loops: 151 | int op = (p+3) * 7 + (o+3); // index[o,p] 152 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 153 | 154 | for (int y = ymin; y <= ymax; y++) { 155 | for (int x = xmin; x <= xmax; x++) { 156 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 157 | sum += gradOutput[idxgradOutput] * bot1tmp; 158 | } 159 | } 160 | } 161 | } 162 | } 163 | const int sumelems = SIZE_1(gradOne); 164 | const int bot0index = ((n * SIZE_2(gradOne)) + (m-3*{{intStride}})) * SIZE_3(gradOne) + (l-3*{{intStride}}); 165 | gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems; 166 | } } 167 | ''' 168 | 169 | kernel_Correlation_updateGradTwo = ''' 170 | #define ROUND_OFF 50000 171 | 172 | extern "C" __global__ void kernel_Correlation_updateGradTwo( 173 | const int n, 174 | const int intSample, 175 | const float* rbot0, 176 | const float* rbot1, 177 | const float* gradOutput, 178 | float* gradOne, 179 | float* gradTwo 180 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 181 | int n = intIndex % SIZE_1(gradTwo); // channels 182 | int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 3*{{intStride}}; // w-pos 183 | int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 3*{{intStride}}; // h-pos 184 | 185 | // round_off is a trick to enable integer division with ceil, even for negative numbers 186 | // We use a large offset, for the inner part not to become negative. 187 | const int round_off = ROUND_OFF; 188 | const int round_off_s1 = {{intStride}} * round_off; 189 | 190 | float sum = 0; 191 | for (int p = -3; p <= 3; p++) { 192 | for (int o = -3; o <= 3; o++) { 193 | int s2o = {{intStride}} * o; 194 | int s2p = {{intStride}} * p; 195 | 196 | //Get X,Y ranges and clamp 197 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 198 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 199 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 200 | 201 | // Same here: 202 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}} 203 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}} 204 | 205 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 206 | xmin = max(0,xmin); 207 | xmax = min(SIZE_3(gradOutput)-1,xmax); 208 | 209 | ymin = max(0,ymin); 210 | ymax = min(SIZE_2(gradOutput)-1,ymax); 211 | 212 | // Get rbot0 data: 213 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; 214 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] 215 | 216 | // Index offset for gradOutput in following loops: 217 | int op = (p+3) * 7 + (o+3); // index[o,p] 218 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 219 | 220 | for (int y = ymin; y <= ymax; y++) { 221 | for (int x = xmin; x <= xmax; x++) { 222 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 223 | sum += gradOutput[idxgradOutput] * bot0tmp; 224 | } 225 | } 226 | } 227 | } 228 | } 229 | const int sumelems = SIZE_1(gradTwo); 230 | const int bot1index = ((n * SIZE_2(gradTwo)) + (m-3*{{intStride}})) * SIZE_3(gradTwo) + (l-3*{{intStride}}); 231 | gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems; 232 | } } 233 | ''' 234 | 235 | def cupy_kernel(strFunction, objVariables): 236 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride'])) 237 | 238 | while True: 239 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 240 | 241 | if objMatch is None: 242 | break 243 | # end 244 | 245 | intArg = int(objMatch.group(2)) 246 | 247 | strTensor = objMatch.group(4) 248 | intSizes = objVariables[strTensor].size() 249 | 250 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) 251 | # end 252 | 253 | while True: 254 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 255 | 256 | if objMatch is None: 257 | break 258 | # end 259 | 260 | intArgs = int(objMatch.group(2)) 261 | strArgs = objMatch.group(4).split(',') 262 | 263 | strTensor = strArgs[0] 264 | intStrides = objVariables[strTensor].stride() 265 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ] 266 | 267 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 268 | # end 269 | 270 | return strKernel 271 | # end 272 | 273 | @cupy.memoize(for_each_device=True) 274 | def cupy_launch(strFunction, strKernel): 275 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 276 | # end 277 | 278 | class _FunctionCorrelation(torch.autograd.Function): 279 | @staticmethod 280 | def forward(self, one, two, intStride): 281 | rbot0 = one.new_zeros([ one.shape[0], one.shape[2] + (6 * intStride), one.shape[3] + (6 * intStride), one.shape[1] ]) 282 | rbot1 = one.new_zeros([ one.shape[0], one.shape[2] + (6 * intStride), one.shape[3] + (6 * intStride), one.shape[1] ]) 283 | 284 | self.intStride = intStride 285 | 286 | one = one.contiguous(); assert(one.is_cuda == True) 287 | two = two.contiguous(); assert(two.is_cuda == True) 288 | 289 | output = one.new_zeros([ one.shape[0], 49, int(math.ceil(one.shape[2] / intStride)), int(math.ceil(one.shape[3] / intStride)) ]) 290 | 291 | if one.is_cuda == True: 292 | n = one.shape[2] * one.shape[3] 293 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 294 | 'intStride': self.intStride, 295 | 'input': one, 296 | 'output': rbot0 297 | }))( 298 | grid=tuple([ int((n + 16 - 1) / 16), one.shape[1], one.shape[0] ]), 299 | block=tuple([ 16, 1, 1 ]), 300 | args=[ cupy.int32(n), one.data_ptr(), rbot0.data_ptr() ] 301 | ) 302 | 303 | n = two.shape[2] * two.shape[3] 304 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 305 | 'intStride': self.intStride, 306 | 'input': two, 307 | 'output': rbot1 308 | }))( 309 | grid=tuple([ int((n + 16 - 1) / 16), two.shape[1], two.shape[0] ]), 310 | block=tuple([ 16, 1, 1 ]), 311 | args=[ cupy.int32(n), two.data_ptr(), rbot1.data_ptr() ] 312 | ) 313 | 314 | n = output.shape[1] * output.shape[2] * output.shape[3] 315 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 316 | 'intStride': self.intStride, 317 | 'rbot0': rbot0, 318 | 'rbot1': rbot1, 319 | 'top': output 320 | }))( 321 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), 322 | block=tuple([ 32, 1, 1 ]), 323 | shared_mem=one.shape[1] * 4, 324 | args=[ cupy.int32(n), rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] 325 | ) 326 | 327 | elif one.is_cuda == False: 328 | raise NotImplementedError() 329 | 330 | # end 331 | 332 | self.save_for_backward(one, two, rbot0, rbot1) 333 | 334 | return output 335 | # end 336 | 337 | @staticmethod 338 | def backward(self, gradOutput): 339 | one, two, rbot0, rbot1 = self.saved_tensors 340 | 341 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) 342 | 343 | gradOne = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[0] == True else None 344 | gradTwo = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[1] == True else None 345 | 346 | if one.is_cuda == True: 347 | if gradOne is not None: 348 | for intSample in range(one.shape[0]): 349 | n = one.shape[1] * one.shape[2] * one.shape[3] 350 | cupy_launch('kernel_Correlation_updateGradOne', cupy_kernel('kernel_Correlation_updateGradOne', { 351 | 'intStride': self.intStride, 352 | 'rbot0': rbot0, 353 | 'rbot1': rbot1, 354 | 'gradOutput': gradOutput, 355 | 'gradOne': gradOne, 356 | 'gradTwo': None 357 | }))( 358 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 359 | block=tuple([ 512, 1, 1 ]), 360 | args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradOne.data_ptr(), None ] 361 | ) 362 | # end 363 | # end 364 | 365 | if gradTwo is not None: 366 | for intSample in range(one.shape[0]): 367 | n = one.shape[1] * one.shape[2] * one.shape[3] 368 | cupy_launch('kernel_Correlation_updateGradTwo', cupy_kernel('kernel_Correlation_updateGradTwo', { 369 | 'intStride': self.intStride, 370 | 'rbot0': rbot0, 371 | 'rbot1': rbot1, 372 | 'gradOutput': gradOutput, 373 | 'gradOne': None, 374 | 'gradTwo': gradTwo 375 | }))( 376 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 377 | block=tuple([ 512, 1, 1 ]), 378 | args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradTwo.data_ptr() ] 379 | ) 380 | # end 381 | # end 382 | 383 | elif one.is_cuda == False: 384 | raise NotImplementedError() 385 | 386 | # end 387 | 388 | return gradOne, gradTwo, None 389 | # end 390 | # end 391 | 392 | def FunctionCorrelation(tenOne, tenTwo, intStride): 393 | return _FunctionCorrelation.apply(tenOne, tenTwo, intStride) 394 | # end 395 | 396 | class ModuleCorrelation(torch.nn.Module): 397 | def __init__(self): 398 | super().__init__() 399 | # end 400 | 401 | def forward(self, tenOne, tenTwo, intStride): 402 | return _FunctionCorrelation.apply(tenOne, tenTwo, intStride) 403 | # end 404 | # end -------------------------------------------------------------------------------- /liteflownet/requirements.txt: -------------------------------------------------------------------------------- 1 | cupy>=5.0.0 2 | numpy>=1.15.0 3 | Pillow>=5.0.0 4 | torch>=1.3.0 -------------------------------------------------------------------------------- /liteflownet/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import getopt 4 | import math 5 | import numpy 6 | import PIL 7 | import PIL.Image 8 | import sys 9 | import torch 10 | 11 | try: 12 | from .correlation import correlation # the custom cost volume layer 13 | except: 14 | sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python 15 | # end 16 | 17 | ########################################################## 18 | 19 | assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 20 | 21 | torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance 22 | 23 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance 24 | 25 | ########################################################## 26 | 27 | arguments_strModel = 'default' # 'default', or 'kitti', or 'sintel' 28 | arguments_strOne = './images/one.png' 29 | arguments_strTwo = './images/two.png' 30 | arguments_strOut = './out.flo' 31 | 32 | for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: 33 | if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use 34 | if strOption == '--one' and strArgument != '': arguments_strOne = strArgument # path to the first frame 35 | if strOption == '--two' and strArgument != '': arguments_strTwo = strArgument # path to the second frame 36 | if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored 37 | # end 38 | 39 | ########################################################## 40 | 41 | backwarp_tenGrid = {} 42 | 43 | def backwarp(tenInput, tenFlow): 44 | if str(tenFlow.shape) not in backwarp_tenGrid: 45 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) 46 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) 47 | 48 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() 49 | # end 50 | 51 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 52 | 53 | return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) 54 | # end 55 | 56 | ########################################################## 57 | 58 | class Network(torch.nn.Module): 59 | def __init__(self): 60 | super().__init__() 61 | 62 | class Features(torch.nn.Module): 63 | def __init__(self): 64 | super().__init__() 65 | 66 | self.netOne = torch.nn.Sequential( 67 | torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7, stride=1, padding=3), 68 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 69 | ) 70 | 71 | self.netTwo = torch.nn.Sequential( 72 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1), 73 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 74 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 75 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 76 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 77 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 78 | ) 79 | 80 | self.netThr = torch.nn.Sequential( 81 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 82 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 83 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 84 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 85 | ) 86 | 87 | self.netFou = torch.nn.Sequential( 88 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), 89 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 90 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 91 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 92 | ) 93 | 94 | self.netFiv = torch.nn.Sequential( 95 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), 96 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 97 | ) 98 | 99 | self.netSix = torch.nn.Sequential( 100 | torch.nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=2, padding=1), 101 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 102 | ) 103 | # end 104 | 105 | def forward(self, tenInput): 106 | tenOne = self.netOne(tenInput) 107 | tenTwo = self.netTwo(tenOne) 108 | tenThr = self.netThr(tenTwo) 109 | tenFou = self.netFou(tenThr) 110 | tenFiv = self.netFiv(tenFou) 111 | tenSix = self.netSix(tenFiv) 112 | 113 | return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] 114 | # end 115 | # end 116 | 117 | class Matching(torch.nn.Module): 118 | def __init__(self, intLevel): 119 | super().__init__() 120 | 121 | self.fltBackwarp = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] 122 | 123 | if intLevel != 2: 124 | self.netFeat = torch.nn.Sequential() 125 | 126 | elif intLevel == 2: 127 | self.netFeat = torch.nn.Sequential( 128 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0), 129 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 130 | ) 131 | 132 | # end 133 | 134 | if intLevel == 6: 135 | self.netUpflow = None 136 | 137 | elif intLevel != 6: 138 | self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1, bias=False, groups=2) 139 | 140 | # end 141 | 142 | if intLevel >= 4: 143 | self.netUpcorr = None 144 | 145 | elif intLevel < 4: 146 | self.netUpcorr = torch.nn.ConvTranspose2d(in_channels=49, out_channels=49, kernel_size=4, stride=2, padding=1, bias=False, groups=49) 147 | 148 | # end 149 | 150 | self.netMain = torch.nn.Sequential( 151 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1), 152 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 153 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 154 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 155 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 156 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 157 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) 158 | ) 159 | # end 160 | 161 | def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow): 162 | tenFeaturesOne = self.netFeat(tenFeaturesOne) 163 | tenFeaturesTwo = self.netFeat(tenFeaturesTwo) 164 | 165 | if tenFlow is not None: 166 | tenFlow = self.netUpflow(tenFlow) 167 | # end 168 | 169 | if tenFlow is not None: 170 | tenFeaturesTwo = backwarp(tenInput=tenFeaturesTwo, tenFlow=tenFlow * self.fltBackwarp) 171 | # end 172 | 173 | if self.netUpcorr is None: 174 | tenCorrelation = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenFeaturesOne, tenTwo=tenFeaturesTwo, intStride=1), negative_slope=0.1, inplace=False) 175 | 176 | elif self.netUpcorr is not None: 177 | tenCorrelation = self.netUpcorr(torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenFeaturesOne, tenTwo=tenFeaturesTwo, intStride=2), negative_slope=0.1, inplace=False)) 178 | 179 | # end 180 | 181 | return (tenFlow if tenFlow is not None else 0.0) + self.netMain(tenCorrelation) 182 | # end 183 | # end 184 | 185 | class Subpixel(torch.nn.Module): 186 | def __init__(self, intLevel): 187 | super().__init__() 188 | 189 | self.fltBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] 190 | 191 | if intLevel != 2: 192 | self.netFeat = torch.nn.Sequential() 193 | 194 | elif intLevel == 2: 195 | self.netFeat = torch.nn.Sequential( 196 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0), 197 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 198 | ) 199 | 200 | # end 201 | 202 | self.netMain = torch.nn.Sequential( 203 | torch.nn.Conv2d(in_channels=[ 0, 0, 130, 130, 194, 258, 386 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1), 204 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 205 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 206 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 207 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 208 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 209 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) 210 | ) 211 | # end 212 | 213 | def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow): 214 | tenFeaturesOne = self.netFeat(tenFeaturesOne) 215 | tenFeaturesTwo = self.netFeat(tenFeaturesTwo) 216 | 217 | if tenFlow is not None: 218 | tenFeaturesTwo = backwarp(tenInput=tenFeaturesTwo, tenFlow=tenFlow * self.fltBackward) 219 | # end 220 | 221 | return (tenFlow if tenFlow is not None else 0.0) + self.netMain(torch.cat([ tenFeaturesOne, tenFeaturesTwo, tenFlow ], 1)) 222 | # end 223 | # end 224 | 225 | class Regularization(torch.nn.Module): 226 | def __init__(self, intLevel): 227 | super().__init__() 228 | 229 | self.fltBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] 230 | 231 | self.intUnfold = [ 0, 0, 7, 5, 5, 3, 3 ][intLevel] 232 | 233 | if intLevel >= 5: 234 | self.netFeat = torch.nn.Sequential() 235 | 236 | elif intLevel < 5: 237 | self.netFeat = torch.nn.Sequential( 238 | torch.nn.Conv2d(in_channels=[ 0, 0, 32, 64, 96, 128, 192 ][intLevel], out_channels=128, kernel_size=1, stride=1, padding=0), 239 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 240 | ) 241 | 242 | # end 243 | 244 | self.netMain = torch.nn.Sequential( 245 | torch.nn.Conv2d(in_channels=[ 0, 0, 131, 131, 131, 131, 195 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1), 246 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 247 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 248 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 249 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 250 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 251 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 252 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 253 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 254 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 255 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 256 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 257 | ) 258 | 259 | if intLevel >= 5: 260 | self.netDist = torch.nn.Sequential( 261 | torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) 262 | ) 263 | 264 | elif intLevel < 5: 265 | self.netDist = torch.nn.Sequential( 266 | torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=([ 0, 0, 7, 5, 5, 3, 3 ][intLevel], 1), stride=1, padding=([ 0, 0, 3, 2, 2, 1, 1 ][intLevel], 0)), 267 | torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=(1, [ 0, 0, 7, 5, 5, 3, 3 ][intLevel]), stride=1, padding=(0, [ 0, 0, 3, 2, 2, 1, 1 ][intLevel])) 268 | ) 269 | 270 | # end 271 | 272 | self.netScaleX = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0) 273 | self.netScaleY = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0) 274 | # eny 275 | 276 | def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow): 277 | tenDifference = ((tenOne - backwarp(tenInput=tenTwo, tenFlow=tenFlow * self.fltBackward)) ** 2).sum(1, True).sqrt().detach() 278 | 279 | tenDist = self.netDist(self.netMain(torch.cat([ tenDifference, tenFlow - tenFlow.view(tenFlow.shape[0], 2, -1).mean(2, True).view(tenFlow.shape[0], 2, 1, 1), self.netFeat(tenFeaturesOne) ], 1))) 280 | tenDist = (tenDist ** 2).neg() 281 | tenDist = (tenDist - tenDist.max(1, True)[0]).exp() 282 | 283 | tenDivisor = tenDist.sum(1, True).reciprocal() 284 | 285 | tenScaleX = self.netScaleX(tenDist * torch.nn.functional.unfold(input=tenFlow[:, 0:1, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tenDist)) * tenDivisor 286 | tenScaleY = self.netScaleY(tenDist * torch.nn.functional.unfold(input=tenFlow[:, 1:2, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tenDist)) * tenDivisor 287 | 288 | return torch.cat([ tenScaleX, tenScaleY ], 1) 289 | # end 290 | # end 291 | 292 | self.netFeatures = Features() 293 | self.netMatching = torch.nn.ModuleList([ Matching(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) 294 | self.netSubpixel = torch.nn.ModuleList([ Subpixel(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) 295 | self.netRegularization = torch.nn.ModuleList([ Regularization(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) 296 | 297 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-liteflownet/network-' + arguments_strModel + '.pytorch').items() }) 298 | # self.load_state_dict(torch.load('./liteflownet/network-default.pth')) 299 | # end 300 | 301 | def forward(self, tenOne, tenTwo): 302 | tenOne[:, 0, :, :] = tenOne[:, 0, :, :] - 0.411618 303 | tenOne[:, 1, :, :] = tenOne[:, 1, :, :] - 0.434631 304 | tenOne[:, 2, :, :] = tenOne[:, 2, :, :] - 0.454253 305 | 306 | tenTwo[:, 0, :, :] = tenTwo[:, 0, :, :] - 0.410782 307 | tenTwo[:, 1, :, :] = tenTwo[:, 1, :, :] - 0.433645 308 | tenTwo[:, 2, :, :] = tenTwo[:, 2, :, :] - 0.452793 309 | 310 | tenFeaturesOne = self.netFeatures(tenOne) 311 | tenFeaturesTwo = self.netFeatures(tenTwo) 312 | 313 | tenOne = [ tenOne ] 314 | tenTwo = [ tenTwo ] 315 | 316 | for intLevel in [ 1, 2, 3, 4, 5 ]: 317 | tenOne.append(torch.nn.functional.interpolate(input=tenOne[-1], size=(tenFeaturesOne[intLevel].shape[2], tenFeaturesOne[intLevel].shape[3]), mode='bilinear', align_corners=False)) 318 | tenTwo.append(torch.nn.functional.interpolate(input=tenTwo[-1], size=(tenFeaturesTwo[intLevel].shape[2], tenFeaturesTwo[intLevel].shape[3]), mode='bilinear', align_corners=False)) 319 | # end 320 | 321 | tenFlow = None 322 | 323 | for intLevel in [ -1, -2, -3, -4, -5 ]: 324 | tenFlow = self.netMatching[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow) 325 | tenFlow = self.netSubpixel[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow) 326 | tenFlow = self.netRegularization[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow) 327 | # end 328 | 329 | return tenFlow * 20.0 330 | # end 331 | # end 332 | 333 | netNetwork = None 334 | 335 | ########################################################## 336 | 337 | def estimate(tenOne, tenTwo): 338 | global netNetwork 339 | 340 | if netNetwork is None: 341 | netNetwork = Network().cuda().eval() 342 | # end 343 | 344 | assert(tenOne.shape[1] == tenTwo.shape[1]) 345 | assert(tenOne.shape[2] == tenTwo.shape[2]) 346 | 347 | intWidth = tenOne.shape[2] 348 | intHeight = tenOne.shape[1] 349 | 350 | # assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 351 | # assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 352 | 353 | tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth) 354 | tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth) 355 | 356 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0)) 357 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0)) 358 | 359 | tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 360 | tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 361 | 362 | tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False) 363 | 364 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 365 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 366 | 367 | return tenFlow[0, :, :, :].cpu() 368 | # end 369 | 370 | ########################################################## 371 | 372 | if __name__ == '__main__': 373 | tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strOne))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) 374 | tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strTwo))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) 375 | 376 | tenOutput = estimate(tenOne, tenTwo) 377 | 378 | objOutput = open(arguments_strOut, 'wb') 379 | 380 | numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) 381 | numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) 382 | numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) 383 | 384 | objOutput.close() 385 | # end 386 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | 9 | class Ternary(nn.Module): 10 | def __init__(self, patch_size=7): 11 | super(Ternary, self).__init__() 12 | self.patch_size = patch_size 13 | out_channels = patch_size * patch_size 14 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) 15 | self.w = np.transpose(self.w, (3, 2, 0, 1)) 16 | self.w = torch.tensor(self.w).float().to(device) 17 | 18 | def transform(self, tensor): 19 | tensor_ = tensor.mean(dim=1, keepdim=True) 20 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None) 21 | loc_diff = patches - tensor_ 22 | loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff ** 2) 23 | return loc_diff_norm 24 | 25 | def valid_mask(self, tensor): 26 | padding = self.patch_size//2 27 | b, c, h, w = tensor.size() 28 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor) 29 | mask = F.pad(inner, [padding] * 4) 30 | return mask 31 | 32 | def forward(self, x, y): 33 | loc_diff_x = self.transform(x) 34 | loc_diff_y = self.transform(y) 35 | diff = loc_diff_x - loc_diff_y.detach() 36 | dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True) 37 | mask = self.valid_mask(x) 38 | loss = (dist * mask).mean() 39 | return loss 40 | 41 | 42 | class Geometry(nn.Module): 43 | def __init__(self, patch_size=3): 44 | super(Geometry, self).__init__() 45 | self.patch_size = patch_size 46 | out_channels = patch_size * patch_size 47 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) 48 | self.w = np.transpose(self.w, (3, 2, 0, 1)) 49 | self.w = torch.tensor(self.w).float().to(device) 50 | 51 | def transform(self, tensor): 52 | b, c, h, w = tensor.size() 53 | tensor_ = tensor.reshape(b*c, 1, h, w) 54 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None) 55 | loc_diff = patches - tensor_ 56 | loc_diff_ = loc_diff.reshape(b, c*(self.patch_size**2), h, w) 57 | loc_diff_norm = loc_diff_ / torch.sqrt(0.81 + loc_diff_ ** 2) 58 | return loc_diff_norm 59 | 60 | def valid_mask(self, tensor): 61 | padding = self.patch_size//2 62 | b, c, h, w = tensor.size() 63 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor) 64 | mask = F.pad(inner, [padding] * 4) 65 | return mask 66 | 67 | def forward(self, x, y): 68 | loc_diff_x = self.transform(x) 69 | loc_diff_y = self.transform(y) 70 | diff = loc_diff_x - loc_diff_y 71 | dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True) 72 | mask = self.valid_mask(x) 73 | loss = (dist * mask).mean() 74 | return loss 75 | 76 | 77 | class Charbonnier_L1(nn.Module): 78 | def __init__(self): 79 | super(Charbonnier_L1, self).__init__() 80 | 81 | def forward(self, diff, mask=None): 82 | if mask is None: 83 | loss = ((diff ** 2 + 1e-6) ** 0.5).mean() 84 | else: 85 | loss = (((diff ** 2 + 1e-6) ** 0.5) * mask).mean() / (mask.mean() + 1e-9) 86 | return loss 87 | 88 | 89 | class Charbonnier_Ada(nn.Module): 90 | def __init__(self): 91 | super(Charbonnier_Ada, self).__init__() 92 | 93 | def forward(self, diff, weight): 94 | alpha = weight / 2 95 | epsilon = 10 ** (-(10 * weight - 1) / 3) 96 | loss = ((diff ** 2 + epsilon ** 2) ** alpha).mean() 97 | return loss 98 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from math import exp 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 11 | return gauss/gauss.sum() 12 | 13 | 14 | def create_window(window_size, channel=1): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) 17 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 18 | return window 19 | 20 | 21 | def create_window_3d(window_size, channel=1): 22 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 23 | _2D_window = _1D_window.mm(_1D_window.t()) 24 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) 25 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) 26 | return window 27 | 28 | 29 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 30 | if val_range is None: 31 | if torch.max(img1) > 128: 32 | max_val = 255 33 | else: 34 | max_val = 1 35 | 36 | if torch.min(img1) < -0.5: 37 | min_val = -1 38 | else: 39 | min_val = 0 40 | L = max_val - min_val 41 | else: 42 | L = val_range 43 | 44 | padd = 0 45 | (_, channel, height, width) = img1.size() 46 | if window is None: 47 | real_size = min(window_size, height, width) 48 | window = create_window(real_size, channel=channel).to(img1.device) 49 | 50 | mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 51 | mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 52 | 53 | mu1_sq = mu1.pow(2) 54 | mu2_sq = mu2.pow(2) 55 | mu1_mu2 = mu1 * mu2 56 | 57 | sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq 58 | sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq 59 | sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 60 | 61 | C1 = (0.01 * L) ** 2 62 | C2 = (0.03 * L) ** 2 63 | 64 | v1 = 2.0 * sigma12 + C2 65 | v2 = sigma1_sq + sigma2_sq + C2 66 | cs = torch.mean(v1 / v2) 67 | 68 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 69 | 70 | if size_average: 71 | ret = ssim_map.mean() 72 | else: 73 | ret = ssim_map.mean(1).mean(1).mean(1) 74 | 75 | if full: 76 | return ret, cs 77 | return ret 78 | 79 | 80 | def calculate_ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 81 | if val_range is None: 82 | if torch.max(img1) > 128: 83 | max_val = 255 84 | else: 85 | max_val = 1 86 | 87 | if torch.min(img1) < -0.5: 88 | min_val = -1 89 | else: 90 | min_val = 0 91 | L = max_val - min_val 92 | else: 93 | L = val_range 94 | 95 | padd = 0 96 | (_, _, height, width) = img1.size() 97 | if window is None: 98 | real_size = min(window_size, height, width) 99 | window = create_window_3d(real_size, channel=1).to(img1.device) 100 | 101 | img1 = img1.unsqueeze(1) 102 | img2 = img2.unsqueeze(1) 103 | 104 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 105 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 106 | 107 | mu1_sq = mu1.pow(2) 108 | mu2_sq = mu2.pow(2) 109 | mu1_mu2 = mu1 * mu2 110 | 111 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq 112 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq 113 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 114 | 115 | C1 = (0.01 * L) ** 2 116 | C2 = (0.03 * L) ** 2 117 | 118 | v1 = 2.0 * sigma12 + C2 119 | v2 = sigma1_sq + sigma2_sq + C2 120 | cs = torch.mean(v1 / v2) 121 | 122 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 123 | 124 | if size_average: 125 | ret = ssim_map.mean() 126 | else: 127 | ret = ssim_map.mean(1).mean(1).mean(1) 128 | 129 | if full: 130 | return ret, cs 131 | return ret 132 | 133 | 134 | def calculate_psnr(img1, img2): 135 | psnr = -10 * torch.log10(((img1 - img2) * (img1 - img2)).mean()) 136 | return psnr 137 | 138 | 139 | def calculate_ie(img1, img2): 140 | ie = torch.abs(torch.round(img1 * 255.0) - torch.round(img2 * 255.0)).mean() 141 | return ie 142 | -------------------------------------------------------------------------------- /models/IFRNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import warp, get_robust_weight 5 | from loss import * 6 | 7 | 8 | def resize(x, scale_factor): 9 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) 10 | 11 | 12 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 13 | return nn.Sequential( 14 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), 15 | nn.PReLU(out_channels) 16 | ) 17 | 18 | 19 | class ResBlock(nn.Module): 20 | def __init__(self, in_channels, side_channels, bias=True): 21 | super(ResBlock, self).__init__() 22 | self.side_channels = side_channels 23 | self.conv1 = nn.Sequential( 24 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 25 | nn.PReLU(in_channels) 26 | ) 27 | self.conv2 = nn.Sequential( 28 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 29 | nn.PReLU(side_channels) 30 | ) 31 | self.conv3 = nn.Sequential( 32 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 33 | nn.PReLU(in_channels) 34 | ) 35 | self.conv4 = nn.Sequential( 36 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 37 | nn.PReLU(side_channels) 38 | ) 39 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) 40 | self.prelu = nn.PReLU(in_channels) 41 | 42 | def forward(self, x): 43 | out = self.conv1(x) 44 | out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone()) 45 | out = self.conv3(out) 46 | out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone()) 47 | out = self.prelu(x + self.conv5(out)) 48 | return out 49 | 50 | 51 | class Encoder(nn.Module): 52 | def __init__(self): 53 | super(Encoder, self).__init__() 54 | self.pyramid1 = nn.Sequential( 55 | convrelu(3, 32, 3, 2, 1), 56 | convrelu(32, 32, 3, 1, 1) 57 | ) 58 | self.pyramid2 = nn.Sequential( 59 | convrelu(32, 48, 3, 2, 1), 60 | convrelu(48, 48, 3, 1, 1) 61 | ) 62 | self.pyramid3 = nn.Sequential( 63 | convrelu(48, 72, 3, 2, 1), 64 | convrelu(72, 72, 3, 1, 1) 65 | ) 66 | self.pyramid4 = nn.Sequential( 67 | convrelu(72, 96, 3, 2, 1), 68 | convrelu(96, 96, 3, 1, 1) 69 | ) 70 | 71 | def forward(self, img): 72 | f1 = self.pyramid1(img) 73 | f2 = self.pyramid2(f1) 74 | f3 = self.pyramid3(f2) 75 | f4 = self.pyramid4(f3) 76 | return f1, f2, f3, f4 77 | 78 | 79 | class Decoder4(nn.Module): 80 | def __init__(self): 81 | super(Decoder4, self).__init__() 82 | self.convblock = nn.Sequential( 83 | convrelu(192+1, 192), 84 | ResBlock(192, 32), 85 | nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True) 86 | ) 87 | 88 | def forward(self, f0, f1, embt): 89 | b, c, h, w = f0.shape 90 | embt = embt.repeat(1, 1, h, w) 91 | f_in = torch.cat([f0, f1, embt], 1) 92 | f_out = self.convblock(f_in) 93 | return f_out 94 | 95 | 96 | class Decoder3(nn.Module): 97 | def __init__(self): 98 | super(Decoder3, self).__init__() 99 | self.convblock = nn.Sequential( 100 | convrelu(220, 216), 101 | ResBlock(216, 32), 102 | nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True) 103 | ) 104 | 105 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 106 | f0_warp = warp(f0, up_flow0) 107 | f1_warp = warp(f1, up_flow1) 108 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 109 | f_out = self.convblock(f_in) 110 | return f_out 111 | 112 | 113 | class Decoder2(nn.Module): 114 | def __init__(self): 115 | super(Decoder2, self).__init__() 116 | self.convblock = nn.Sequential( 117 | convrelu(148, 144), 118 | ResBlock(144, 32), 119 | nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True) 120 | ) 121 | 122 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 123 | f0_warp = warp(f0, up_flow0) 124 | f1_warp = warp(f1, up_flow1) 125 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 126 | f_out = self.convblock(f_in) 127 | return f_out 128 | 129 | 130 | class Decoder1(nn.Module): 131 | def __init__(self): 132 | super(Decoder1, self).__init__() 133 | self.convblock = nn.Sequential( 134 | convrelu(100, 96), 135 | ResBlock(96, 32), 136 | nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True) 137 | ) 138 | 139 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 140 | f0_warp = warp(f0, up_flow0) 141 | f1_warp = warp(f1, up_flow1) 142 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 143 | f_out = self.convblock(f_in) 144 | return f_out 145 | 146 | 147 | class Model(nn.Module): 148 | def __init__(self, local_rank=-1, lr=1e-4): 149 | super(Model, self).__init__() 150 | self.encoder = Encoder() 151 | self.decoder4 = Decoder4() 152 | self.decoder3 = Decoder3() 153 | self.decoder2 = Decoder2() 154 | self.decoder1 = Decoder1() 155 | self.l1_loss = Charbonnier_L1() 156 | self.tr_loss = Ternary(7) 157 | self.rb_loss = Charbonnier_Ada() 158 | self.gc_loss = Geometry(3) 159 | 160 | 161 | def inference(self, img0, img1, embt, scale_factor=1.0): 162 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 163 | img0 = img0 - mean_ 164 | img1 = img1 - mean_ 165 | 166 | img0_ = resize(img0, scale_factor=scale_factor) 167 | img1_ = resize(img1, scale_factor=scale_factor) 168 | 169 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) 170 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) 171 | 172 | out4 = self.decoder4(f0_4, f1_4, embt) 173 | up_flow0_4 = out4[:, 0:2] 174 | up_flow1_4 = out4[:, 2:4] 175 | ft_3_ = out4[:, 4:] 176 | 177 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) 178 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) 179 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) 180 | ft_2_ = out3[:, 4:] 181 | 182 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) 183 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) 184 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) 185 | ft_1_ = out2[:, 4:] 186 | 187 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) 188 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) 189 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) 190 | up_mask_1 = torch.sigmoid(out1[:, 4:5]) 191 | up_res_1 = out1[:, 5:] 192 | 193 | up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 194 | up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 195 | up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) 196 | up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) 197 | 198 | img0_warp = warp(img0, up_flow0_1) 199 | img1_warp = warp(img1, up_flow1_1) 200 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ 201 | imgt_pred = imgt_merge + up_res_1 202 | imgt_pred = torch.clamp(imgt_pred, 0, 1) 203 | return imgt_pred 204 | 205 | 206 | def forward(self, img0, img1, embt, imgt, flow=None): 207 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 208 | img0 = img0 - mean_ 209 | img1 = img1 - mean_ 210 | imgt_ = imgt - mean_ 211 | 212 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0) 213 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1) 214 | ft_1, ft_2, ft_3, ft_4 = self.encoder(imgt_) 215 | 216 | out4 = self.decoder4(f0_4, f1_4, embt) 217 | up_flow0_4 = out4[:, 0:2] 218 | up_flow1_4 = out4[:, 2:4] 219 | ft_3_ = out4[:, 4:] 220 | 221 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) 222 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) 223 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) 224 | ft_2_ = out3[:, 4:] 225 | 226 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) 227 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) 228 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) 229 | ft_1_ = out2[:, 4:] 230 | 231 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) 232 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) 233 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) 234 | up_mask_1 = torch.sigmoid(out1[:, 4:5]) 235 | up_res_1 = out1[:, 5:] 236 | 237 | img0_warp = warp(img0, up_flow0_1) 238 | img1_warp = warp(img1, up_flow1_1) 239 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ 240 | imgt_pred = imgt_merge + up_res_1 241 | imgt_pred = torch.clamp(imgt_pred, 0, 1) 242 | 243 | loss_rec = self.l1_loss(imgt_pred - imgt) + self.tr_loss(imgt_pred, imgt) 244 | loss_geo = 0.01 * (self.gc_loss(ft_1_, ft_1) + self.gc_loss(ft_2_, ft_2) + self.gc_loss(ft_3_, ft_3)) 245 | if flow is not None: 246 | robust_weight0 = get_robust_weight(up_flow0_1, flow[:, 0:2], beta=0.3) 247 | robust_weight1 = get_robust_weight(up_flow1_1, flow[:, 2:4], beta=0.3) 248 | loss_dis = 0.01 * (self.rb_loss(2.0 * resize(up_flow0_2, 2.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(2.0 * resize(up_flow1_2, 2.0) - flow[:, 2:4], weight=robust_weight1)) 249 | loss_dis += 0.01 * (self.rb_loss(4.0 * resize(up_flow0_3, 4.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(4.0 * resize(up_flow1_3, 4.0) - flow[:, 2:4], weight=robust_weight1)) 250 | loss_dis += 0.01 * (self.rb_loss(8.0 * resize(up_flow0_4, 8.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(8.0 * resize(up_flow1_4, 8.0) - flow[:, 2:4], weight=robust_weight1)) 251 | else: 252 | loss_dis = 0.00 * loss_geo 253 | 254 | return imgt_pred, loss_rec, loss_geo, loss_dis 255 | -------------------------------------------------------------------------------- /models/IFRNet_L.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import warp, get_robust_weight 5 | from loss import * 6 | 7 | 8 | def resize(x, scale_factor): 9 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) 10 | 11 | 12 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 13 | return nn.Sequential( 14 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), 15 | nn.PReLU(out_channels) 16 | ) 17 | 18 | 19 | class ResBlock(nn.Module): 20 | def __init__(self, in_channels, side_channels, bias=True): 21 | super(ResBlock, self).__init__() 22 | self.side_channels = side_channels 23 | self.conv1 = nn.Sequential( 24 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 25 | nn.PReLU(in_channels) 26 | ) 27 | self.conv2 = nn.Sequential( 28 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 29 | nn.PReLU(side_channels) 30 | ) 31 | self.conv3 = nn.Sequential( 32 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 33 | nn.PReLU(in_channels) 34 | ) 35 | self.conv4 = nn.Sequential( 36 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 37 | nn.PReLU(side_channels) 38 | ) 39 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) 40 | self.prelu = nn.PReLU(in_channels) 41 | 42 | def forward(self, x): 43 | out = self.conv1(x) 44 | out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone()) 45 | out = self.conv3(out) 46 | out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone()) 47 | out = self.prelu(x + self.conv5(out)) 48 | return out 49 | 50 | 51 | class Encoder(nn.Module): 52 | def __init__(self): 53 | super(Encoder, self).__init__() 54 | self.pyramid1 = nn.Sequential( 55 | convrelu(3, 64, 7, 2, 3), 56 | convrelu(64, 64, 3, 1, 1) 57 | ) 58 | self.pyramid2 = nn.Sequential( 59 | convrelu(64, 96, 3, 2, 1), 60 | convrelu(96, 96, 3, 1, 1) 61 | ) 62 | self.pyramid3 = nn.Sequential( 63 | convrelu(96, 144, 3, 2, 1), 64 | convrelu(144, 144, 3, 1, 1) 65 | ) 66 | self.pyramid4 = nn.Sequential( 67 | convrelu(144, 192, 3, 2, 1), 68 | convrelu(192, 192, 3, 1, 1) 69 | ) 70 | 71 | def forward(self, img): 72 | f1 = self.pyramid1(img) 73 | f2 = self.pyramid2(f1) 74 | f3 = self.pyramid3(f2) 75 | f4 = self.pyramid4(f3) 76 | return f1, f2, f3, f4 77 | 78 | 79 | class Decoder4(nn.Module): 80 | def __init__(self): 81 | super(Decoder4, self).__init__() 82 | self.convblock = nn.Sequential( 83 | convrelu(384+1, 384), 84 | ResBlock(384, 64), 85 | nn.ConvTranspose2d(384, 148, 4, 2, 1, bias=True) 86 | ) 87 | 88 | def forward(self, f0, f1, embt): 89 | b, c, h, w = f0.shape 90 | embt = embt.repeat(1, 1, h, w) 91 | f_in = torch.cat([f0, f1, embt], 1) 92 | f_out = self.convblock(f_in) 93 | return f_out 94 | 95 | 96 | class Decoder3(nn.Module): 97 | def __init__(self): 98 | super(Decoder3, self).__init__() 99 | self.convblock = nn.Sequential( 100 | convrelu(436, 432), 101 | ResBlock(432, 64), 102 | nn.ConvTranspose2d(432, 100, 4, 2, 1, bias=True) 103 | ) 104 | 105 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 106 | f0_warp = warp(f0, up_flow0) 107 | f1_warp = warp(f1, up_flow1) 108 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 109 | f_out = self.convblock(f_in) 110 | return f_out 111 | 112 | 113 | class Decoder2(nn.Module): 114 | def __init__(self): 115 | super(Decoder2, self).__init__() 116 | self.convblock = nn.Sequential( 117 | convrelu(292, 288), 118 | ResBlock(288, 64), 119 | nn.ConvTranspose2d(288, 68, 4, 2, 1, bias=True) 120 | ) 121 | 122 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 123 | f0_warp = warp(f0, up_flow0) 124 | f1_warp = warp(f1, up_flow1) 125 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 126 | f_out = self.convblock(f_in) 127 | return f_out 128 | 129 | 130 | class Decoder1(nn.Module): 131 | def __init__(self): 132 | super(Decoder1, self).__init__() 133 | self.convblock = nn.Sequential( 134 | convrelu(196, 192), 135 | ResBlock(192, 64), 136 | nn.ConvTranspose2d(192, 8, 4, 2, 1, bias=True) 137 | ) 138 | 139 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 140 | f0_warp = warp(f0, up_flow0) 141 | f1_warp = warp(f1, up_flow1) 142 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 143 | f_out = self.convblock(f_in) 144 | return f_out 145 | 146 | 147 | class Model(nn.Module): 148 | def __init__(self, local_rank=-1, lr=1e-4): 149 | super(Model, self).__init__() 150 | self.encoder = Encoder() 151 | self.decoder4 = Decoder4() 152 | self.decoder3 = Decoder3() 153 | self.decoder2 = Decoder2() 154 | self.decoder1 = Decoder1() 155 | self.l1_loss = Charbonnier_L1() 156 | self.tr_loss = Ternary(7) 157 | self.rb_loss = Charbonnier_Ada() 158 | self.gc_loss = Geometry(3) 159 | 160 | 161 | def inference(self, img0, img1, embt, scale_factor=1.0): 162 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 163 | img0 = img0 - mean_ 164 | img1 = img1 - mean_ 165 | 166 | img0_ = resize(img0, scale_factor=scale_factor) 167 | img1_ = resize(img1, scale_factor=scale_factor) 168 | 169 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) 170 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) 171 | 172 | out4 = self.decoder4(f0_4, f1_4, embt) 173 | up_flow0_4 = out4[:, 0:2] 174 | up_flow1_4 = out4[:, 2:4] 175 | ft_3_ = out4[:, 4:] 176 | 177 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) 178 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) 179 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) 180 | ft_2_ = out3[:, 4:] 181 | 182 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) 183 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) 184 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) 185 | ft_1_ = out2[:, 4:] 186 | 187 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) 188 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) 189 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) 190 | up_mask_1 = torch.sigmoid(out1[:, 4:5]) 191 | up_res_1 = out1[:, 5:] 192 | 193 | up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 194 | up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 195 | up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) 196 | up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) 197 | 198 | img0_warp = warp(img0, up_flow0_1) 199 | img1_warp = warp(img1, up_flow1_1) 200 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ 201 | imgt_pred = imgt_merge + up_res_1 202 | imgt_pred = torch.clamp(imgt_pred, 0, 1) 203 | return imgt_pred 204 | 205 | 206 | def forward(self, img0, img1, embt, imgt, flow=None): 207 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 208 | img0 = img0 - mean_ 209 | img1 = img1 - mean_ 210 | imgt_ = imgt - mean_ 211 | 212 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0) 213 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1) 214 | ft_1, ft_2, ft_3, ft_4 = self.encoder(imgt_) 215 | 216 | out4 = self.decoder4(f0_4, f1_4, embt) 217 | up_flow0_4 = out4[:, 0:2] 218 | up_flow1_4 = out4[:, 2:4] 219 | ft_3_ = out4[:, 4:] 220 | 221 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) 222 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) 223 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) 224 | ft_2_ = out3[:, 4:] 225 | 226 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) 227 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) 228 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) 229 | ft_1_ = out2[:, 4:] 230 | 231 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) 232 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) 233 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) 234 | up_mask_1 = torch.sigmoid(out1[:, 4:5]) 235 | up_res_1 = out1[:, 5:] 236 | 237 | img0_warp = warp(img0, up_flow0_1) 238 | img1_warp = warp(img1, up_flow1_1) 239 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ 240 | imgt_pred = imgt_merge + up_res_1 241 | imgt_pred = torch.clamp(imgt_pred, 0, 1) 242 | 243 | loss_rec = self.l1_loss(imgt_pred - imgt) + self.tr_loss(imgt_pred, imgt) 244 | loss_geo = 0.01 * (self.gc_loss(ft_1_, ft_1) + self.gc_loss(ft_2_, ft_2) + self.gc_loss(ft_3_, ft_3)) 245 | if flow is not None: 246 | robust_weight0 = get_robust_weight(up_flow0_1, flow[:, 0:2], beta=0.3) 247 | robust_weight1 = get_robust_weight(up_flow1_1, flow[:, 2:4], beta=0.3) 248 | loss_dis = 0.01 * (self.rb_loss(2.0 * resize(up_flow0_2, 2.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(2.0 * resize(up_flow1_2, 2.0) - flow[:, 2:4], weight=robust_weight1)) 249 | loss_dis += 0.01 * (self.rb_loss(4.0 * resize(up_flow0_3, 4.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(4.0 * resize(up_flow1_3, 4.0) - flow[:, 2:4], weight=robust_weight1)) 250 | loss_dis += 0.01 * (self.rb_loss(8.0 * resize(up_flow0_4, 8.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(8.0 * resize(up_flow1_4, 8.0) - flow[:, 2:4], weight=robust_weight1)) 251 | else: 252 | loss_dis = 0.00 * loss_geo 253 | 254 | return imgt_pred, loss_rec, loss_geo, loss_dis 255 | -------------------------------------------------------------------------------- /models/IFRNet_S.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import warp, get_robust_weight 5 | from loss import * 6 | 7 | 8 | def resize(x, scale_factor): 9 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) 10 | 11 | 12 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 13 | return nn.Sequential( 14 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), 15 | nn.PReLU(out_channels) 16 | ) 17 | 18 | 19 | class ResBlock(nn.Module): 20 | def __init__(self, in_channels, side_channels, bias=True): 21 | super(ResBlock, self).__init__() 22 | self.side_channels = side_channels 23 | self.conv1 = nn.Sequential( 24 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 25 | nn.PReLU(in_channels) 26 | ) 27 | self.conv2 = nn.Sequential( 28 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 29 | nn.PReLU(side_channels) 30 | ) 31 | self.conv3 = nn.Sequential( 32 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 33 | nn.PReLU(in_channels) 34 | ) 35 | self.conv4 = nn.Sequential( 36 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 37 | nn.PReLU(side_channels) 38 | ) 39 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) 40 | self.prelu = nn.PReLU(in_channels) 41 | 42 | def forward(self, x): 43 | out = self.conv1(x) 44 | out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone()) 45 | out = self.conv3(out) 46 | out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone()) 47 | out = self.prelu(x + self.conv5(out)) 48 | return out 49 | 50 | 51 | class Encoder(nn.Module): 52 | def __init__(self): 53 | super(Encoder, self).__init__() 54 | self.pyramid1 = nn.Sequential( 55 | convrelu(3, 24, 3, 2, 1), 56 | convrelu(24, 24, 3, 1, 1) 57 | ) 58 | self.pyramid2 = nn.Sequential( 59 | convrelu(24, 36, 3, 2, 1), 60 | convrelu(36, 36, 3, 1, 1) 61 | ) 62 | self.pyramid3 = nn.Sequential( 63 | convrelu(36, 54, 3, 2, 1), 64 | convrelu(54, 54, 3, 1, 1) 65 | ) 66 | self.pyramid4 = nn.Sequential( 67 | convrelu(54, 72, 3, 2, 1), 68 | convrelu(72, 72, 3, 1, 1) 69 | ) 70 | 71 | def forward(self, img): 72 | f1 = self.pyramid1(img) 73 | f2 = self.pyramid2(f1) 74 | f3 = self.pyramid3(f2) 75 | f4 = self.pyramid4(f3) 76 | return f1, f2, f3, f4 77 | 78 | 79 | class Decoder4(nn.Module): 80 | def __init__(self): 81 | super(Decoder4, self).__init__() 82 | self.convblock = nn.Sequential( 83 | convrelu(144+1, 144), 84 | ResBlock(144, 24), 85 | nn.ConvTranspose2d(144, 58, 4, 2, 1, bias=True) 86 | ) 87 | 88 | def forward(self, f0, f1, embt): 89 | b, c, h, w = f0.shape 90 | embt = embt.repeat(1, 1, h, w) 91 | f_in = torch.cat([f0, f1, embt], 1) 92 | f_out = self.convblock(f_in) 93 | return f_out 94 | 95 | 96 | class Decoder3(nn.Module): 97 | def __init__(self): 98 | super(Decoder3, self).__init__() 99 | self.convblock = nn.Sequential( 100 | convrelu(166, 162), 101 | ResBlock(162, 24), 102 | nn.ConvTranspose2d(162, 40, 4, 2, 1, bias=True) 103 | ) 104 | 105 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 106 | f0_warp = warp(f0, up_flow0) 107 | f1_warp = warp(f1, up_flow1) 108 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 109 | f_out = self.convblock(f_in) 110 | return f_out 111 | 112 | 113 | class Decoder2(nn.Module): 114 | def __init__(self): 115 | super(Decoder2, self).__init__() 116 | self.convblock = nn.Sequential( 117 | convrelu(112, 108), 118 | ResBlock(108, 24), 119 | nn.ConvTranspose2d(108, 28, 4, 2, 1, bias=True) 120 | ) 121 | 122 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 123 | f0_warp = warp(f0, up_flow0) 124 | f1_warp = warp(f1, up_flow1) 125 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 126 | f_out = self.convblock(f_in) 127 | return f_out 128 | 129 | 130 | class Decoder1(nn.Module): 131 | def __init__(self): 132 | super(Decoder1, self).__init__() 133 | self.convblock = nn.Sequential( 134 | convrelu(76, 72), 135 | ResBlock(72, 24), 136 | nn.ConvTranspose2d(72, 8, 4, 2, 1, bias=True) 137 | ) 138 | 139 | def forward(self, ft_, f0, f1, up_flow0, up_flow1): 140 | f0_warp = warp(f0, up_flow0) 141 | f1_warp = warp(f1, up_flow1) 142 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) 143 | f_out = self.convblock(f_in) 144 | return f_out 145 | 146 | 147 | class Model(nn.Module): 148 | def __init__(self, local_rank=-1, lr=1e-4): 149 | super(Model, self).__init__() 150 | self.encoder = Encoder() 151 | self.decoder4 = Decoder4() 152 | self.decoder3 = Decoder3() 153 | self.decoder2 = Decoder2() 154 | self.decoder1 = Decoder1() 155 | self.l1_loss = Charbonnier_L1() 156 | self.tr_loss = Ternary(7) 157 | self.rb_loss = Charbonnier_Ada() 158 | self.gc_loss = Geometry(3) 159 | 160 | 161 | def inference(self, img0, img1, embt, scale_factor=1.0): 162 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 163 | img0 = img0 - mean_ 164 | img1 = img1 - mean_ 165 | 166 | img0_ = resize(img0, scale_factor=scale_factor) 167 | img1_ = resize(img1, scale_factor=scale_factor) 168 | 169 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) 170 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) 171 | 172 | out4 = self.decoder4(f0_4, f1_4, embt) 173 | up_flow0_4 = out4[:, 0:2] 174 | up_flow1_4 = out4[:, 2:4] 175 | ft_3_ = out4[:, 4:] 176 | 177 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) 178 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) 179 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) 180 | ft_2_ = out3[:, 4:] 181 | 182 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) 183 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) 184 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) 185 | ft_1_ = out2[:, 4:] 186 | 187 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) 188 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) 189 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) 190 | up_mask_1 = torch.sigmoid(out1[:, 4:5]) 191 | up_res_1 = out1[:, 5:] 192 | 193 | up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 194 | up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 195 | up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) 196 | up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) 197 | 198 | img0_warp = warp(img0, up_flow0_1) 199 | img1_warp = warp(img1, up_flow1_1) 200 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ 201 | imgt_pred = imgt_merge + up_res_1 202 | imgt_pred = torch.clamp(imgt_pred, 0, 1) 203 | return imgt_pred 204 | 205 | 206 | def forward(self, img0, img1, embt, imgt, flow=None): 207 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 208 | img0 = img0 - mean_ 209 | img1 = img1 - mean_ 210 | imgt_ = imgt - mean_ 211 | 212 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0) 213 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1) 214 | ft_1, ft_2, ft_3, ft_4 = self.encoder(imgt_) 215 | 216 | out4 = self.decoder4(f0_4, f1_4, embt) 217 | up_flow0_4 = out4[:, 0:2] 218 | up_flow1_4 = out4[:, 2:4] 219 | ft_3_ = out4[:, 4:] 220 | 221 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) 222 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) 223 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) 224 | ft_2_ = out3[:, 4:] 225 | 226 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) 227 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) 228 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) 229 | ft_1_ = out2[:, 4:] 230 | 231 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) 232 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) 233 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) 234 | up_mask_1 = torch.sigmoid(out1[:, 4:5]) 235 | up_res_1 = out1[:, 5:] 236 | 237 | img0_warp = warp(img0, up_flow0_1) 238 | img1_warp = warp(img1, up_flow1_1) 239 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ 240 | imgt_pred = imgt_merge + up_res_1 241 | imgt_pred = torch.clamp(imgt_pred, 0, 1) 242 | 243 | loss_rec = self.l1_loss(imgt_pred - imgt) + self.tr_loss(imgt_pred, imgt) 244 | loss_geo = 0.01 * (self.gc_loss(ft_1_, ft_1) + self.gc_loss(ft_2_, ft_2) + self.gc_loss(ft_3_, ft_3)) 245 | if flow is not None: 246 | robust_weight0 = get_robust_weight(up_flow0_1, flow[:, 0:2], beta=0.3) 247 | robust_weight1 = get_robust_weight(up_flow1_1, flow[:, 2:4], beta=0.3) 248 | loss_dis = 0.01 * (self.rb_loss(2.0 * resize(up_flow0_2, 2.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(2.0 * resize(up_flow1_2, 2.0) - flow[:, 2:4], weight=robust_weight1)) 249 | loss_dis += 0.01 * (self.rb_loss(4.0 * resize(up_flow0_3, 4.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(4.0 * resize(up_flow1_3, 4.0) - flow[:, 2:4], weight=robust_weight1)) 250 | loss_dis += 0.01 * (self.rb_loss(8.0 * resize(up_flow0_4, 8.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(8.0 * resize(up_flow1_4, 8.0) - flow[:, 2:4], weight=robust_weight1)) 251 | else: 252 | loss_dis = 0.00 * loss_geo 253 | 254 | return imgt_pred, loss_rec, loss_geo, loss_dis 255 | -------------------------------------------------------------------------------- /train_gopro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import random 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from datasets import GoPro_Train_Dataset, GoPro_Test_Dataset 15 | from metric import calculate_psnr, calculate_ssim 16 | from utils import AverageMeter 17 | import logging 18 | 19 | 20 | def get_lr(args, iters): 21 | ratio = 0.5 * (1.0 + np.cos(iters / (args.epochs * args.iters_per_epoch) * math.pi)) 22 | lr = (args.lr_start - args.lr_end) * ratio + args.lr_end 23 | return lr 24 | 25 | 26 | def set_lr(optimizer, lr): 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | 30 | 31 | def train(args, ddp_model): 32 | local_rank = args.local_rank 33 | print('Distributed Data Parallel Training IFRNet on Rank {}'.format(local_rank)) 34 | 35 | if local_rank == 0: 36 | os.makedirs(args.log_path, exist_ok=True) 37 | log_path = os.path.join(args.log_path, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) 38 | os.makedirs(log_path, exist_ok=True) 39 | logger = logging.getLogger() 40 | logger.setLevel('INFO') 41 | BASIC_FORMAT = '%(asctime)s:%(levelname)s:%(message)s' 42 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 43 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 44 | chlr = logging.StreamHandler() 45 | chlr.setFormatter(formatter) 46 | chlr.setLevel('INFO') 47 | fhlr = logging.FileHandler(os.path.join(log_path, 'train.log')) 48 | fhlr.setFormatter(formatter) 49 | logger.addHandler(chlr) 50 | logger.addHandler(fhlr) 51 | logger.info(args) 52 | 53 | dataset_train = GoPro_Train_Dataset(dataset_dir='/home/ltkong/Datasets/GOPRO', augment=True) 54 | sampler = DistributedSampler(dataset_train) 55 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=sampler) 56 | args.iters_per_epoch = dataloader_train.__len__() 57 | iters = args.resume_epoch * args.iters_per_epoch 58 | 59 | dataset_val = GoPro_Test_Dataset(dataset_dir='/home/ltkong/Datasets/GOPRO') 60 | dataloader_val = DataLoader(dataset_val, batch_size=2, num_workers=4, pin_memory=True, shuffle=False, drop_last=True) 61 | 62 | optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr_start, weight_decay=0) 63 | 64 | time_stamp = time.time() 65 | avg_rec = AverageMeter() 66 | avg_geo = AverageMeter() 67 | avg_dis = AverageMeter() 68 | best_psnr = 0.0 69 | 70 | for epoch in range(args.resume_epoch, args.epochs): 71 | sampler.set_epoch(epoch) 72 | for i, data in enumerate(dataloader_train): 73 | for l in range(len(data)): 74 | data[l] = data[l].to(args.device) 75 | img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7 = data 76 | 77 | img0 = torch.cat([img0, img0, img0, img0, img0, img0, img0], 0) 78 | img8 = torch.cat([img8, img8, img8, img8, img8, img8, img8], 0) 79 | imgt = torch.cat([img1, img2, img3, img4, img5, img6, img7], 0) 80 | embt = torch.cat([emb1, emb2, emb3, emb4, emb5, emb6, emb7], 0) 81 | 82 | data_time_interval = time.time() - time_stamp 83 | time_stamp = time.time() 84 | 85 | lr = get_lr(args, iters) 86 | set_lr(optimizer, lr) 87 | 88 | optimizer.zero_grad() 89 | 90 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img8, embt, imgt, None) 91 | 92 | loss = loss_rec + loss_geo + loss_dis 93 | loss.backward() 94 | optimizer.step() 95 | 96 | avg_rec.update(loss_rec.cpu().data) 97 | avg_geo.update(loss_geo.cpu().data) 98 | avg_dis.update(loss_dis.cpu().data) 99 | train_time_interval = time.time() - time_stamp 100 | 101 | if (iters+1) % 50 == 0 and local_rank == 0: 102 | logger.info('epoch:{}/{} iter:{}/{} time:{:.2f}+{:.2f} lr:{:.5e} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e}'.format(epoch+1, args.epochs, iters+1, args.epochs * args.iters_per_epoch, data_time_interval, train_time_interval, lr, avg_rec.avg, avg_geo.avg, avg_dis.avg)) 103 | avg_rec.reset() 104 | avg_geo.reset() 105 | avg_dis.reset() 106 | 107 | iters += 1 108 | time_stamp = time.time() 109 | 110 | if (epoch+1) % args.eval_interval == 0 and local_rank == 0: 111 | psnr = evaluate(args, ddp_model, dataloader_val, epoch, logger) 112 | if psnr > best_psnr: 113 | best_psnr = psnr 114 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'best')) 115 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'latest')) 116 | 117 | dist.barrier() 118 | 119 | 120 | def evaluate(args, ddp_model, dataloader_val, epoch, logger): 121 | loss_rec_list = [] 122 | loss_geo_list = [] 123 | loss_dis_list = [] 124 | psnr_list = [] 125 | time_stamp = time.time() 126 | for i, data in enumerate(dataloader_val): 127 | for l in range(len(data)): 128 | data[l] = data[l].to(args.device) 129 | img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7 = data 130 | 131 | img0 = torch.cat([img0, img0, img0, img0, img0, img0, img0], 0) 132 | img8 = torch.cat([img8, img8, img8, img8, img8, img8, img8], 0) 133 | imgt = torch.cat([img1, img2, img3, img4, img5, img6, img7], 0) 134 | embt = torch.cat([emb1, emb2, emb3, emb4, emb5, emb6, emb7], 0) 135 | 136 | with torch.no_grad(): 137 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img8, embt, imgt, None) 138 | 139 | loss_rec_list.append(loss_rec.cpu().numpy()) 140 | loss_geo_list.append(loss_geo.cpu().numpy()) 141 | loss_dis_list.append(loss_dis.cpu().numpy()) 142 | 143 | for j in range(img0.shape[0]): 144 | psnr = calculate_psnr(imgt_pred[j].unsqueeze(0), imgt[j].unsqueeze(0)).cpu().data 145 | psnr_list.append(psnr) 146 | 147 | eval_time_interval = time.time() - time_stamp 148 | 149 | logger.info('eval epoch:{}/{} time:{:.2f} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e} psnr:{:.3f}'.format(epoch+1, args.epochs, eval_time_interval, np.array(loss_rec_list).mean(), np.array(loss_geo_list).mean(), np.array(loss_dis_list).mean(), np.array(psnr_list).mean())) 150 | return np.array(psnr_list).mean() 151 | 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser(description='IFRNet') 156 | parser.add_argument('--model_name', default='IFRNet', type=str, help='IFRNet, IFRNet_L, IFRNet_S') 157 | parser.add_argument('--local_rank', default=-1, type=int) 158 | parser.add_argument('--world_size', default=4, type=int) 159 | parser.add_argument('--epochs', default=600, type=int) 160 | parser.add_argument('--eval_interval', default=8, type=int) 161 | parser.add_argument('--batch_size', default=2, type=int) 162 | parser.add_argument('--lr_start', default=1e-4, type=float) 163 | parser.add_argument('--lr_end', default=1e-5, type=float) 164 | parser.add_argument('--log_path', default='checkpoint', type=str) 165 | parser.add_argument('--resume_epoch', default=0, type=int) 166 | parser.add_argument('--resume_path', default=None, type=str) 167 | args = parser.parse_args() 168 | 169 | dist.init_process_group(backend='nccl', world_size=args.world_size) 170 | torch.cuda.set_device(args.local_rank) 171 | args.device = torch.device('cuda', args.local_rank) 172 | 173 | seed = 1234 174 | random.seed(seed) 175 | np.random.seed(seed) 176 | torch.manual_seed(seed) 177 | torch.cuda.manual_seed_all(seed) 178 | torch.backends.cudnn.benchmark = True 179 | 180 | if args.model_name == 'IFRNet': 181 | from models.IFRNet import Model 182 | elif args.model_name == 'IFRNet_L': 183 | from models.IFRNet_L import Model 184 | elif args.model_name == 'IFRNet_S': 185 | from models.IFRNet_S import Model 186 | 187 | args.log_path = args.log_path + '/' + args.model_name 188 | args.num_workers = args.batch_size * 4 189 | 190 | model = Model().to(args.device) 191 | 192 | if args.resume_epoch != 0: 193 | model.load_state_dict(torch.load(args.resume_path, map_location='cpu')) 194 | 195 | ddp_model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) 196 | 197 | train(args, ddp_model) 198 | 199 | dist.destroy_process_group() 200 | -------------------------------------------------------------------------------- /train_vimeo90k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import random 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | import torch.distributed as dist 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from datasets import Vimeo90K_Train_Dataset, Vimeo90K_Test_Dataset 15 | from metric import calculate_psnr, calculate_ssim 16 | from utils import AverageMeter 17 | import logging 18 | 19 | 20 | def get_lr(args, iters): 21 | ratio = 0.5 * (1.0 + np.cos(iters / (args.epochs * args.iters_per_epoch) * math.pi)) 22 | lr = (args.lr_start - args.lr_end) * ratio + args.lr_end 23 | return lr 24 | 25 | 26 | def set_lr(optimizer, lr): 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | 30 | 31 | def train(args, ddp_model): 32 | local_rank = args.local_rank 33 | print('Distributed Data Parallel Training IFRNet on Rank {}'.format(local_rank)) 34 | 35 | if local_rank == 0: 36 | os.makedirs(args.log_path, exist_ok=True) 37 | log_path = os.path.join(args.log_path, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) 38 | os.makedirs(log_path, exist_ok=True) 39 | logger = logging.getLogger() 40 | logger.setLevel('INFO') 41 | BASIC_FORMAT = '%(asctime)s:%(levelname)s:%(message)s' 42 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 43 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 44 | chlr = logging.StreamHandler() 45 | chlr.setFormatter(formatter) 46 | chlr.setLevel('INFO') 47 | fhlr = logging.FileHandler(os.path.join(log_path, 'train.log')) 48 | fhlr.setFormatter(formatter) 49 | logger.addHandler(chlr) 50 | logger.addHandler(fhlr) 51 | logger.info(args) 52 | 53 | dataset_train = Vimeo90K_Train_Dataset(dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet', augment=True) 54 | sampler = DistributedSampler(dataset_train) 55 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=sampler) 56 | args.iters_per_epoch = dataloader_train.__len__() 57 | iters = args.resume_epoch * args.iters_per_epoch 58 | 59 | dataset_val = Vimeo90K_Test_Dataset(dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet') 60 | dataloader_val = DataLoader(dataset_val, batch_size=16, num_workers=16, pin_memory=True, shuffle=False, drop_last=True) 61 | 62 | optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr_start, weight_decay=0) 63 | 64 | time_stamp = time.time() 65 | avg_rec = AverageMeter() 66 | avg_geo = AverageMeter() 67 | avg_dis = AverageMeter() 68 | best_psnr = 0.0 69 | 70 | for epoch in range(args.resume_epoch, args.epochs): 71 | sampler.set_epoch(epoch) 72 | for i, data in enumerate(dataloader_train): 73 | for l in range(len(data)): 74 | data[l] = data[l].to(args.device) 75 | img0, imgt, img1, flow, embt = data 76 | 77 | data_time_interval = time.time() - time_stamp 78 | time_stamp = time.time() 79 | 80 | lr = get_lr(args, iters) 81 | set_lr(optimizer, lr) 82 | 83 | optimizer.zero_grad() 84 | 85 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img1, embt, imgt, flow) 86 | 87 | loss = loss_rec + loss_geo + loss_dis 88 | loss.backward() 89 | optimizer.step() 90 | 91 | avg_rec.update(loss_rec.cpu().data) 92 | avg_geo.update(loss_geo.cpu().data) 93 | avg_dis.update(loss_dis.cpu().data) 94 | train_time_interval = time.time() - time_stamp 95 | 96 | if (iters+1) % 100 == 0 and local_rank == 0: 97 | logger.info('epoch:{}/{} iter:{}/{} time:{:.2f}+{:.2f} lr:{:.5e} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e}'.format(epoch+1, args.epochs, iters+1, args.epochs * args.iters_per_epoch, data_time_interval, train_time_interval, lr, avg_rec.avg, avg_geo.avg, avg_dis.avg)) 98 | avg_rec.reset() 99 | avg_geo.reset() 100 | avg_dis.reset() 101 | 102 | iters += 1 103 | time_stamp = time.time() 104 | 105 | if (epoch+1) % args.eval_interval == 0 and local_rank == 0: 106 | psnr = evaluate(args, ddp_model, dataloader_val, epoch, logger) 107 | if psnr > best_psnr: 108 | best_psnr = psnr 109 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'best')) 110 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'latest')) 111 | 112 | dist.barrier() 113 | 114 | 115 | def evaluate(args, ddp_model, dataloader_val, epoch, logger): 116 | loss_rec_list = [] 117 | loss_geo_list = [] 118 | loss_dis_list = [] 119 | psnr_list = [] 120 | time_stamp = time.time() 121 | for i, data in enumerate(dataloader_val): 122 | for l in range(len(data)): 123 | data[l] = data[l].to(args.device) 124 | img0, imgt, img1, flow, embt = data 125 | 126 | with torch.no_grad(): 127 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img1, embt, imgt, flow) 128 | 129 | loss_rec_list.append(loss_rec.cpu().numpy()) 130 | loss_geo_list.append(loss_geo.cpu().numpy()) 131 | loss_dis_list.append(loss_dis.cpu().numpy()) 132 | 133 | for j in range(img0.shape[0]): 134 | psnr = calculate_psnr(imgt_pred[j].unsqueeze(0), imgt[j].unsqueeze(0)).cpu().data 135 | psnr_list.append(psnr) 136 | 137 | eval_time_interval = time.time() - time_stamp 138 | 139 | logger.info('eval epoch:{}/{} time:{:.2f} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e} psnr:{:.3f}'.format(epoch+1, args.epochs, eval_time_interval, np.array(loss_rec_list).mean(), np.array(loss_geo_list).mean(), np.array(loss_dis_list).mean(), np.array(psnr_list).mean())) 140 | return np.array(psnr_list).mean() 141 | 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser(description='IFRNet') 146 | parser.add_argument('--model_name', default='IFRNet', type=str, help='IFRNet, IFRNet_L, IFRNet_S') 147 | parser.add_argument('--local_rank', default=-1, type=int) 148 | parser.add_argument('--world_size', default=4, type=int) 149 | parser.add_argument('--epochs', default=300, type=int) 150 | parser.add_argument('--eval_interval', default=1, type=int) 151 | parser.add_argument('--batch_size', default=6, type=int) 152 | parser.add_argument('--lr_start', default=1e-4, type=float) 153 | parser.add_argument('--lr_end', default=1e-5, type=float) 154 | parser.add_argument('--log_path', default='checkpoint', type=str) 155 | parser.add_argument('--resume_epoch', default=0, type=int) 156 | parser.add_argument('--resume_path', default=None, type=str) 157 | args = parser.parse_args() 158 | 159 | dist.init_process_group(backend='nccl', world_size=args.world_size) 160 | torch.cuda.set_device(args.local_rank) 161 | args.device = torch.device('cuda', args.local_rank) 162 | 163 | seed = 1234 164 | random.seed(seed) 165 | np.random.seed(seed) 166 | torch.manual_seed(seed) 167 | torch.cuda.manual_seed_all(seed) 168 | torch.backends.cudnn.benchmark = True 169 | 170 | if args.model_name == 'IFRNet': 171 | from models.IFRNet import Model 172 | elif args.model_name == 'IFRNet_L': 173 | from models.IFRNet_L import Model 174 | elif args.model_name == 'IFRNet_S': 175 | from models.IFRNet_S import Model 176 | 177 | args.log_path = args.log_path + '/' + args.model_name 178 | args.num_workers = args.batch_size 179 | 180 | model = Model().to(args.device) 181 | 182 | if args.resume_epoch != 0: 183 | model.load_state_dict(torch.load(args.resume_path, map_location='cpu')) 184 | 185 | ddp_model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) 186 | 187 | train(args, ddp_model) 188 | 189 | dist.destroy_process_group() 190 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import numpy as np 5 | from imageio import imread, imwrite 6 | import numpy as np 7 | from PIL import Image, ImageFile 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def warp(img, flow): 15 | B, _, H, W = flow.shape 16 | xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) 17 | yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) 18 | grid = torch.cat([xx, yy], 1).to(img) 19 | flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) 20 | grid_ = (grid + flow_).permute(0, 2, 3, 1) 21 | output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) 22 | return output 23 | 24 | 25 | def get_robust_weight(flow_pred, flow_gt, beta): 26 | epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5 27 | robust_weight = torch.exp(-beta * epe) 28 | return robust_weight 29 | 30 | 31 | class AverageMeter(): 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | def read(file): 49 | if file.endswith('.float3'): return readFloat(file) 50 | elif file.endswith('.flo'): return readFlow(file) 51 | elif file.endswith('.ppm'): return readImage(file) 52 | elif file.endswith('.pgm'): return readImage(file) 53 | elif file.endswith('.png'): return readImage(file) 54 | elif file.endswith('.jpg'): return readImage(file) 55 | elif file.endswith('.pfm'): return readPFM(file)[0] 56 | else: raise Exception('don\'t know how to read %s' % file) 57 | 58 | 59 | def write(file, data): 60 | if file.endswith('.float3'): return writeFloat(file, data) 61 | elif file.endswith('.flo'): return writeFlow(file, data) 62 | elif file.endswith('.ppm'): return writeImage(file, data) 63 | elif file.endswith('.pgm'): return writeImage(file, data) 64 | elif file.endswith('.png'): return writeImage(file, data) 65 | elif file.endswith('.jpg'): return writeImage(file, data) 66 | elif file.endswith('.pfm'): return writePFM(file, data) 67 | else: raise Exception('don\'t know how to write %s' % file) 68 | 69 | 70 | def readPFM(file): 71 | file = open(file, 'rb') 72 | 73 | color = None 74 | width = None 75 | height = None 76 | scale = None 77 | endian = None 78 | 79 | header = file.readline().rstrip() 80 | if header.decode("ascii") == 'PF': 81 | color = True 82 | elif header.decode("ascii") == 'Pf': 83 | color = False 84 | else: 85 | raise Exception('Not a PFM file.') 86 | 87 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 88 | if dim_match: 89 | width, height = list(map(int, dim_match.groups())) 90 | else: 91 | raise Exception('Malformed PFM header.') 92 | 93 | scale = float(file.readline().decode("ascii").rstrip()) 94 | if scale < 0: 95 | endian = '<' 96 | scale = -scale 97 | else: 98 | endian = '>' 99 | 100 | data = np.fromfile(file, endian + 'f') 101 | shape = (height, width, 3) if color else (height, width) 102 | 103 | data = np.reshape(data, shape) 104 | data = np.flipud(data) 105 | return data, scale 106 | 107 | 108 | def writePFM(file, image, scale=1): 109 | file = open(file, 'wb') 110 | 111 | color = None 112 | 113 | if image.dtype.name != 'float32': 114 | raise Exception('Image dtype must be float32.') 115 | 116 | image = np.flipud(image) 117 | 118 | if len(image.shape) == 3 and image.shape[2] == 3: 119 | color = True 120 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: 121 | color = False 122 | else: 123 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 124 | 125 | file.write('PF\n' if color else 'Pf\n'.encode()) 126 | file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) 127 | 128 | endian = image.dtype.byteorder 129 | 130 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 131 | scale = -scale 132 | 133 | file.write('%f\n'.encode() % scale) 134 | 135 | image.tofile(file) 136 | 137 | 138 | def readFlow(name): 139 | if name.endswith('.pfm') or name.endswith('.PFM'): 140 | return readPFM(name)[0][:,:,0:2] 141 | 142 | f = open(name, 'rb') 143 | 144 | header = f.read(4) 145 | if header.decode("utf-8") != 'PIEH': 146 | raise Exception('Flow file header does not contain PIEH') 147 | 148 | width = np.fromfile(f, np.int32, 1).squeeze() 149 | height = np.fromfile(f, np.int32, 1).squeeze() 150 | 151 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) 152 | 153 | return flow.astype(np.float32) 154 | 155 | 156 | def readImage(name): 157 | if name.endswith('.pfm') or name.endswith('.PFM'): 158 | data = readPFM(name)[0] 159 | if len(data.shape)==3: 160 | return data[:,:,0:3] 161 | else: 162 | return data 163 | return imread(name) 164 | 165 | 166 | def writeImage(name, data): 167 | if name.endswith('.pfm') or name.endswith('.PFM'): 168 | return writePFM(name, data, 1) 169 | return imwrite(name, data) 170 | 171 | 172 | def writeFlow(name, flow): 173 | f = open(name, 'wb') 174 | f.write('PIEH'.encode('utf-8')) 175 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 176 | flow = flow.astype(np.float32) 177 | flow.tofile(f) 178 | 179 | 180 | def readFloat(name): 181 | f = open(name, 'rb') 182 | 183 | if(f.readline().decode("utf-8")) != 'float\n': 184 | raise Exception('float file %s did not contain keyword' % name) 185 | 186 | dim = int(f.readline()) 187 | 188 | dims = [] 189 | count = 1 190 | for i in range(0, dim): 191 | d = int(f.readline()) 192 | dims.append(d) 193 | count *= d 194 | 195 | dims = list(reversed(dims)) 196 | 197 | data = np.fromfile(f, np.float32, count).reshape(dims) 198 | if dim > 2: 199 | data = np.transpose(data, (2, 1, 0)) 200 | data = np.transpose(data, (1, 0, 2)) 201 | 202 | return data 203 | 204 | 205 | def writeFloat(name, data): 206 | f = open(name, 'wb') 207 | 208 | dim=len(data.shape) 209 | if dim>3: 210 | raise Exception('bad float file dimension: %d' % dim) 211 | 212 | f.write(('float\n').encode('ascii')) 213 | f.write(('%d\n' % dim).encode('ascii')) 214 | 215 | if dim == 1: 216 | f.write(('%d\n' % data.shape[0]).encode('ascii')) 217 | else: 218 | f.write(('%d\n' % data.shape[1]).encode('ascii')) 219 | f.write(('%d\n' % data.shape[0]).encode('ascii')) 220 | for i in range(2, dim): 221 | f.write(('%d\n' % data.shape[i]).encode('ascii')) 222 | 223 | data = data.astype(np.float32) 224 | if dim==2: 225 | data.tofile(f) 226 | 227 | else: 228 | np.transpose(data, (2, 0, 1)).tofile(f) 229 | --------------------------------------------------------------------------------