├── LICENSE
├── README.md
├── benchmarks
├── Middlebury_Other.py
├── SNU_FILM.py
├── UCF101.py
├── Vimeo90K.py
└── speed_parameters.py
├── datasets.py
├── demo_2x.py
├── demo_8x.py
├── figures
├── 8x_interpolation.png
├── benchmarks.png
├── fig1_1.gif
├── fig1_2.gif
├── fig1_3.gif
├── fig2_1.gif
├── fig2_2.gif
├── img0.png
├── img1.png
├── img_overlaid.png
├── middlebury.png
├── middlebury_other.png
├── out_2x.gif
├── out_8x.gif
└── vimeo90k.png
├── generate_flow.py
├── liteflownet
├── README.md
├── correlation
│ ├── README.md
│ └── correlation.py
├── requirements.txt
└── run.py
├── loss.py
├── metric.py
├── models
├── IFRNet.py
├── IFRNet_L.py
└── IFRNet_S.py
├── train_gopro.py
├── train_vimeo90k.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Lingtong Kong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation
2 | The official PyTorch implementation of [IFRNet](https://arxiv.org/abs/2205.14620) (CVPR 2022).
3 |
4 | Authors: [Lingtong Kong](https://scholar.google.com.hk/citations?user=KKzKc_8AAAAJ&hl=zh-CN), [Boyuan Jiang](https://byjiang.com/), Donghao Luo, Wenqing Chu, Xiaoming Huang, [Ying Tai](https://tyshiwo.github.io/), Chengjie Wang, [Jie Yang](http://www.pami.sjtu.edu.cn/jieyang)
5 |
6 | ## Highlights
7 | Existing flow-based frame interpolation methods almost all first estimate or model intermediate optical flow, and then use flow warped context features to synthesize target frame. However, they ignore the mutual promotion of intermediate optical flow and intermediate context feature. Also, their cascaded architecture can substantially increase the inference delay and model parameters, blocking them from lots of mobile and real-time applications. For the first time, we merge above separated flow estimation and context feature refinement into a single encoder-decoder based IFRNet for compactness and fast inference, where these two crucial elements can benefit from each other. Moreover, task-oriented flow distillation loss and feature space geometry consistency loss are newly proposed to promote intermediate motion estimation and intermediate feature reconstruction of IFRNet, respectively. Benchmark results demonstrate that our IFRNet not only achieves state-of-the-art VFI accuracy, but also enjoys fast inference speed and lightweight model size.
8 |
9 | 
10 |
11 | ## YouTube Demos
12 | [[4K60p] うたわれるもの 偽りの仮面 OP フレーム補間+超解像 (IFRnetとReal-CUGAN)](https://www.youtube.com/watch?v=tV2imgGS-5Q)
13 |
14 | [[4K60p] 天神乱漫 -LUCKY or UNLUCKY!?- OP (IFRnetとReal-CUGAN)](https://www.youtube.com/watch?v=NtpJqDZaM-4)
15 |
16 | [RIFE IFRnet 比較](https://www.youtube.com/watch?v=lHqnOQgpZHQ)
17 |
18 | [IFRNet frame interpolation](https://www.youtube.com/watch?v=ygSdCCZCsZU)
19 |
20 | ## Preparation
21 | 1. PyTorch >= 1.3.0 (We have verified that this repository supports Python 3.6/3.7, PyTorch 1.3.0/1.9.1).
22 | 2. Download training and test datasets: [Vimeo90K](http://toflow.csail.mit.edu/), [UCF101](https://liuziwei7.github.io/projects/VoxelFlow), [SNU-FILM](https://myungsub.github.io/CAIN/), [Middlebury](https://vision.middlebury.edu/flow/data/), [GoPro](https://seungjunnah.github.io/Datasets/gopro.html) and [Adobe240](http://www.cs.ubc.ca/labs/imager/tr/2017/DeepVideoDeblurring/).
23 | 3. Set the right dataset path on your machine.
24 |
25 | ## Download Pre-trained Models and Play with Demos
26 | Figures from left to right are overlaid input frames, 2x and 8x video interpolation results respectively.
27 |
28 |
29 |
30 |
31 |
32 |
33 | 1. Download our pre-trained models in this [link](https://www.dropbox.com/sh/hrewbpedd2cgdp3/AADbEivu0-CKDQcHtKdMNJPJa?dl=0), and then put file checkpoints
into the root dir.
34 |
35 | 2. Run the following scripts to generate 2x and 8x frame interpolation demos
36 | $ python demo_2x.py
37 | $ python demo_8x.py
38 |
39 |
40 | ## Training on Vimeo90K Triplet Dataset for 2x Frame Interpolation
41 | 1. First, run this script to generate optical flow pseudo label
42 | $ python generate_flow.py
43 |
44 | 2. Then, start training by executing one of the following commands with selected model
45 | $ python -m torch.distributed.launch --nproc_per_node=4 train_vimeo90k.py --world_size 4 --model_name 'IFRNet' --epochs 300 --batch_size 6 --lr_start 1e-4 --lr_end 1e-5
46 | $ python -m torch.distributed.launch --nproc_per_node=4 train_vimeo90k.py --world_size 4 --model_name 'IFRNet_L' --epochs 300 --batch_size 6 --lr_start 1e-4 --lr_end 1e-5
47 | $ python -m torch.distributed.launch --nproc_per_node=4 train_vimeo90k.py --world_size 4 --model_name 'IFRNet_S' --epochs 300 --batch_size 6 --lr_start 1e-4 --lr_end 1e-5
48 |
49 | ## Benchmarks for 2x Frame Interpolation
50 | To test running time and model parameters, you can run
51 | $ python benchmarks/speed_parameters.py
52 |
53 | To test frame interpolation accuracy on Vimeo90K, UCF101 and SNU-FILM datasets, you can run
54 | $ python benchmarks/Vimeo90K.py
55 | $ python benchmarks/UCF101.py
56 | $ python benchmarks/SNU_FILM.py
57 |
58 | ## Quantitative Comparison for 2x Frame Interpolation
59 | Proposed IFRNet achieves state-of-the-art frame interpolation accuracy with less inference time and computation complexity. We expect proposed single encoder-decoder joint refinement based IFRNet to be a useful component for many frame rate up-conversion, video compression and intermediate view synthesis systems. Time and FLOPs are measured on 1280 x 720 resolution.
60 |
61 | 
62 |
63 |
64 | ## Qualitative Comparison for 2x Frame Interpolation
65 | Video comparison for 2x interpolation of methods using 2 input frames on SNU-FILM dataset.
66 |
67 | 
68 |
69 | 
70 |
71 |
72 | ## Middlebury Benchmark
73 | Results on the [Middlebury](https://vision.middlebury.edu/flow/eval/results/results-i1.php) online benchmark.
74 |
75 | 
76 |
77 | Results on the Middlebury Other dataset.
78 |
79 | 
80 |
81 |
82 | ## Training on GoPro Dataset for 8x Frame Interpolation
83 | 1. Start training by executing one of the following commands with selected model
84 | $ python -m torch.distributed.launch --nproc_per_node=4 train_gopro.py --world_size 4 --model_name 'IFRNet' --epochs 600 --batch_size 2 --lr_start 1e-4 --lr_end 1e-5
85 | $ python -m torch.distributed.launch --nproc_per_node=4 train_gopro.py --world_size 4 --model_name 'IFRNet_L' --epochs 600 --batch_size 2 --lr_start 1e-4 --lr_end 1e-5
86 | $ python -m torch.distributed.launch --nproc_per_node=4 train_gopro.py --world_size 4 --model_name 'IFRNet_S' --epochs 600 --batch_size 2 --lr_start 1e-4 --lr_end 1e-5
87 |
88 | Since inter-frame motion in 8x interpolation setting is relatively small, task-oriented flow distillation loss is omitted here. Due to the GoPro training set is a relatively small dataset, we suggest to use your specific datasets to train slow-motion generation for better results.
89 |
90 | ## Quantitative Comparison for 8x Frame Interpolation
91 |
92 |
93 |
94 | ## Qualitative Results on GoPro and Adobe240 Datasets for 8x Frame Interpolation
95 | Each video has 9 frames, where the first and the last frames are input, and the middle 7 frames are predicted by IFRNet.
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 | ## ncnn Implementation of IFRNet
104 |
105 | [ifrnet-ncnn-vulkan](https://github.com/nihui/ifrnet-ncnn-vulkan) uses [ncnn project](https://github.com/Tencent/ncnn) as the universal neural network inference framework. This package includes all the binaries and models required. It is portable, so no CUDA or PyTorch runtime environment is needed.
106 |
107 | ## Citation
108 | When using any parts of the Software or the Paper in your work, please cite the following paper:
109 | @InProceedings{Kong_2022_CVPR,
110 | author = {Kong, Lingtong and Jiang, Boyuan and Luo, Donghao and Chu, Wenqing and Huang, Xiaoming and Tai, Ying and Wang, Chengjie and Yang, Jie},
111 | title = {IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation},
112 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
113 | year = {2022}
114 | }
115 |
--------------------------------------------------------------------------------
/benchmarks/Middlebury_Other.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('.')
4 | import torch
5 | import numpy as np
6 | from utils import read
7 | from metric import calculate_psnr, calculate_ssim, calculate_ie
8 | from models.IFRNet import Model
9 | # from models.IFRNet_L import Model
10 | # from models.IFRNet_S import Model
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 |
15 | model = Model()
16 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth'))
17 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth'))
18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth'))
19 | model.eval()
20 | model.cuda()
21 |
22 | # Replace the 'path' with your Middlebury dataset absolute path.
23 | path = '/home/ltkong/Datasets/Middlebury/'
24 | sequence = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']
25 |
26 | psnr_list = []
27 | ssim_list = []
28 | ie_list = []
29 | for i in sequence:
30 | I0 = read(path + 'other-data/{}/frame10.png'.format(i))
31 | I1 = read(path + 'other-gt-interp/{}/frame10i11.png'.format(i))
32 | I2 = read(path + 'other-data/{}/frame11.png'.format(i))
33 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
34 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
35 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
36 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
37 |
38 | I0_pad = torch.zeros([1, 3, 480, 640]).to(device)
39 | I2_pad = torch.zeros([1, 3, 480, 640]).to(device)
40 | h, w = I0.shape[-2:]
41 | I0_pad[:, :, :h, :w] = I0
42 | I2_pad[:, :, :h, :w] = I2
43 |
44 | I1_pred_pad = model.inference(I0_pad, I2_pad, embt)
45 | I1_pred = I1_pred_pad[:, :, :h, :w]
46 |
47 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy()
48 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy()
49 | ie = calculate_ie(I1_pred, I1).detach().cpu().numpy()
50 |
51 | psnr_list.append(psnr)
52 | ssim_list.append(ssim)
53 | ie_list.append(ie)
54 |
55 | print('Avg PSNR: {} SSIM: {} IE: {}'.format(np.mean(psnr_list), np.mean(ssim_list), np.mean(ie_list)))
56 |
--------------------------------------------------------------------------------
/benchmarks/SNU_FILM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('.')
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from utils import read
8 | from metric import calculate_psnr, calculate_ssim
9 | from models.IFRNet import Model
10 | # from models.IFRNet_L import Model
11 | # from models.IFRNet_S import Model
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 |
16 | model = Model()
17 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth'))
18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth'))
19 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth'))
20 | model.eval()
21 | model.cuda()
22 |
23 | divisor = 20
24 | scale_factor = 0.8
25 |
26 | class InputPadder:
27 | """ Pads images such that dimensions are divisible by divisor """
28 | def __init__(self, dims, divisor=divisor):
29 | self.ht, self.wd = dims[-2:]
30 | pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
31 | pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
32 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
33 |
34 | def pad(self, *inputs):
35 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
36 |
37 | def unpad(self,x):
38 | ht, wd = x.shape[-2:]
39 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
40 | return x[..., c[0]:c[1], c[2]:c[3]]
41 |
42 | # Replace the 'path' with your SNU-FILM dataset absolute path.
43 | path = '/home/ltkong/Datasets/SNU-FILM/'
44 |
45 | psnr_list = []
46 | ssim_list = []
47 | file_list = []
48 | test_file = "test-hard.txt" # test-easy.txt, test-medium.txt, test-hard.txt, test-extreme.txt
49 | with open(os.path.join(path, test_file), "r") as f:
50 | for line in f:
51 | line = line.strip()
52 | file_list.append(line.split(' '))
53 |
54 | for line in file_list:
55 | print(os.path.join(path, line[0]))
56 | I0_path = os.path.join(path, line[0])
57 | I1_path = os.path.join(path, line[1])
58 | I2_path = os.path.join(path, line[2])
59 | I0 = read(I0_path)
60 | I1 = read(I1_path)
61 | I2 = read(I2_path)
62 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
63 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
64 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
65 | padder = InputPadder(I0.shape)
66 | I0, I2 = padder.pad(I0, I2)
67 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
68 |
69 | I1_pred = model.inference(I0, I2, embt, scale_factor=scale_factor)
70 | I1_pred = padder.unpad(I1_pred)
71 |
72 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy()
73 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy()
74 |
75 | psnr_list.append(psnr)
76 | ssim_list.append(ssim)
77 |
78 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list)))
79 |
--------------------------------------------------------------------------------
/benchmarks/UCF101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('.')
4 | import torch
5 | import numpy as np
6 | from utils import read
7 | from metric import calculate_psnr, calculate_ssim
8 | from models.IFRNet import Model
9 | # from models.IFRNet_L import Model
10 | # from models.IFRNet_S import Model
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 |
15 | model = Model()
16 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth'))
17 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth'))
18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth'))
19 | model.eval()
20 | model.cuda()
21 |
22 | # Replace the 'path' with your UCF101 dataset absolute path.
23 | path = '/home/ltkong/Datasets/UCF101/ucf101_interp_ours/'
24 | dirs = sorted(os.listdir(path))
25 |
26 | psnr_list = []
27 | ssim_list = []
28 | for d in dirs:
29 | print(path + d + '/frame_00.png')
30 | I0 = read(path + d + '/frame_00.png')
31 | I1 = read(path + d + '/frame_01_gt.png')
32 | I2 = read(path + d + '/frame_02.png')
33 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
34 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
35 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
36 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
37 |
38 | I1_pred = model.inference(I0, I2, embt)
39 |
40 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy()
41 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy()
42 |
43 | psnr_list.append(psnr)
44 | ssim_list.append(ssim)
45 |
46 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list)))
47 |
--------------------------------------------------------------------------------
/benchmarks/Vimeo90K.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('.')
4 | import torch
5 | import numpy as np
6 | from utils import read
7 | from metric import calculate_psnr, calculate_ssim
8 | from models.IFRNet import Model
9 | # from models.IFRNet_L import Model
10 | # from models.IFRNet_S import Model
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 |
15 | model = Model()
16 | model.load_state_dict(torch.load('checkpoints/IFRNet/IFRNet_Vimeo90K.pth'))
17 | # model.load_state_dict(torch.load('checkpoints/IFRNet_large/IFRNet_L_Vimeo90K.pth'))
18 | # model.load_state_dict(torch.load('checkpoints/IFRNet_small/IFRNet_S_Vimeo90K.pth'))
19 | model.eval()
20 | model.cuda()
21 |
22 | # Replace the 'path' with your Vimeo90K dataset absolute path.
23 | path = '/home/ltkong/Datasets/Vimeo90K/vimeo_triplet/'
24 | f = open(path + 'tri_testlist.txt', 'r')
25 |
26 | psnr_list = []
27 | ssim_list = []
28 | for i in f:
29 | name = str(i).strip()
30 | if(len(name) <= 1):
31 | continue
32 | print(path + 'sequences/' + name + '/im1.png')
33 | I0 = read(path + 'sequences/' + name + '/im1.png')
34 | I1 = read(path + 'sequences/' + name + '/im2.png')
35 | I2 = read(path + 'sequences/' + name + '/im3.png')
36 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
37 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
38 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
39 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
40 |
41 | I1_pred = model.inference(I0, I2, embt)
42 |
43 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy()
44 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy()
45 |
46 | psnr_list.append(psnr)
47 | ssim_list.append(ssim)
48 |
49 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list)))
50 |
--------------------------------------------------------------------------------
/benchmarks/speed_parameters.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('.')
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from models.IFRNet import Model
9 | # from models.IFRNet_L import Model
10 | # from models.IFRNet_S import Model
11 |
12 | if torch.cuda.is_available():
13 | torch.backends.cudnn.enabled = True
14 | torch.backends.cudnn.benchmark = True
15 |
16 | img0 = torch.randn(1, 3, 256, 448).cuda()
17 | img1 = torch.randn(1, 3, 256, 448).cuda()
18 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).cuda()
19 |
20 | model = Model().cuda().eval()
21 |
22 | with torch.no_grad():
23 | for i in range(100):
24 | out = model.inference(img0, img1, embt)
25 | if torch.cuda.is_available():
26 | torch.cuda.synchronize()
27 | time_stamp = time.time()
28 | for i in range(100):
29 | out = model.inference(img0, img1, embt)
30 | if torch.cuda.is_available():
31 | torch.cuda.synchronize()
32 | print('Time: {:.3f}s'.format((time.time() - time_stamp) / 100))
33 |
34 | total = sum([param.nelement() for param in model.parameters()])
35 | print('Parameters: {:.2f}M'.format(total / 1e6))
36 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import cv2
5 | import torch
6 | from torch.utils.data import Dataset
7 | from utils import read
8 |
9 |
10 | def random_resize(img0, imgt, img1, flow, p=0.1):
11 | if random.uniform(0, 1) < p:
12 | img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
13 | imgt = cv2.resize(imgt, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
14 | img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
15 | flow = cv2.resize(flow, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) * 2.0
16 | return img0, imgt, img1, flow
17 |
18 |
19 | def random_crop(img0, imgt, img1, flow, crop_size=(224, 224)):
20 | h, w = crop_size[0], crop_size[1]
21 | ih, iw, _ = img0.shape
22 | x = np.random.randint(0, ih-h+1)
23 | y = np.random.randint(0, iw-w+1)
24 | img0 = img0[x:x+h, y:y+w, :]
25 | imgt = imgt[x:x+h, y:y+w, :]
26 | img1 = img1[x:x+h, y:y+w, :]
27 | flow = flow[x:x+h, y:y+w, :]
28 | return img0, imgt, img1, flow
29 |
30 |
31 | def random_reverse_channel(img0, imgt, img1, flow, p=0.5):
32 | if random.uniform(0, 1) < p:
33 | img0 = img0[:, :, ::-1]
34 | imgt = imgt[:, :, ::-1]
35 | img1 = img1[:, :, ::-1]
36 | return img0, imgt, img1, flow
37 |
38 |
39 | def random_vertical_flip(img0, imgt, img1, flow, p=0.3):
40 | if random.uniform(0, 1) < p:
41 | img0 = img0[::-1]
42 | imgt = imgt[::-1]
43 | img1 = img1[::-1]
44 | flow = flow[::-1]
45 | flow = np.concatenate((flow[:, :, 0:1], -flow[:, :, 1:2], flow[:, :, 2:3], -flow[:, :, 3:4]), 2)
46 | return img0, imgt, img1, flow
47 |
48 |
49 | def random_horizontal_flip(img0, imgt, img1, flow, p=0.5):
50 | if random.uniform(0, 1) < p:
51 | img0 = img0[:, ::-1]
52 | imgt = imgt[:, ::-1]
53 | img1 = img1[:, ::-1]
54 | flow = flow[:, ::-1]
55 | flow = np.concatenate((-flow[:, :, 0:1], flow[:, :, 1:2], -flow[:, :, 2:3], flow[:, :, 3:4]), 2)
56 | return img0, imgt, img1, flow
57 |
58 |
59 | def random_rotate(img0, imgt, img1, flow, p=0.05):
60 | if random.uniform(0, 1) < p:
61 | img0 = img0.transpose((1, 0, 2))
62 | imgt = imgt.transpose((1, 0, 2))
63 | img1 = img1.transpose((1, 0, 2))
64 | flow = flow.transpose((1, 0, 2))
65 | flow = np.concatenate((flow[:, :, 1:2], flow[:, :, 0:1], flow[:, :, 3:4], flow[:, :, 2:3]), 2)
66 | return img0, imgt, img1, flow
67 |
68 |
69 | def random_reverse_time(img0, imgt, img1, flow, p=0.5):
70 | if random.uniform(0, 1) < p:
71 | tmp = img1
72 | img1 = img0
73 | img0 = tmp
74 | flow = np.concatenate((flow[:, :, 2:4], flow[:, :, 0:2]), 2)
75 | return img0, imgt, img1, flow
76 |
77 |
78 | class Vimeo90K_Train_Dataset(Dataset):
79 | def __init__(self, dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet', augment=True):
80 | self.dataset_dir = dataset_dir
81 | self.augment = augment
82 | self.img0_list = []
83 | self.imgt_list = []
84 | self.img1_list = []
85 | self.flow_t0_list = []
86 | self.flow_t1_list = []
87 | with open(os.path.join(dataset_dir, 'tri_trainlist.txt'), 'r') as f:
88 | for i in f:
89 | name = str(i).strip()
90 | if(len(name) <= 1):
91 | continue
92 | self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png'))
93 | self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png'))
94 | self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png'))
95 | self.flow_t0_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t0.flo'))
96 | self.flow_t1_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t1.flo'))
97 |
98 | def __len__(self):
99 | return len(self.imgt_list)
100 |
101 | def __getitem__(self, idx):
102 | img0 = read(self.img0_list[idx])
103 | imgt = read(self.imgt_list[idx])
104 | img1 = read(self.img1_list[idx])
105 | flow_t0 = read(self.flow_t0_list[idx])
106 | flow_t1 = read(self.flow_t1_list[idx])
107 | flow = np.concatenate((flow_t0, flow_t1), 2).astype(np.float64)
108 |
109 | if self.augment == True:
110 | img0, imgt, img1, flow = random_resize(img0, imgt, img1, flow, p=0.1)
111 | img0, imgt, img1, flow = random_crop(img0, imgt, img1, flow, crop_size=(224, 224))
112 | img0, imgt, img1, flow = random_reverse_channel(img0, imgt, img1, flow, p=0.5)
113 | img0, imgt, img1, flow = random_vertical_flip(img0, imgt, img1, flow, p=0.3)
114 | img0, imgt, img1, flow = random_horizontal_flip(img0, imgt, img1, flow, p=0.5)
115 | img0, imgt, img1, flow = random_rotate(img0, imgt, img1, flow, p=0.05)
116 | img0, imgt, img1, flow = random_reverse_time(img0, imgt, img1, flow, p=0.5)
117 |
118 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0)
119 | imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0)
120 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0)
121 | flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32))
122 | embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32))
123 |
124 | return img0, imgt, img1, flow, embt
125 |
126 |
127 | class Vimeo90K_Test_Dataset(Dataset):
128 | def __init__(self, dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet'):
129 | self.dataset_dir = dataset_dir
130 | self.img0_list = []
131 | self.imgt_list = []
132 | self.img1_list = []
133 | self.flow_t0_list = []
134 | self.flow_t1_list = []
135 | with open(os.path.join(dataset_dir, 'tri_testlist.txt'), 'r') as f:
136 | for i in f:
137 | name = str(i).strip()
138 | if(len(name) <= 1):
139 | continue
140 | self.img0_list.append(os.path.join(dataset_dir, 'sequences', name, 'im1.png'))
141 | self.imgt_list.append(os.path.join(dataset_dir, 'sequences', name, 'im2.png'))
142 | self.img1_list.append(os.path.join(dataset_dir, 'sequences', name, 'im3.png'))
143 | self.flow_t0_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t0.flo'))
144 | self.flow_t1_list.append(os.path.join(dataset_dir, 'flow', name, 'flow_t1.flo'))
145 |
146 | def __len__(self):
147 | return len(self.imgt_list)
148 |
149 | def __getitem__(self, idx):
150 | img0 = read(self.img0_list[idx])
151 | imgt = read(self.imgt_list[idx])
152 | img1 = read(self.img1_list[idx])
153 | flow_t0 = read(self.flow_t0_list[idx])
154 | flow_t1 = read(self.flow_t1_list[idx])
155 | flow = np.concatenate((flow_t0, flow_t1), 2)
156 |
157 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0)
158 | imgt = torch.from_numpy(imgt.transpose((2, 0, 1)).astype(np.float32) / 255.0)
159 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0)
160 | flow = torch.from_numpy(flow.transpose((2, 0, 1)).astype(np.float32))
161 | embt = torch.from_numpy(np.array(1/2).reshape(1, 1, 1).astype(np.float32))
162 |
163 | return img0, imgt, img1, flow, embt
164 |
165 |
166 |
167 | def random_resize_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.1):
168 | if random.uniform(0, 1) < p:
169 | img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
170 | img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
171 | img2 = cv2.resize(img2, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
172 | img3 = cv2.resize(img3, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
173 | img4 = cv2.resize(img4, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
174 | img5 = cv2.resize(img5, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
175 | img6 = cv2.resize(img6, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
176 | img7 = cv2.resize(img7, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
177 | img8 = cv2.resize(img8, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR)
178 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
179 |
180 |
181 | def random_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(224, 224)):
182 | h, w = crop_size[0], crop_size[1]
183 | ih, iw, _ = img0.shape
184 | x = np.random.randint(0, ih-h+1)
185 | y = np.random.randint(0, iw-w+1)
186 | img0 = img0[x:x+h, y:y+w, :]
187 | img1 = img1[x:x+h, y:y+w, :]
188 | img2 = img2[x:x+h, y:y+w, :]
189 | img3 = img3[x:x+h, y:y+w, :]
190 | img4 = img4[x:x+h, y:y+w, :]
191 | img5 = img5[x:x+h, y:y+w, :]
192 | img6 = img6[x:x+h, y:y+w, :]
193 | img7 = img7[x:x+h, y:y+w, :]
194 | img8 = img8[x:x+h, y:y+w, :]
195 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
196 |
197 |
198 | def center_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(512, 512)):
199 | h, w = crop_size[0], crop_size[1]
200 | ih, iw, _ = img0.shape
201 | img0 = img0[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
202 | img1 = img1[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
203 | img2 = img2[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
204 | img3 = img3[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
205 | img4 = img4[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
206 | img5 = img5[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
207 | img6 = img6[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
208 | img7 = img7[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
209 | img8 = img8[(ih//2-h//2):(ih//2+h//2), (iw//2-w//2):(iw//2+w//2), :]
210 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
211 |
212 |
213 | def random_reverse_channel_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5):
214 | if random.uniform(0, 1) < p:
215 | img0 = img0[:, :, ::-1]
216 | img1 = img1[:, :, ::-1]
217 | img2 = img2[:, :, ::-1]
218 | img3 = img3[:, :, ::-1]
219 | img4 = img4[:, :, ::-1]
220 | img5 = img5[:, :, ::-1]
221 | img6 = img6[:, :, ::-1]
222 | img7 = img7[:, :, ::-1]
223 | img8 = img8[:, :, ::-1]
224 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
225 |
226 |
227 | def random_vertical_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.3):
228 | if random.uniform(0, 1) < p:
229 | img0 = img0[::-1]
230 | img1 = img1[::-1]
231 | img2 = img2[::-1]
232 | img3 = img3[::-1]
233 | img4 = img4[::-1]
234 | img5 = img5[::-1]
235 | img6 = img6[::-1]
236 | img7 = img7[::-1]
237 | img8 = img8[::-1]
238 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
239 |
240 |
241 | def random_horizontal_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5):
242 | if random.uniform(0, 1) < p:
243 | img0 = img0[:, ::-1]
244 | img1 = img1[:, ::-1]
245 | img2 = img2[:, ::-1]
246 | img3 = img3[:, ::-1]
247 | img4 = img4[:, ::-1]
248 | img5 = img5[:, ::-1]
249 | img6 = img6[:, ::-1]
250 | img7 = img7[:, ::-1]
251 | img8 = img8[:, ::-1]
252 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
253 |
254 |
255 | def random_rotate_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.05):
256 | if random.uniform(0, 1) < p:
257 | img0 = img0.transpose((1, 0, 2))
258 | img1 = img1.transpose((1, 0, 2))
259 | img2 = img2.transpose((1, 0, 2))
260 | img3 = img3.transpose((1, 0, 2))
261 | img4 = img4.transpose((1, 0, 2))
262 | img5 = img5.transpose((1, 0, 2))
263 | img6 = img6.transpose((1, 0, 2))
264 | img7 = img7.transpose((1, 0, 2))
265 | img8 = img8.transpose((1, 0, 2))
266 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
267 |
268 |
269 | def random_reverse_time_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5):
270 | if random.uniform(0, 1) < p:
271 | return img8, img7, img6, img5, img4, img3, img2, img1, img0
272 | else:
273 | return img0, img1, img2, img3, img4, img5, img6, img7, img8
274 |
275 |
276 | class GoPro_Train_Dataset(Dataset):
277 | def __init__(self, dataset_dir='/home/ltkong/Datasets/GOPRO', interFrames=7, n_inputs=2, augment=True):
278 | self.dataset_dir = dataset_dir
279 | self.interFrames = interFrames
280 | self.n_inputs = n_inputs
281 | self.augment = augment
282 | self.setLength = (n_inputs-1)*(interFrames+1)+1
283 | video_list = [
284 | 'GOPR0372_07_00', 'GOPR0374_11_01', 'GOPR0378_13_00', 'GOPR0384_11_01', 'GOPR0384_11_04', 'GOPR0477_11_00', 'GOPR0868_11_02', 'GOPR0884_11_00',
285 | 'GOPR0372_07_01', 'GOPR0374_11_02', 'GOPR0379_11_00', 'GOPR0384_11_02', 'GOPR0385_11_00', 'GOPR0857_11_00', 'GOPR0871_11_01', 'GOPR0374_11_00',
286 | 'GOPR0374_11_03', 'GOPR0380_11_00', 'GOPR0384_11_03', 'GOPR0386_11_00', 'GOPR0868_11_01', 'GOPR0881_11_00']
287 | self.frames_list = []
288 | self.file_list = []
289 | for video in video_list:
290 | frames = sorted(os.listdir(os.path.join(self.dataset_dir, video)))
291 | n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1
292 | videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength] for i in range(n_sets)]
293 | videoInputs = [[os.path.join(video, f) for f in group] for group in videoInputs]
294 | self.file_list.extend(videoInputs)
295 |
296 | def __len__(self):
297 | return len(self.file_list)
298 |
299 | def __getitem__(self, idx):
300 | imgpaths = [os.path.join(self.dataset_dir, fp) for fp in self.file_list[idx]]
301 | pick_idxs = list(range(0, self.setLength, self.interFrames+1))
302 | rem = self.interFrames%2
303 | gt_idx = list(range(self.setLength//2-self.interFrames//2, self.setLength//2+self.interFrames//2+rem))
304 | input_paths = [imgpaths[idx] for idx in pick_idxs]
305 | gt_paths = [imgpaths[idx] for idx in gt_idx]
306 | img0 = np.array(read(input_paths[0]))
307 | img1 = np.array(read(gt_paths[0]))
308 | img2 = np.array(read(gt_paths[1]))
309 | img3 = np.array(read(gt_paths[2]))
310 | img4 = np.array(read(gt_paths[3]))
311 | img5 = np.array(read(gt_paths[4]))
312 | img6 = np.array(read(gt_paths[5]))
313 | img7 = np.array(read(gt_paths[6]))
314 | img8 = np.array(read(input_paths[1]))
315 |
316 | if self.augment == True:
317 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_resize_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.1)
318 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(224, 224))
319 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_reverse_channel_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5)
320 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_vertical_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.3)
321 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_horizontal_flip_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5)
322 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_rotate_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.05)
323 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = random_reverse_time_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, p=0.5)
324 | else:
325 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = center_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(512, 512))
326 |
327 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0)
328 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0)
329 | img2 = torch.from_numpy(img2.transpose((2, 0, 1)).astype(np.float32) / 255.0)
330 | img3 = torch.from_numpy(img3.transpose((2, 0, 1)).astype(np.float32) / 255.0)
331 | img4 = torch.from_numpy(img4.transpose((2, 0, 1)).astype(np.float32) / 255.0)
332 | img5 = torch.from_numpy(img5.transpose((2, 0, 1)).astype(np.float32) / 255.0)
333 | img6 = torch.from_numpy(img6.transpose((2, 0, 1)).astype(np.float32) / 255.0)
334 | img7 = torch.from_numpy(img7.transpose((2, 0, 1)).astype(np.float32) / 255.0)
335 | img8 = torch.from_numpy(img8.transpose((2, 0, 1)).astype(np.float32) / 255.0)
336 |
337 | emb1 = torch.from_numpy(np.array(1/8).reshape(1, 1, 1).astype(np.float32))
338 | emb2 = torch.from_numpy(np.array(2/8).reshape(1, 1, 1).astype(np.float32))
339 | emb3 = torch.from_numpy(np.array(3/8).reshape(1, 1, 1).astype(np.float32))
340 | emb4 = torch.from_numpy(np.array(4/8).reshape(1, 1, 1).astype(np.float32))
341 | emb5 = torch.from_numpy(np.array(5/8).reshape(1, 1, 1).astype(np.float32))
342 | emb6 = torch.from_numpy(np.array(6/8).reshape(1, 1, 1).astype(np.float32))
343 | emb7 = torch.from_numpy(np.array(7/8).reshape(1, 1, 1).astype(np.float32))
344 |
345 | return img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7
346 |
347 |
348 | class GoPro_Test_Dataset(Dataset):
349 | def __init__(self, dataset_dir='/home/ltkong/Datasets/GOPRO', interFrames=7, n_inputs=2):
350 | self.dataset_dir = dataset_dir
351 | self.interFrames = interFrames
352 | self.n_inputs = n_inputs
353 | self.setLength = (n_inputs-1)*(interFrames+1)+1
354 | video_list = [
355 | 'GOPR0384_11_00', 'GOPR0385_11_01', 'GOPR0410_11_00', 'GOPR0862_11_00', 'GOPR0869_11_00', 'GOPR0881_11_01', 'GOPR0384_11_05', 'GOPR0396_11_00',
356 | 'GOPR0854_11_00', 'GOPR0868_11_00', 'GOPR0871_11_00']
357 | self.frames_list = []
358 | self.file_list = []
359 | for video in video_list:
360 | frames = sorted(os.listdir(os.path.join(self.dataset_dir, video)))
361 | n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1
362 | videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength] for i in range(n_sets)]
363 | videoInputs = [[os.path.join(video, f) for f in group] for group in videoInputs]
364 | self.file_list.extend(videoInputs)
365 |
366 | def __len__(self):
367 | return len(self.file_list)
368 |
369 | def __getitem__(self, idx):
370 | imgpaths = [os.path.join(self.dataset_dir, fp) for fp in self.file_list[idx]]
371 | pick_idxs = list(range(0, self.setLength, self.interFrames+1))
372 | rem = self.interFrames%2
373 | gt_idx = list(range(self.setLength//2-self.interFrames//2, self.setLength//2+self.interFrames//2+rem))
374 | input_paths = [imgpaths[idx] for idx in pick_idxs]
375 | gt_paths = [imgpaths[idx] for idx in gt_idx]
376 | img0 = np.array(read(input_paths[0]))
377 | img1 = np.array(read(gt_paths[0]))
378 | img2 = np.array(read(gt_paths[1]))
379 | img3 = np.array(read(gt_paths[2]))
380 | img4 = np.array(read(gt_paths[3]))
381 | img5 = np.array(read(gt_paths[4]))
382 | img6 = np.array(read(gt_paths[5]))
383 | img7 = np.array(read(gt_paths[6]))
384 | img8 = np.array(read(input_paths[1]))
385 |
386 | img0, img1, img2, img3, img4, img5, img6, img7, img8 = center_crop_8x(img0, img1, img2, img3, img4, img5, img6, img7, img8, crop_size=(512, 512))
387 |
388 | img0 = torch.from_numpy(img0.transpose((2, 0, 1)).astype(np.float32) / 255.0)
389 | img1 = torch.from_numpy(img1.transpose((2, 0, 1)).astype(np.float32) / 255.0)
390 | img2 = torch.from_numpy(img2.transpose((2, 0, 1)).astype(np.float32) / 255.0)
391 | img3 = torch.from_numpy(img3.transpose((2, 0, 1)).astype(np.float32) / 255.0)
392 | img4 = torch.from_numpy(img4.transpose((2, 0, 1)).astype(np.float32) / 255.0)
393 | img5 = torch.from_numpy(img5.transpose((2, 0, 1)).astype(np.float32) / 255.0)
394 | img6 = torch.from_numpy(img6.transpose((2, 0, 1)).astype(np.float32) / 255.0)
395 | img7 = torch.from_numpy(img7.transpose((2, 0, 1)).astype(np.float32) / 255.0)
396 | img8 = torch.from_numpy(img8.transpose((2, 0, 1)).astype(np.float32) / 255.0)
397 |
398 | emb1 = torch.from_numpy(np.array(1/8).reshape(1, 1, 1).astype(np.float32))
399 | emb2 = torch.from_numpy(np.array(2/8).reshape(1, 1, 1).astype(np.float32))
400 | emb3 = torch.from_numpy(np.array(3/8).reshape(1, 1, 1).astype(np.float32))
401 | emb4 = torch.from_numpy(np.array(4/8).reshape(1, 1, 1).astype(np.float32))
402 | emb5 = torch.from_numpy(np.array(5/8).reshape(1, 1, 1).astype(np.float32))
403 | emb6 = torch.from_numpy(np.array(6/8).reshape(1, 1, 1).astype(np.float32))
404 | emb7 = torch.from_numpy(np.array(7/8).reshape(1, 1, 1).astype(np.float32))
405 |
406 | return img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7
407 |
--------------------------------------------------------------------------------
/demo_2x.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from models.IFRNet import Model
5 | from utils import read
6 | from imageio import mimsave
7 |
8 |
9 | model = Model().cuda().eval()
10 | model.load_state_dict(torch.load('./checkpoints/IFRNet/IFRNet_Vimeo90K.pth'))
11 |
12 | img0_np = read('./figures/img0.png')
13 | img1_np = read('./figures/img1.png')
14 |
15 | img0 = (torch.tensor(img0_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda()
16 | img1 = (torch.tensor(img1_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda()
17 | embt = torch.tensor(1/2).view(1, 1, 1, 1).float().cuda()
18 |
19 | imgt_pred = model.inference(img0, img1, embt)
20 |
21 | imgt_pred_np = (imgt_pred[0].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
22 |
23 | images = [img0_np, imgt_pred_np, img1_np]
24 | mimsave('./figures/out_2x.gif', images, fps=3)
25 |
--------------------------------------------------------------------------------
/demo_8x.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from models.IFRNet import Model
5 | from utils import read
6 | from imageio import mimsave
7 |
8 |
9 | model = Model().cuda().eval()
10 | model.load_state_dict(torch.load('./checkpoints/IFRNet/IFRNet_GoPro.pth'))
11 |
12 | img0_np = read('./figures/img0.png')
13 | img8_np = read('./figures/img1.png')
14 |
15 | img0 = (torch.tensor(img0_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda()
16 | img8 = (torch.tensor(img8_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda()
17 |
18 | emb1 = torch.tensor(1/8).view(1, 1, 1, 1).float().cuda()
19 | emb2 = torch.tensor(2/8).view(1, 1, 1, 1).float().cuda()
20 | emb3 = torch.tensor(3/8).view(1, 1, 1, 1).float().cuda()
21 | emb4 = torch.tensor(4/8).view(1, 1, 1, 1).float().cuda()
22 | emb5 = torch.tensor(5/8).view(1, 1, 1, 1).float().cuda()
23 | emb6 = torch.tensor(6/8).view(1, 1, 1, 1).float().cuda()
24 | emb7 = torch.tensor(7/8).view(1, 1, 1, 1).float().cuda()
25 |
26 | img0_ = torch.cat([img0, img0, img0, img0, img0, img0, img0], 0)
27 | img8_ = torch.cat([img8, img8, img8, img8, img8, img8, img8], 0)
28 | embt = torch.cat([emb1, emb2, emb3, emb4, emb5, emb6, emb7], 0)
29 |
30 | imgt_pred = model.inference(img0_, img8_, embt)
31 |
32 | img1_pred_np = (imgt_pred[0].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
33 | img2_pred_np = (imgt_pred[1].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
34 | img3_pred_np = (imgt_pred[2].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
35 | img4_pred_np = (imgt_pred[3].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
36 | img5_pred_np = (imgt_pred[4].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
37 | img6_pred_np = (imgt_pred[5].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
38 | img7_pred_np = (imgt_pred[6].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
39 |
40 | images = [img0_np, img1_pred_np, img2_pred_np, img3_pred_np, img4_pred_np, img5_pred_np, img6_pred_np, img7_pred_np, img8_np]
41 | mimsave('./figures/out_8x.gif', images, fps=9)
42 |
--------------------------------------------------------------------------------
/figures/8x_interpolation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/8x_interpolation.png
--------------------------------------------------------------------------------
/figures/benchmarks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/benchmarks.png
--------------------------------------------------------------------------------
/figures/fig1_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig1_1.gif
--------------------------------------------------------------------------------
/figures/fig1_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig1_2.gif
--------------------------------------------------------------------------------
/figures/fig1_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig1_3.gif
--------------------------------------------------------------------------------
/figures/fig2_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig2_1.gif
--------------------------------------------------------------------------------
/figures/fig2_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/fig2_2.gif
--------------------------------------------------------------------------------
/figures/img0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/img0.png
--------------------------------------------------------------------------------
/figures/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/img1.png
--------------------------------------------------------------------------------
/figures/img_overlaid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/img_overlaid.png
--------------------------------------------------------------------------------
/figures/middlebury.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/middlebury.png
--------------------------------------------------------------------------------
/figures/middlebury_other.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/middlebury_other.png
--------------------------------------------------------------------------------
/figures/out_2x.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/out_2x.gif
--------------------------------------------------------------------------------
/figures/out_8x.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/out_8x.gif
--------------------------------------------------------------------------------
/figures/vimeo90k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ltkong218/IFRNet/b117bcafcf074b2de756b882f8a6ca02c3169bfe/figures/vimeo90k.png
--------------------------------------------------------------------------------
/generate_flow.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | from liteflownet.run import estimate
6 | from utils import read, write
7 |
8 |
9 | # set vimeo90k_dir with your Vimeo90K triplet dataset path, like '/.../vimeo_triplet'
10 | vimeo90k_dir = '/home/ltkong/Datasets/Vimeo90K/vimeo_triplet'
11 |
12 | vimeo90k_sequences_dir = os.path.join(vimeo90k_dir, 'sequences')
13 | vimeo90k_flow_dir = os.path.join(vimeo90k_dir, 'flow')
14 |
15 | if not os.path.exists(vimeo90k_flow_dir):
16 | os.makedirs(vimeo90k_flow_dir)
17 |
18 | for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
19 | vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path)
20 | vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path)
21 | if not os.path.exists(vimeo90k_flow_path_dir):
22 | os.mkdir(vimeo90k_flow_path_dir)
23 |
24 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
25 | vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id)
26 | if not os.path.exists(vimeo90k_flow_id_dir):
27 | os.mkdir(vimeo90k_flow_id_dir)
28 |
29 | print('Built Flow Path')
30 |
31 |
32 | def pred_flow(img1, img2):
33 | img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0
34 | img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0
35 |
36 | flow = estimate(img1, img2)
37 |
38 | flow = flow.permute(1, 2, 0).cpu().numpy()
39 | return flow
40 |
41 | for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
42 | vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path)
43 | vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path)
44 |
45 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
46 | vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id)
47 | vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id)
48 |
49 | img0_path = vimeo90k_sequences_id_dir + '/im1.png'
50 | imgt_path = vimeo90k_sequences_id_dir + '/im2.png'
51 | img1_path = vimeo90k_sequences_id_dir + '/im3.png'
52 | flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo'
53 | flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo'
54 |
55 | img0 = read(img0_path)
56 | imgt = read(imgt_path)
57 | img1 = read(img1_path)
58 |
59 | flow_t0 = pred_flow(imgt, img0)
60 | flow_t1 = pred_flow(imgt, img1)
61 |
62 | write(flow_t0_path, flow_t0)
63 | write(flow_t1_path, flow_t1)
64 |
65 | print('Written Sequences {}'.format(sequences_path))
66 |
--------------------------------------------------------------------------------
/liteflownet/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-liteflownet
2 | This is a personal reimplementation of LiteFlowNet [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2].
3 |
4 |
5 |
6 | For the original Caffe version of this work, please see: https://github.com/twhui/LiteFlowNet
7 |
8 | Other optical flow implementations from me: [pytorch-pwc](https://github.com/sniklaus/pytorch-pwc), [pytorch-unflow](https://github.com/sniklaus/pytorch-unflow), [pytorch-spynet](https://github.com/sniklaus/pytorch-spynet)
9 |
10 | ## setup
11 | The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository. If you would like to use Docker, you can take a look at [this](https://github.com/sniklaus/pytorch-liteflownet/pull/43) pull request to get started.
12 |
13 | ## usage
14 | To run it on your own pair of images, use the following command. You can choose between three models, please make sure to see their paper / the code for more details.
15 |
16 | ```
17 | python run.py --model default --one ./images/one.png --two ./images/two.png --out ./out.flo
18 | ```
19 |
20 | I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results pretty much identical to the implementation of the original authors in the examples that I tried. There are some numerical deviations that stem from differences in the `DownsampleLayer` of Caffe and the `torch.nn.functional.interpolate` function of PyTorch. Please feel free to contribute to this repository by submitting issues and pull requests.
21 |
22 | ## comparison
23 | 
24 |
25 | ## license
26 | As stated in the licensing terms of the authors of the paper, their material is provided for research purposes only. Please make sure to further consult their licensing terms.
27 |
28 | ## references
29 | ```
30 | [1] @inproceedings{Hui_CVPR_2018,
31 | author = {Tak-Wai Hui and Xiaoou Tang and Chen Change Loy},
32 | title = {{LiteFlowNet}: A Lightweight Convolutional Neural Network for Optical Flow Estimation},
33 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
34 | year = {2018}
35 | }
36 | ```
37 |
38 | ```
39 | [2] @misc{pytorch-liteflownet,
40 | author = {Simon Niklaus},
41 | title = {A Reimplementation of {LiteFlowNet} Using {PyTorch}},
42 | year = {2019},
43 | howpublished = {\url{https://github.com/sniklaus/pytorch-liteflownet}}
44 | }
45 | ```
--------------------------------------------------------------------------------
/liteflownet/correlation/README.md:
--------------------------------------------------------------------------------
1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
--------------------------------------------------------------------------------
/liteflownet/correlation/correlation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import cupy
4 | import math
5 | import re
6 | import torch
7 |
8 | kernel_Correlation_rearrange = '''
9 | extern "C" __global__ void kernel_Correlation_rearrange(
10 | const int n,
11 | const float* input,
12 | float* output
13 | ) {
14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
15 |
16 | if (intIndex >= n) {
17 | return;
18 | }
19 |
20 | int intSample = blockIdx.z;
21 | int intChannel = blockIdx.y;
22 |
23 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
24 |
25 | __syncthreads();
26 |
27 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}};
28 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}};
29 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX;
30 |
31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
32 | }
33 | '''
34 |
35 | kernel_Correlation_updateOutput = '''
36 | extern "C" __global__ void kernel_Correlation_updateOutput(
37 | const int n,
38 | const float* rbot0,
39 | const float* rbot1,
40 | float* top
41 | ) {
42 | extern __shared__ char patch_data_char[];
43 |
44 | float *patch_data = (float *)patch_data_char;
45 |
46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
47 | int x1 = (blockIdx.x + 3) * {{intStride}};
48 | int y1 = (blockIdx.y + 3) * {{intStride}};
49 | int item = blockIdx.z;
50 | int ch_off = threadIdx.x;
51 |
52 | // Load 3D patch into shared shared memory
53 | for (int j = 0; j < 1; j++) { // HEIGHT
54 | for (int i = 0; i < 1; i++) { // WIDTH
55 | int ji_off = (j + i) * SIZE_3(rbot0);
56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
58 | int idxPatchData = ji_off + ch;
59 | patch_data[idxPatchData] = rbot0[idx1];
60 | }
61 | }
62 | }
63 |
64 | __syncthreads();
65 |
66 | __shared__ float sum[32];
67 |
68 | // Compute correlation
69 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
70 | sum[ch_off] = 0;
71 |
72 | int s2o = (top_channel % 7 - 3) * {{intStride}};
73 | int s2p = (top_channel / 7 - 3) * {{intStride}};
74 |
75 | for (int j = 0; j < 1; j++) { // HEIGHT
76 | for (int i = 0; i < 1; i++) { // WIDTH
77 | int ji_off = (j + i) * SIZE_3(rbot0);
78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
79 | int x2 = x1 + s2o;
80 | int y2 = y1 + s2p;
81 |
82 | int idxPatchData = ji_off + ch;
83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
84 |
85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
86 | }
87 | }
88 | }
89 |
90 | __syncthreads();
91 |
92 | if (ch_off == 0) {
93 | float total_sum = 0;
94 | for (int idx = 0; idx < 32; idx++) {
95 | total_sum += sum[idx];
96 | }
97 | const int sumelems = SIZE_3(rbot0);
98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
100 | }
101 | }
102 | }
103 | '''
104 |
105 | kernel_Correlation_updateGradOne = '''
106 | #define ROUND_OFF 50000
107 |
108 | extern "C" __global__ void kernel_Correlation_updateGradOne(
109 | const int n,
110 | const int intSample,
111 | const float* rbot0,
112 | const float* rbot1,
113 | const float* gradOutput,
114 | float* gradOne,
115 | float* gradTwo
116 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
117 | int n = intIndex % SIZE_1(gradOne); // channels
118 | int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 3*{{intStride}}; // w-pos
119 | int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 3*{{intStride}}; // h-pos
120 |
121 | // round_off is a trick to enable integer division with ceil, even for negative numbers
122 | // We use a large offset, for the inner part not to become negative.
123 | const int round_off = ROUND_OFF;
124 | const int round_off_s1 = {{intStride}} * round_off;
125 |
126 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
127 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
128 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
129 |
130 | // Same here:
131 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}}
132 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}}
133 |
134 | float sum = 0;
135 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
136 | xmin = max(0,xmin);
137 | xmax = min(SIZE_3(gradOutput)-1,xmax);
138 |
139 | ymin = max(0,ymin);
140 | ymax = min(SIZE_2(gradOutput)-1,ymax);
141 |
142 | for (int p = -3; p <= 3; p++) {
143 | for (int o = -3; o <= 3; o++) {
144 | // Get rbot1 data:
145 | int s2o = {{intStride}} * o;
146 | int s2p = {{intStride}} * p;
147 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
148 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
149 |
150 | // Index offset for gradOutput in following loops:
151 | int op = (p+3) * 7 + (o+3); // index[o,p]
152 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
153 |
154 | for (int y = ymin; y <= ymax; y++) {
155 | for (int x = xmin; x <= xmax; x++) {
156 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
157 | sum += gradOutput[idxgradOutput] * bot1tmp;
158 | }
159 | }
160 | }
161 | }
162 | }
163 | const int sumelems = SIZE_1(gradOne);
164 | const int bot0index = ((n * SIZE_2(gradOne)) + (m-3*{{intStride}})) * SIZE_3(gradOne) + (l-3*{{intStride}});
165 | gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems;
166 | } }
167 | '''
168 |
169 | kernel_Correlation_updateGradTwo = '''
170 | #define ROUND_OFF 50000
171 |
172 | extern "C" __global__ void kernel_Correlation_updateGradTwo(
173 | const int n,
174 | const int intSample,
175 | const float* rbot0,
176 | const float* rbot1,
177 | const float* gradOutput,
178 | float* gradOne,
179 | float* gradTwo
180 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
181 | int n = intIndex % SIZE_1(gradTwo); // channels
182 | int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 3*{{intStride}}; // w-pos
183 | int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 3*{{intStride}}; // h-pos
184 |
185 | // round_off is a trick to enable integer division with ceil, even for negative numbers
186 | // We use a large offset, for the inner part not to become negative.
187 | const int round_off = ROUND_OFF;
188 | const int round_off_s1 = {{intStride}} * round_off;
189 |
190 | float sum = 0;
191 | for (int p = -3; p <= 3; p++) {
192 | for (int o = -3; o <= 3; o++) {
193 | int s2o = {{intStride}} * o;
194 | int s2p = {{intStride}} * p;
195 |
196 | //Get X,Y ranges and clamp
197 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
198 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
199 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
200 |
201 | // Same here:
202 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}}
203 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}}
204 |
205 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
206 | xmin = max(0,xmin);
207 | xmax = min(SIZE_3(gradOutput)-1,xmax);
208 |
209 | ymin = max(0,ymin);
210 | ymax = min(SIZE_2(gradOutput)-1,ymax);
211 |
212 | // Get rbot0 data:
213 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
214 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
215 |
216 | // Index offset for gradOutput in following loops:
217 | int op = (p+3) * 7 + (o+3); // index[o,p]
218 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
219 |
220 | for (int y = ymin; y <= ymax; y++) {
221 | for (int x = xmin; x <= xmax; x++) {
222 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
223 | sum += gradOutput[idxgradOutput] * bot0tmp;
224 | }
225 | }
226 | }
227 | }
228 | }
229 | const int sumelems = SIZE_1(gradTwo);
230 | const int bot1index = ((n * SIZE_2(gradTwo)) + (m-3*{{intStride}})) * SIZE_3(gradTwo) + (l-3*{{intStride}});
231 | gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems;
232 | } }
233 | '''
234 |
235 | def cupy_kernel(strFunction, objVariables):
236 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride']))
237 |
238 | while True:
239 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
240 |
241 | if objMatch is None:
242 | break
243 | # end
244 |
245 | intArg = int(objMatch.group(2))
246 |
247 | strTensor = objMatch.group(4)
248 | intSizes = objVariables[strTensor].size()
249 |
250 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
251 | # end
252 |
253 | while True:
254 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
255 |
256 | if objMatch is None:
257 | break
258 | # end
259 |
260 | intArgs = int(objMatch.group(2))
261 | strArgs = objMatch.group(4).split(',')
262 |
263 | strTensor = strArgs[0]
264 | intStrides = objVariables[strTensor].stride()
265 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
266 |
267 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
268 | # end
269 |
270 | return strKernel
271 | # end
272 |
273 | @cupy.memoize(for_each_device=True)
274 | def cupy_launch(strFunction, strKernel):
275 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
276 | # end
277 |
278 | class _FunctionCorrelation(torch.autograd.Function):
279 | @staticmethod
280 | def forward(self, one, two, intStride):
281 | rbot0 = one.new_zeros([ one.shape[0], one.shape[2] + (6 * intStride), one.shape[3] + (6 * intStride), one.shape[1] ])
282 | rbot1 = one.new_zeros([ one.shape[0], one.shape[2] + (6 * intStride), one.shape[3] + (6 * intStride), one.shape[1] ])
283 |
284 | self.intStride = intStride
285 |
286 | one = one.contiguous(); assert(one.is_cuda == True)
287 | two = two.contiguous(); assert(two.is_cuda == True)
288 |
289 | output = one.new_zeros([ one.shape[0], 49, int(math.ceil(one.shape[2] / intStride)), int(math.ceil(one.shape[3] / intStride)) ])
290 |
291 | if one.is_cuda == True:
292 | n = one.shape[2] * one.shape[3]
293 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
294 | 'intStride': self.intStride,
295 | 'input': one,
296 | 'output': rbot0
297 | }))(
298 | grid=tuple([ int((n + 16 - 1) / 16), one.shape[1], one.shape[0] ]),
299 | block=tuple([ 16, 1, 1 ]),
300 | args=[ cupy.int32(n), one.data_ptr(), rbot0.data_ptr() ]
301 | )
302 |
303 | n = two.shape[2] * two.shape[3]
304 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
305 | 'intStride': self.intStride,
306 | 'input': two,
307 | 'output': rbot1
308 | }))(
309 | grid=tuple([ int((n + 16 - 1) / 16), two.shape[1], two.shape[0] ]),
310 | block=tuple([ 16, 1, 1 ]),
311 | args=[ cupy.int32(n), two.data_ptr(), rbot1.data_ptr() ]
312 | )
313 |
314 | n = output.shape[1] * output.shape[2] * output.shape[3]
315 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
316 | 'intStride': self.intStride,
317 | 'rbot0': rbot0,
318 | 'rbot1': rbot1,
319 | 'top': output
320 | }))(
321 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
322 | block=tuple([ 32, 1, 1 ]),
323 | shared_mem=one.shape[1] * 4,
324 | args=[ cupy.int32(n), rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
325 | )
326 |
327 | elif one.is_cuda == False:
328 | raise NotImplementedError()
329 |
330 | # end
331 |
332 | self.save_for_backward(one, two, rbot0, rbot1)
333 |
334 | return output
335 | # end
336 |
337 | @staticmethod
338 | def backward(self, gradOutput):
339 | one, two, rbot0, rbot1 = self.saved_tensors
340 |
341 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
342 |
343 | gradOne = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[0] == True else None
344 | gradTwo = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[1] == True else None
345 |
346 | if one.is_cuda == True:
347 | if gradOne is not None:
348 | for intSample in range(one.shape[0]):
349 | n = one.shape[1] * one.shape[2] * one.shape[3]
350 | cupy_launch('kernel_Correlation_updateGradOne', cupy_kernel('kernel_Correlation_updateGradOne', {
351 | 'intStride': self.intStride,
352 | 'rbot0': rbot0,
353 | 'rbot1': rbot1,
354 | 'gradOutput': gradOutput,
355 | 'gradOne': gradOne,
356 | 'gradTwo': None
357 | }))(
358 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
359 | block=tuple([ 512, 1, 1 ]),
360 | args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradOne.data_ptr(), None ]
361 | )
362 | # end
363 | # end
364 |
365 | if gradTwo is not None:
366 | for intSample in range(one.shape[0]):
367 | n = one.shape[1] * one.shape[2] * one.shape[3]
368 | cupy_launch('kernel_Correlation_updateGradTwo', cupy_kernel('kernel_Correlation_updateGradTwo', {
369 | 'intStride': self.intStride,
370 | 'rbot0': rbot0,
371 | 'rbot1': rbot1,
372 | 'gradOutput': gradOutput,
373 | 'gradOne': None,
374 | 'gradTwo': gradTwo
375 | }))(
376 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
377 | block=tuple([ 512, 1, 1 ]),
378 | args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradTwo.data_ptr() ]
379 | )
380 | # end
381 | # end
382 |
383 | elif one.is_cuda == False:
384 | raise NotImplementedError()
385 |
386 | # end
387 |
388 | return gradOne, gradTwo, None
389 | # end
390 | # end
391 |
392 | def FunctionCorrelation(tenOne, tenTwo, intStride):
393 | return _FunctionCorrelation.apply(tenOne, tenTwo, intStride)
394 | # end
395 |
396 | class ModuleCorrelation(torch.nn.Module):
397 | def __init__(self):
398 | super().__init__()
399 | # end
400 |
401 | def forward(self, tenOne, tenTwo, intStride):
402 | return _FunctionCorrelation.apply(tenOne, tenTwo, intStride)
403 | # end
404 | # end
--------------------------------------------------------------------------------
/liteflownet/requirements.txt:
--------------------------------------------------------------------------------
1 | cupy>=5.0.0
2 | numpy>=1.15.0
3 | Pillow>=5.0.0
4 | torch>=1.3.0
--------------------------------------------------------------------------------
/liteflownet/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import getopt
4 | import math
5 | import numpy
6 | import PIL
7 | import PIL.Image
8 | import sys
9 | import torch
10 |
11 | try:
12 | from .correlation import correlation # the custom cost volume layer
13 | except:
14 | sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python
15 | # end
16 |
17 | ##########################################################
18 |
19 | assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0
20 |
21 | torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
22 |
23 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
24 |
25 | ##########################################################
26 |
27 | arguments_strModel = 'default' # 'default', or 'kitti', or 'sintel'
28 | arguments_strOne = './images/one.png'
29 | arguments_strTwo = './images/two.png'
30 | arguments_strOut = './out.flo'
31 |
32 | for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
33 | if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use
34 | if strOption == '--one' and strArgument != '': arguments_strOne = strArgument # path to the first frame
35 | if strOption == '--two' and strArgument != '': arguments_strTwo = strArgument # path to the second frame
36 | if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored
37 | # end
38 |
39 | ##########################################################
40 |
41 | backwarp_tenGrid = {}
42 |
43 | def backwarp(tenInput, tenFlow):
44 | if str(tenFlow.shape) not in backwarp_tenGrid:
45 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
46 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
47 |
48 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
49 | # end
50 |
51 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
52 |
53 | return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)
54 | # end
55 |
56 | ##########################################################
57 |
58 | class Network(torch.nn.Module):
59 | def __init__(self):
60 | super().__init__()
61 |
62 | class Features(torch.nn.Module):
63 | def __init__(self):
64 | super().__init__()
65 |
66 | self.netOne = torch.nn.Sequential(
67 | torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7, stride=1, padding=3),
68 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
69 | )
70 |
71 | self.netTwo = torch.nn.Sequential(
72 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
73 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
74 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
75 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
76 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
77 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
78 | )
79 |
80 | self.netThr = torch.nn.Sequential(
81 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
82 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
83 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
84 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
85 | )
86 |
87 | self.netFou = torch.nn.Sequential(
88 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
89 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
90 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
91 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
92 | )
93 |
94 | self.netFiv = torch.nn.Sequential(
95 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
96 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
97 | )
98 |
99 | self.netSix = torch.nn.Sequential(
100 | torch.nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=2, padding=1),
101 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
102 | )
103 | # end
104 |
105 | def forward(self, tenInput):
106 | tenOne = self.netOne(tenInput)
107 | tenTwo = self.netTwo(tenOne)
108 | tenThr = self.netThr(tenTwo)
109 | tenFou = self.netFou(tenThr)
110 | tenFiv = self.netFiv(tenFou)
111 | tenSix = self.netSix(tenFiv)
112 |
113 | return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]
114 | # end
115 | # end
116 |
117 | class Matching(torch.nn.Module):
118 | def __init__(self, intLevel):
119 | super().__init__()
120 |
121 | self.fltBackwarp = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel]
122 |
123 | if intLevel != 2:
124 | self.netFeat = torch.nn.Sequential()
125 |
126 | elif intLevel == 2:
127 | self.netFeat = torch.nn.Sequential(
128 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0),
129 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
130 | )
131 |
132 | # end
133 |
134 | if intLevel == 6:
135 | self.netUpflow = None
136 |
137 | elif intLevel != 6:
138 | self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1, bias=False, groups=2)
139 |
140 | # end
141 |
142 | if intLevel >= 4:
143 | self.netUpcorr = None
144 |
145 | elif intLevel < 4:
146 | self.netUpcorr = torch.nn.ConvTranspose2d(in_channels=49, out_channels=49, kernel_size=4, stride=2, padding=1, bias=False, groups=49)
147 |
148 | # end
149 |
150 | self.netMain = torch.nn.Sequential(
151 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1),
152 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
153 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
154 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
155 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
156 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
157 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel])
158 | )
159 | # end
160 |
161 | def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow):
162 | tenFeaturesOne = self.netFeat(tenFeaturesOne)
163 | tenFeaturesTwo = self.netFeat(tenFeaturesTwo)
164 |
165 | if tenFlow is not None:
166 | tenFlow = self.netUpflow(tenFlow)
167 | # end
168 |
169 | if tenFlow is not None:
170 | tenFeaturesTwo = backwarp(tenInput=tenFeaturesTwo, tenFlow=tenFlow * self.fltBackwarp)
171 | # end
172 |
173 | if self.netUpcorr is None:
174 | tenCorrelation = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenFeaturesOne, tenTwo=tenFeaturesTwo, intStride=1), negative_slope=0.1, inplace=False)
175 |
176 | elif self.netUpcorr is not None:
177 | tenCorrelation = self.netUpcorr(torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenFeaturesOne, tenTwo=tenFeaturesTwo, intStride=2), negative_slope=0.1, inplace=False))
178 |
179 | # end
180 |
181 | return (tenFlow if tenFlow is not None else 0.0) + self.netMain(tenCorrelation)
182 | # end
183 | # end
184 |
185 | class Subpixel(torch.nn.Module):
186 | def __init__(self, intLevel):
187 | super().__init__()
188 |
189 | self.fltBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel]
190 |
191 | if intLevel != 2:
192 | self.netFeat = torch.nn.Sequential()
193 |
194 | elif intLevel == 2:
195 | self.netFeat = torch.nn.Sequential(
196 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0),
197 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
198 | )
199 |
200 | # end
201 |
202 | self.netMain = torch.nn.Sequential(
203 | torch.nn.Conv2d(in_channels=[ 0, 0, 130, 130, 194, 258, 386 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1),
204 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
205 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
206 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
207 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
208 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
209 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel])
210 | )
211 | # end
212 |
213 | def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow):
214 | tenFeaturesOne = self.netFeat(tenFeaturesOne)
215 | tenFeaturesTwo = self.netFeat(tenFeaturesTwo)
216 |
217 | if tenFlow is not None:
218 | tenFeaturesTwo = backwarp(tenInput=tenFeaturesTwo, tenFlow=tenFlow * self.fltBackward)
219 | # end
220 |
221 | return (tenFlow if tenFlow is not None else 0.0) + self.netMain(torch.cat([ tenFeaturesOne, tenFeaturesTwo, tenFlow ], 1))
222 | # end
223 | # end
224 |
225 | class Regularization(torch.nn.Module):
226 | def __init__(self, intLevel):
227 | super().__init__()
228 |
229 | self.fltBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel]
230 |
231 | self.intUnfold = [ 0, 0, 7, 5, 5, 3, 3 ][intLevel]
232 |
233 | if intLevel >= 5:
234 | self.netFeat = torch.nn.Sequential()
235 |
236 | elif intLevel < 5:
237 | self.netFeat = torch.nn.Sequential(
238 | torch.nn.Conv2d(in_channels=[ 0, 0, 32, 64, 96, 128, 192 ][intLevel], out_channels=128, kernel_size=1, stride=1, padding=0),
239 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
240 | )
241 |
242 | # end
243 |
244 | self.netMain = torch.nn.Sequential(
245 | torch.nn.Conv2d(in_channels=[ 0, 0, 131, 131, 131, 131, 195 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1),
246 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
247 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
248 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
249 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
250 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
251 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
252 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
253 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
254 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
255 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
256 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
257 | )
258 |
259 | if intLevel >= 5:
260 | self.netDist = torch.nn.Sequential(
261 | torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel])
262 | )
263 |
264 | elif intLevel < 5:
265 | self.netDist = torch.nn.Sequential(
266 | torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=([ 0, 0, 7, 5, 5, 3, 3 ][intLevel], 1), stride=1, padding=([ 0, 0, 3, 2, 2, 1, 1 ][intLevel], 0)),
267 | torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=(1, [ 0, 0, 7, 5, 5, 3, 3 ][intLevel]), stride=1, padding=(0, [ 0, 0, 3, 2, 2, 1, 1 ][intLevel]))
268 | )
269 |
270 | # end
271 |
272 | self.netScaleX = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0)
273 | self.netScaleY = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0)
274 | # eny
275 |
276 | def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow):
277 | tenDifference = ((tenOne - backwarp(tenInput=tenTwo, tenFlow=tenFlow * self.fltBackward)) ** 2).sum(1, True).sqrt().detach()
278 |
279 | tenDist = self.netDist(self.netMain(torch.cat([ tenDifference, tenFlow - tenFlow.view(tenFlow.shape[0], 2, -1).mean(2, True).view(tenFlow.shape[0], 2, 1, 1), self.netFeat(tenFeaturesOne) ], 1)))
280 | tenDist = (tenDist ** 2).neg()
281 | tenDist = (tenDist - tenDist.max(1, True)[0]).exp()
282 |
283 | tenDivisor = tenDist.sum(1, True).reciprocal()
284 |
285 | tenScaleX = self.netScaleX(tenDist * torch.nn.functional.unfold(input=tenFlow[:, 0:1, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tenDist)) * tenDivisor
286 | tenScaleY = self.netScaleY(tenDist * torch.nn.functional.unfold(input=tenFlow[:, 1:2, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tenDist)) * tenDivisor
287 |
288 | return torch.cat([ tenScaleX, tenScaleY ], 1)
289 | # end
290 | # end
291 |
292 | self.netFeatures = Features()
293 | self.netMatching = torch.nn.ModuleList([ Matching(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ])
294 | self.netSubpixel = torch.nn.ModuleList([ Subpixel(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ])
295 | self.netRegularization = torch.nn.ModuleList([ Regularization(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ])
296 |
297 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-liteflownet/network-' + arguments_strModel + '.pytorch').items() })
298 | # self.load_state_dict(torch.load('./liteflownet/network-default.pth'))
299 | # end
300 |
301 | def forward(self, tenOne, tenTwo):
302 | tenOne[:, 0, :, :] = tenOne[:, 0, :, :] - 0.411618
303 | tenOne[:, 1, :, :] = tenOne[:, 1, :, :] - 0.434631
304 | tenOne[:, 2, :, :] = tenOne[:, 2, :, :] - 0.454253
305 |
306 | tenTwo[:, 0, :, :] = tenTwo[:, 0, :, :] - 0.410782
307 | tenTwo[:, 1, :, :] = tenTwo[:, 1, :, :] - 0.433645
308 | tenTwo[:, 2, :, :] = tenTwo[:, 2, :, :] - 0.452793
309 |
310 | tenFeaturesOne = self.netFeatures(tenOne)
311 | tenFeaturesTwo = self.netFeatures(tenTwo)
312 |
313 | tenOne = [ tenOne ]
314 | tenTwo = [ tenTwo ]
315 |
316 | for intLevel in [ 1, 2, 3, 4, 5 ]:
317 | tenOne.append(torch.nn.functional.interpolate(input=tenOne[-1], size=(tenFeaturesOne[intLevel].shape[2], tenFeaturesOne[intLevel].shape[3]), mode='bilinear', align_corners=False))
318 | tenTwo.append(torch.nn.functional.interpolate(input=tenTwo[-1], size=(tenFeaturesTwo[intLevel].shape[2], tenFeaturesTwo[intLevel].shape[3]), mode='bilinear', align_corners=False))
319 | # end
320 |
321 | tenFlow = None
322 |
323 | for intLevel in [ -1, -2, -3, -4, -5 ]:
324 | tenFlow = self.netMatching[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow)
325 | tenFlow = self.netSubpixel[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow)
326 | tenFlow = self.netRegularization[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow)
327 | # end
328 |
329 | return tenFlow * 20.0
330 | # end
331 | # end
332 |
333 | netNetwork = None
334 |
335 | ##########################################################
336 |
337 | def estimate(tenOne, tenTwo):
338 | global netNetwork
339 |
340 | if netNetwork is None:
341 | netNetwork = Network().cuda().eval()
342 | # end
343 |
344 | assert(tenOne.shape[1] == tenTwo.shape[1])
345 | assert(tenOne.shape[2] == tenTwo.shape[2])
346 |
347 | intWidth = tenOne.shape[2]
348 | intHeight = tenOne.shape[1]
349 |
350 | # assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
351 | # assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
352 |
353 | tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth)
354 | tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth)
355 |
356 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0))
357 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0))
358 |
359 | tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
360 | tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
361 |
362 | tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
363 |
364 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
365 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
366 |
367 | return tenFlow[0, :, :, :].cpu()
368 | # end
369 |
370 | ##########################################################
371 |
372 | if __name__ == '__main__':
373 | tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strOne))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
374 | tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strTwo))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
375 |
376 | tenOutput = estimate(tenOne, tenTwo)
377 |
378 | objOutput = open(arguments_strOut, 'wb')
379 |
380 | numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
381 | numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
382 | numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
383 |
384 | objOutput.close()
385 | # end
386 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7 |
8 |
9 | class Ternary(nn.Module):
10 | def __init__(self, patch_size=7):
11 | super(Ternary, self).__init__()
12 | self.patch_size = patch_size
13 | out_channels = patch_size * patch_size
14 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
15 | self.w = np.transpose(self.w, (3, 2, 0, 1))
16 | self.w = torch.tensor(self.w).float().to(device)
17 |
18 | def transform(self, tensor):
19 | tensor_ = tensor.mean(dim=1, keepdim=True)
20 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None)
21 | loc_diff = patches - tensor_
22 | loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff ** 2)
23 | return loc_diff_norm
24 |
25 | def valid_mask(self, tensor):
26 | padding = self.patch_size//2
27 | b, c, h, w = tensor.size()
28 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
29 | mask = F.pad(inner, [padding] * 4)
30 | return mask
31 |
32 | def forward(self, x, y):
33 | loc_diff_x = self.transform(x)
34 | loc_diff_y = self.transform(y)
35 | diff = loc_diff_x - loc_diff_y.detach()
36 | dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True)
37 | mask = self.valid_mask(x)
38 | loss = (dist * mask).mean()
39 | return loss
40 |
41 |
42 | class Geometry(nn.Module):
43 | def __init__(self, patch_size=3):
44 | super(Geometry, self).__init__()
45 | self.patch_size = patch_size
46 | out_channels = patch_size * patch_size
47 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
48 | self.w = np.transpose(self.w, (3, 2, 0, 1))
49 | self.w = torch.tensor(self.w).float().to(device)
50 |
51 | def transform(self, tensor):
52 | b, c, h, w = tensor.size()
53 | tensor_ = tensor.reshape(b*c, 1, h, w)
54 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None)
55 | loc_diff = patches - tensor_
56 | loc_diff_ = loc_diff.reshape(b, c*(self.patch_size**2), h, w)
57 | loc_diff_norm = loc_diff_ / torch.sqrt(0.81 + loc_diff_ ** 2)
58 | return loc_diff_norm
59 |
60 | def valid_mask(self, tensor):
61 | padding = self.patch_size//2
62 | b, c, h, w = tensor.size()
63 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
64 | mask = F.pad(inner, [padding] * 4)
65 | return mask
66 |
67 | def forward(self, x, y):
68 | loc_diff_x = self.transform(x)
69 | loc_diff_y = self.transform(y)
70 | diff = loc_diff_x - loc_diff_y
71 | dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True)
72 | mask = self.valid_mask(x)
73 | loss = (dist * mask).mean()
74 | return loss
75 |
76 |
77 | class Charbonnier_L1(nn.Module):
78 | def __init__(self):
79 | super(Charbonnier_L1, self).__init__()
80 |
81 | def forward(self, diff, mask=None):
82 | if mask is None:
83 | loss = ((diff ** 2 + 1e-6) ** 0.5).mean()
84 | else:
85 | loss = (((diff ** 2 + 1e-6) ** 0.5) * mask).mean() / (mask.mean() + 1e-9)
86 | return loss
87 |
88 |
89 | class Charbonnier_Ada(nn.Module):
90 | def __init__(self):
91 | super(Charbonnier_Ada, self).__init__()
92 |
93 | def forward(self, diff, weight):
94 | alpha = weight / 2
95 | epsilon = 10 ** (-(10 * weight - 1) / 3)
96 | loss = ((diff ** 2 + epsilon ** 2) ** alpha).mean()
97 | return loss
98 |
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from math import exp
5 |
6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7 |
8 |
9 | def gaussian(window_size, sigma):
10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11 | return gauss/gauss.sum()
12 |
13 |
14 | def create_window(window_size, channel=1):
15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
17 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
18 | return window
19 |
20 |
21 | def create_window_3d(window_size, channel=1):
22 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
23 | _2D_window = _1D_window.mm(_1D_window.t())
24 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
25 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
26 | return window
27 |
28 |
29 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
30 | if val_range is None:
31 | if torch.max(img1) > 128:
32 | max_val = 255
33 | else:
34 | max_val = 1
35 |
36 | if torch.min(img1) < -0.5:
37 | min_val = -1
38 | else:
39 | min_val = 0
40 | L = max_val - min_val
41 | else:
42 | L = val_range
43 |
44 | padd = 0
45 | (_, channel, height, width) = img1.size()
46 | if window is None:
47 | real_size = min(window_size, height, width)
48 | window = create_window(real_size, channel=channel).to(img1.device)
49 |
50 | mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
51 | mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
52 |
53 | mu1_sq = mu1.pow(2)
54 | mu2_sq = mu2.pow(2)
55 | mu1_mu2 = mu1 * mu2
56 |
57 | sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
58 | sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
59 | sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
60 |
61 | C1 = (0.01 * L) ** 2
62 | C2 = (0.03 * L) ** 2
63 |
64 | v1 = 2.0 * sigma12 + C2
65 | v2 = sigma1_sq + sigma2_sq + C2
66 | cs = torch.mean(v1 / v2)
67 |
68 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
69 |
70 | if size_average:
71 | ret = ssim_map.mean()
72 | else:
73 | ret = ssim_map.mean(1).mean(1).mean(1)
74 |
75 | if full:
76 | return ret, cs
77 | return ret
78 |
79 |
80 | def calculate_ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
81 | if val_range is None:
82 | if torch.max(img1) > 128:
83 | max_val = 255
84 | else:
85 | max_val = 1
86 |
87 | if torch.min(img1) < -0.5:
88 | min_val = -1
89 | else:
90 | min_val = 0
91 | L = max_val - min_val
92 | else:
93 | L = val_range
94 |
95 | padd = 0
96 | (_, _, height, width) = img1.size()
97 | if window is None:
98 | real_size = min(window_size, height, width)
99 | window = create_window_3d(real_size, channel=1).to(img1.device)
100 |
101 | img1 = img1.unsqueeze(1)
102 | img2 = img2.unsqueeze(1)
103 |
104 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
105 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
106 |
107 | mu1_sq = mu1.pow(2)
108 | mu2_sq = mu2.pow(2)
109 | mu1_mu2 = mu1 * mu2
110 |
111 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
112 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
113 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
114 |
115 | C1 = (0.01 * L) ** 2
116 | C2 = (0.03 * L) ** 2
117 |
118 | v1 = 2.0 * sigma12 + C2
119 | v2 = sigma1_sq + sigma2_sq + C2
120 | cs = torch.mean(v1 / v2)
121 |
122 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
123 |
124 | if size_average:
125 | ret = ssim_map.mean()
126 | else:
127 | ret = ssim_map.mean(1).mean(1).mean(1)
128 |
129 | if full:
130 | return ret, cs
131 | return ret
132 |
133 |
134 | def calculate_psnr(img1, img2):
135 | psnr = -10 * torch.log10(((img1 - img2) * (img1 - img2)).mean())
136 | return psnr
137 |
138 |
139 | def calculate_ie(img1, img2):
140 | ie = torch.abs(torch.round(img1 * 255.0) - torch.round(img2 * 255.0)).mean()
141 | return ie
142 |
--------------------------------------------------------------------------------
/models/IFRNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils import warp, get_robust_weight
5 | from loss import *
6 |
7 |
8 | def resize(x, scale_factor):
9 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
10 |
11 |
12 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
13 | return nn.Sequential(
14 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
15 | nn.PReLU(out_channels)
16 | )
17 |
18 |
19 | class ResBlock(nn.Module):
20 | def __init__(self, in_channels, side_channels, bias=True):
21 | super(ResBlock, self).__init__()
22 | self.side_channels = side_channels
23 | self.conv1 = nn.Sequential(
24 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
25 | nn.PReLU(in_channels)
26 | )
27 | self.conv2 = nn.Sequential(
28 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
29 | nn.PReLU(side_channels)
30 | )
31 | self.conv3 = nn.Sequential(
32 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
33 | nn.PReLU(in_channels)
34 | )
35 | self.conv4 = nn.Sequential(
36 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
37 | nn.PReLU(side_channels)
38 | )
39 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
40 | self.prelu = nn.PReLU(in_channels)
41 |
42 | def forward(self, x):
43 | out = self.conv1(x)
44 | out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone())
45 | out = self.conv3(out)
46 | out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone())
47 | out = self.prelu(x + self.conv5(out))
48 | return out
49 |
50 |
51 | class Encoder(nn.Module):
52 | def __init__(self):
53 | super(Encoder, self).__init__()
54 | self.pyramid1 = nn.Sequential(
55 | convrelu(3, 32, 3, 2, 1),
56 | convrelu(32, 32, 3, 1, 1)
57 | )
58 | self.pyramid2 = nn.Sequential(
59 | convrelu(32, 48, 3, 2, 1),
60 | convrelu(48, 48, 3, 1, 1)
61 | )
62 | self.pyramid3 = nn.Sequential(
63 | convrelu(48, 72, 3, 2, 1),
64 | convrelu(72, 72, 3, 1, 1)
65 | )
66 | self.pyramid4 = nn.Sequential(
67 | convrelu(72, 96, 3, 2, 1),
68 | convrelu(96, 96, 3, 1, 1)
69 | )
70 |
71 | def forward(self, img):
72 | f1 = self.pyramid1(img)
73 | f2 = self.pyramid2(f1)
74 | f3 = self.pyramid3(f2)
75 | f4 = self.pyramid4(f3)
76 | return f1, f2, f3, f4
77 |
78 |
79 | class Decoder4(nn.Module):
80 | def __init__(self):
81 | super(Decoder4, self).__init__()
82 | self.convblock = nn.Sequential(
83 | convrelu(192+1, 192),
84 | ResBlock(192, 32),
85 | nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True)
86 | )
87 |
88 | def forward(self, f0, f1, embt):
89 | b, c, h, w = f0.shape
90 | embt = embt.repeat(1, 1, h, w)
91 | f_in = torch.cat([f0, f1, embt], 1)
92 | f_out = self.convblock(f_in)
93 | return f_out
94 |
95 |
96 | class Decoder3(nn.Module):
97 | def __init__(self):
98 | super(Decoder3, self).__init__()
99 | self.convblock = nn.Sequential(
100 | convrelu(220, 216),
101 | ResBlock(216, 32),
102 | nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True)
103 | )
104 |
105 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
106 | f0_warp = warp(f0, up_flow0)
107 | f1_warp = warp(f1, up_flow1)
108 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
109 | f_out = self.convblock(f_in)
110 | return f_out
111 |
112 |
113 | class Decoder2(nn.Module):
114 | def __init__(self):
115 | super(Decoder2, self).__init__()
116 | self.convblock = nn.Sequential(
117 | convrelu(148, 144),
118 | ResBlock(144, 32),
119 | nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True)
120 | )
121 |
122 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
123 | f0_warp = warp(f0, up_flow0)
124 | f1_warp = warp(f1, up_flow1)
125 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
126 | f_out = self.convblock(f_in)
127 | return f_out
128 |
129 |
130 | class Decoder1(nn.Module):
131 | def __init__(self):
132 | super(Decoder1, self).__init__()
133 | self.convblock = nn.Sequential(
134 | convrelu(100, 96),
135 | ResBlock(96, 32),
136 | nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True)
137 | )
138 |
139 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
140 | f0_warp = warp(f0, up_flow0)
141 | f1_warp = warp(f1, up_flow1)
142 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
143 | f_out = self.convblock(f_in)
144 | return f_out
145 |
146 |
147 | class Model(nn.Module):
148 | def __init__(self, local_rank=-1, lr=1e-4):
149 | super(Model, self).__init__()
150 | self.encoder = Encoder()
151 | self.decoder4 = Decoder4()
152 | self.decoder3 = Decoder3()
153 | self.decoder2 = Decoder2()
154 | self.decoder1 = Decoder1()
155 | self.l1_loss = Charbonnier_L1()
156 | self.tr_loss = Ternary(7)
157 | self.rb_loss = Charbonnier_Ada()
158 | self.gc_loss = Geometry(3)
159 |
160 |
161 | def inference(self, img0, img1, embt, scale_factor=1.0):
162 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
163 | img0 = img0 - mean_
164 | img1 = img1 - mean_
165 |
166 | img0_ = resize(img0, scale_factor=scale_factor)
167 | img1_ = resize(img1, scale_factor=scale_factor)
168 |
169 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
170 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
171 |
172 | out4 = self.decoder4(f0_4, f1_4, embt)
173 | up_flow0_4 = out4[:, 0:2]
174 | up_flow1_4 = out4[:, 2:4]
175 | ft_3_ = out4[:, 4:]
176 |
177 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
178 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
179 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
180 | ft_2_ = out3[:, 4:]
181 |
182 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
183 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
184 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
185 | ft_1_ = out2[:, 4:]
186 |
187 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
188 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
189 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
190 | up_mask_1 = torch.sigmoid(out1[:, 4:5])
191 | up_res_1 = out1[:, 5:]
192 |
193 | up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
194 | up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
195 | up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor))
196 | up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor))
197 |
198 | img0_warp = warp(img0, up_flow0_1)
199 | img1_warp = warp(img1, up_flow1_1)
200 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
201 | imgt_pred = imgt_merge + up_res_1
202 | imgt_pred = torch.clamp(imgt_pred, 0, 1)
203 | return imgt_pred
204 |
205 |
206 | def forward(self, img0, img1, embt, imgt, flow=None):
207 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
208 | img0 = img0 - mean_
209 | img1 = img1 - mean_
210 | imgt_ = imgt - mean_
211 |
212 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0)
213 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1)
214 | ft_1, ft_2, ft_3, ft_4 = self.encoder(imgt_)
215 |
216 | out4 = self.decoder4(f0_4, f1_4, embt)
217 | up_flow0_4 = out4[:, 0:2]
218 | up_flow1_4 = out4[:, 2:4]
219 | ft_3_ = out4[:, 4:]
220 |
221 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
222 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
223 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
224 | ft_2_ = out3[:, 4:]
225 |
226 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
227 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
228 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
229 | ft_1_ = out2[:, 4:]
230 |
231 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
232 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
233 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
234 | up_mask_1 = torch.sigmoid(out1[:, 4:5])
235 | up_res_1 = out1[:, 5:]
236 |
237 | img0_warp = warp(img0, up_flow0_1)
238 | img1_warp = warp(img1, up_flow1_1)
239 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
240 | imgt_pred = imgt_merge + up_res_1
241 | imgt_pred = torch.clamp(imgt_pred, 0, 1)
242 |
243 | loss_rec = self.l1_loss(imgt_pred - imgt) + self.tr_loss(imgt_pred, imgt)
244 | loss_geo = 0.01 * (self.gc_loss(ft_1_, ft_1) + self.gc_loss(ft_2_, ft_2) + self.gc_loss(ft_3_, ft_3))
245 | if flow is not None:
246 | robust_weight0 = get_robust_weight(up_flow0_1, flow[:, 0:2], beta=0.3)
247 | robust_weight1 = get_robust_weight(up_flow1_1, flow[:, 2:4], beta=0.3)
248 | loss_dis = 0.01 * (self.rb_loss(2.0 * resize(up_flow0_2, 2.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(2.0 * resize(up_flow1_2, 2.0) - flow[:, 2:4], weight=robust_weight1))
249 | loss_dis += 0.01 * (self.rb_loss(4.0 * resize(up_flow0_3, 4.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(4.0 * resize(up_flow1_3, 4.0) - flow[:, 2:4], weight=robust_weight1))
250 | loss_dis += 0.01 * (self.rb_loss(8.0 * resize(up_flow0_4, 8.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(8.0 * resize(up_flow1_4, 8.0) - flow[:, 2:4], weight=robust_weight1))
251 | else:
252 | loss_dis = 0.00 * loss_geo
253 |
254 | return imgt_pred, loss_rec, loss_geo, loss_dis
255 |
--------------------------------------------------------------------------------
/models/IFRNet_L.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils import warp, get_robust_weight
5 | from loss import *
6 |
7 |
8 | def resize(x, scale_factor):
9 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
10 |
11 |
12 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
13 | return nn.Sequential(
14 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
15 | nn.PReLU(out_channels)
16 | )
17 |
18 |
19 | class ResBlock(nn.Module):
20 | def __init__(self, in_channels, side_channels, bias=True):
21 | super(ResBlock, self).__init__()
22 | self.side_channels = side_channels
23 | self.conv1 = nn.Sequential(
24 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
25 | nn.PReLU(in_channels)
26 | )
27 | self.conv2 = nn.Sequential(
28 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
29 | nn.PReLU(side_channels)
30 | )
31 | self.conv3 = nn.Sequential(
32 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
33 | nn.PReLU(in_channels)
34 | )
35 | self.conv4 = nn.Sequential(
36 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
37 | nn.PReLU(side_channels)
38 | )
39 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
40 | self.prelu = nn.PReLU(in_channels)
41 |
42 | def forward(self, x):
43 | out = self.conv1(x)
44 | out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone())
45 | out = self.conv3(out)
46 | out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone())
47 | out = self.prelu(x + self.conv5(out))
48 | return out
49 |
50 |
51 | class Encoder(nn.Module):
52 | def __init__(self):
53 | super(Encoder, self).__init__()
54 | self.pyramid1 = nn.Sequential(
55 | convrelu(3, 64, 7, 2, 3),
56 | convrelu(64, 64, 3, 1, 1)
57 | )
58 | self.pyramid2 = nn.Sequential(
59 | convrelu(64, 96, 3, 2, 1),
60 | convrelu(96, 96, 3, 1, 1)
61 | )
62 | self.pyramid3 = nn.Sequential(
63 | convrelu(96, 144, 3, 2, 1),
64 | convrelu(144, 144, 3, 1, 1)
65 | )
66 | self.pyramid4 = nn.Sequential(
67 | convrelu(144, 192, 3, 2, 1),
68 | convrelu(192, 192, 3, 1, 1)
69 | )
70 |
71 | def forward(self, img):
72 | f1 = self.pyramid1(img)
73 | f2 = self.pyramid2(f1)
74 | f3 = self.pyramid3(f2)
75 | f4 = self.pyramid4(f3)
76 | return f1, f2, f3, f4
77 |
78 |
79 | class Decoder4(nn.Module):
80 | def __init__(self):
81 | super(Decoder4, self).__init__()
82 | self.convblock = nn.Sequential(
83 | convrelu(384+1, 384),
84 | ResBlock(384, 64),
85 | nn.ConvTranspose2d(384, 148, 4, 2, 1, bias=True)
86 | )
87 |
88 | def forward(self, f0, f1, embt):
89 | b, c, h, w = f0.shape
90 | embt = embt.repeat(1, 1, h, w)
91 | f_in = torch.cat([f0, f1, embt], 1)
92 | f_out = self.convblock(f_in)
93 | return f_out
94 |
95 |
96 | class Decoder3(nn.Module):
97 | def __init__(self):
98 | super(Decoder3, self).__init__()
99 | self.convblock = nn.Sequential(
100 | convrelu(436, 432),
101 | ResBlock(432, 64),
102 | nn.ConvTranspose2d(432, 100, 4, 2, 1, bias=True)
103 | )
104 |
105 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
106 | f0_warp = warp(f0, up_flow0)
107 | f1_warp = warp(f1, up_flow1)
108 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
109 | f_out = self.convblock(f_in)
110 | return f_out
111 |
112 |
113 | class Decoder2(nn.Module):
114 | def __init__(self):
115 | super(Decoder2, self).__init__()
116 | self.convblock = nn.Sequential(
117 | convrelu(292, 288),
118 | ResBlock(288, 64),
119 | nn.ConvTranspose2d(288, 68, 4, 2, 1, bias=True)
120 | )
121 |
122 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
123 | f0_warp = warp(f0, up_flow0)
124 | f1_warp = warp(f1, up_flow1)
125 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
126 | f_out = self.convblock(f_in)
127 | return f_out
128 |
129 |
130 | class Decoder1(nn.Module):
131 | def __init__(self):
132 | super(Decoder1, self).__init__()
133 | self.convblock = nn.Sequential(
134 | convrelu(196, 192),
135 | ResBlock(192, 64),
136 | nn.ConvTranspose2d(192, 8, 4, 2, 1, bias=True)
137 | )
138 |
139 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
140 | f0_warp = warp(f0, up_flow0)
141 | f1_warp = warp(f1, up_flow1)
142 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
143 | f_out = self.convblock(f_in)
144 | return f_out
145 |
146 |
147 | class Model(nn.Module):
148 | def __init__(self, local_rank=-1, lr=1e-4):
149 | super(Model, self).__init__()
150 | self.encoder = Encoder()
151 | self.decoder4 = Decoder4()
152 | self.decoder3 = Decoder3()
153 | self.decoder2 = Decoder2()
154 | self.decoder1 = Decoder1()
155 | self.l1_loss = Charbonnier_L1()
156 | self.tr_loss = Ternary(7)
157 | self.rb_loss = Charbonnier_Ada()
158 | self.gc_loss = Geometry(3)
159 |
160 |
161 | def inference(self, img0, img1, embt, scale_factor=1.0):
162 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
163 | img0 = img0 - mean_
164 | img1 = img1 - mean_
165 |
166 | img0_ = resize(img0, scale_factor=scale_factor)
167 | img1_ = resize(img1, scale_factor=scale_factor)
168 |
169 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
170 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
171 |
172 | out4 = self.decoder4(f0_4, f1_4, embt)
173 | up_flow0_4 = out4[:, 0:2]
174 | up_flow1_4 = out4[:, 2:4]
175 | ft_3_ = out4[:, 4:]
176 |
177 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
178 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
179 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
180 | ft_2_ = out3[:, 4:]
181 |
182 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
183 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
184 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
185 | ft_1_ = out2[:, 4:]
186 |
187 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
188 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
189 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
190 | up_mask_1 = torch.sigmoid(out1[:, 4:5])
191 | up_res_1 = out1[:, 5:]
192 |
193 | up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
194 | up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
195 | up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor))
196 | up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor))
197 |
198 | img0_warp = warp(img0, up_flow0_1)
199 | img1_warp = warp(img1, up_flow1_1)
200 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
201 | imgt_pred = imgt_merge + up_res_1
202 | imgt_pred = torch.clamp(imgt_pred, 0, 1)
203 | return imgt_pred
204 |
205 |
206 | def forward(self, img0, img1, embt, imgt, flow=None):
207 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
208 | img0 = img0 - mean_
209 | img1 = img1 - mean_
210 | imgt_ = imgt - mean_
211 |
212 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0)
213 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1)
214 | ft_1, ft_2, ft_3, ft_4 = self.encoder(imgt_)
215 |
216 | out4 = self.decoder4(f0_4, f1_4, embt)
217 | up_flow0_4 = out4[:, 0:2]
218 | up_flow1_4 = out4[:, 2:4]
219 | ft_3_ = out4[:, 4:]
220 |
221 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
222 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
223 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
224 | ft_2_ = out3[:, 4:]
225 |
226 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
227 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
228 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
229 | ft_1_ = out2[:, 4:]
230 |
231 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
232 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
233 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
234 | up_mask_1 = torch.sigmoid(out1[:, 4:5])
235 | up_res_1 = out1[:, 5:]
236 |
237 | img0_warp = warp(img0, up_flow0_1)
238 | img1_warp = warp(img1, up_flow1_1)
239 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
240 | imgt_pred = imgt_merge + up_res_1
241 | imgt_pred = torch.clamp(imgt_pred, 0, 1)
242 |
243 | loss_rec = self.l1_loss(imgt_pred - imgt) + self.tr_loss(imgt_pred, imgt)
244 | loss_geo = 0.01 * (self.gc_loss(ft_1_, ft_1) + self.gc_loss(ft_2_, ft_2) + self.gc_loss(ft_3_, ft_3))
245 | if flow is not None:
246 | robust_weight0 = get_robust_weight(up_flow0_1, flow[:, 0:2], beta=0.3)
247 | robust_weight1 = get_robust_weight(up_flow1_1, flow[:, 2:4], beta=0.3)
248 | loss_dis = 0.01 * (self.rb_loss(2.0 * resize(up_flow0_2, 2.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(2.0 * resize(up_flow1_2, 2.0) - flow[:, 2:4], weight=robust_weight1))
249 | loss_dis += 0.01 * (self.rb_loss(4.0 * resize(up_flow0_3, 4.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(4.0 * resize(up_flow1_3, 4.0) - flow[:, 2:4], weight=robust_weight1))
250 | loss_dis += 0.01 * (self.rb_loss(8.0 * resize(up_flow0_4, 8.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(8.0 * resize(up_flow1_4, 8.0) - flow[:, 2:4], weight=robust_weight1))
251 | else:
252 | loss_dis = 0.00 * loss_geo
253 |
254 | return imgt_pred, loss_rec, loss_geo, loss_dis
255 |
--------------------------------------------------------------------------------
/models/IFRNet_S.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils import warp, get_robust_weight
5 | from loss import *
6 |
7 |
8 | def resize(x, scale_factor):
9 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
10 |
11 |
12 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
13 | return nn.Sequential(
14 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
15 | nn.PReLU(out_channels)
16 | )
17 |
18 |
19 | class ResBlock(nn.Module):
20 | def __init__(self, in_channels, side_channels, bias=True):
21 | super(ResBlock, self).__init__()
22 | self.side_channels = side_channels
23 | self.conv1 = nn.Sequential(
24 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
25 | nn.PReLU(in_channels)
26 | )
27 | self.conv2 = nn.Sequential(
28 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
29 | nn.PReLU(side_channels)
30 | )
31 | self.conv3 = nn.Sequential(
32 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
33 | nn.PReLU(in_channels)
34 | )
35 | self.conv4 = nn.Sequential(
36 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
37 | nn.PReLU(side_channels)
38 | )
39 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
40 | self.prelu = nn.PReLU(in_channels)
41 |
42 | def forward(self, x):
43 | out = self.conv1(x)
44 | out[:, -self.side_channels:, :, :] = self.conv2(out[:, -self.side_channels:, :, :].clone())
45 | out = self.conv3(out)
46 | out[:, -self.side_channels:, :, :] = self.conv4(out[:, -self.side_channels:, :, :].clone())
47 | out = self.prelu(x + self.conv5(out))
48 | return out
49 |
50 |
51 | class Encoder(nn.Module):
52 | def __init__(self):
53 | super(Encoder, self).__init__()
54 | self.pyramid1 = nn.Sequential(
55 | convrelu(3, 24, 3, 2, 1),
56 | convrelu(24, 24, 3, 1, 1)
57 | )
58 | self.pyramid2 = nn.Sequential(
59 | convrelu(24, 36, 3, 2, 1),
60 | convrelu(36, 36, 3, 1, 1)
61 | )
62 | self.pyramid3 = nn.Sequential(
63 | convrelu(36, 54, 3, 2, 1),
64 | convrelu(54, 54, 3, 1, 1)
65 | )
66 | self.pyramid4 = nn.Sequential(
67 | convrelu(54, 72, 3, 2, 1),
68 | convrelu(72, 72, 3, 1, 1)
69 | )
70 |
71 | def forward(self, img):
72 | f1 = self.pyramid1(img)
73 | f2 = self.pyramid2(f1)
74 | f3 = self.pyramid3(f2)
75 | f4 = self.pyramid4(f3)
76 | return f1, f2, f3, f4
77 |
78 |
79 | class Decoder4(nn.Module):
80 | def __init__(self):
81 | super(Decoder4, self).__init__()
82 | self.convblock = nn.Sequential(
83 | convrelu(144+1, 144),
84 | ResBlock(144, 24),
85 | nn.ConvTranspose2d(144, 58, 4, 2, 1, bias=True)
86 | )
87 |
88 | def forward(self, f0, f1, embt):
89 | b, c, h, w = f0.shape
90 | embt = embt.repeat(1, 1, h, w)
91 | f_in = torch.cat([f0, f1, embt], 1)
92 | f_out = self.convblock(f_in)
93 | return f_out
94 |
95 |
96 | class Decoder3(nn.Module):
97 | def __init__(self):
98 | super(Decoder3, self).__init__()
99 | self.convblock = nn.Sequential(
100 | convrelu(166, 162),
101 | ResBlock(162, 24),
102 | nn.ConvTranspose2d(162, 40, 4, 2, 1, bias=True)
103 | )
104 |
105 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
106 | f0_warp = warp(f0, up_flow0)
107 | f1_warp = warp(f1, up_flow1)
108 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
109 | f_out = self.convblock(f_in)
110 | return f_out
111 |
112 |
113 | class Decoder2(nn.Module):
114 | def __init__(self):
115 | super(Decoder2, self).__init__()
116 | self.convblock = nn.Sequential(
117 | convrelu(112, 108),
118 | ResBlock(108, 24),
119 | nn.ConvTranspose2d(108, 28, 4, 2, 1, bias=True)
120 | )
121 |
122 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
123 | f0_warp = warp(f0, up_flow0)
124 | f1_warp = warp(f1, up_flow1)
125 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
126 | f_out = self.convblock(f_in)
127 | return f_out
128 |
129 |
130 | class Decoder1(nn.Module):
131 | def __init__(self):
132 | super(Decoder1, self).__init__()
133 | self.convblock = nn.Sequential(
134 | convrelu(76, 72),
135 | ResBlock(72, 24),
136 | nn.ConvTranspose2d(72, 8, 4, 2, 1, bias=True)
137 | )
138 |
139 | def forward(self, ft_, f0, f1, up_flow0, up_flow1):
140 | f0_warp = warp(f0, up_flow0)
141 | f1_warp = warp(f1, up_flow1)
142 | f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
143 | f_out = self.convblock(f_in)
144 | return f_out
145 |
146 |
147 | class Model(nn.Module):
148 | def __init__(self, local_rank=-1, lr=1e-4):
149 | super(Model, self).__init__()
150 | self.encoder = Encoder()
151 | self.decoder4 = Decoder4()
152 | self.decoder3 = Decoder3()
153 | self.decoder2 = Decoder2()
154 | self.decoder1 = Decoder1()
155 | self.l1_loss = Charbonnier_L1()
156 | self.tr_loss = Ternary(7)
157 | self.rb_loss = Charbonnier_Ada()
158 | self.gc_loss = Geometry(3)
159 |
160 |
161 | def inference(self, img0, img1, embt, scale_factor=1.0):
162 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
163 | img0 = img0 - mean_
164 | img1 = img1 - mean_
165 |
166 | img0_ = resize(img0, scale_factor=scale_factor)
167 | img1_ = resize(img1, scale_factor=scale_factor)
168 |
169 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
170 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
171 |
172 | out4 = self.decoder4(f0_4, f1_4, embt)
173 | up_flow0_4 = out4[:, 0:2]
174 | up_flow1_4 = out4[:, 2:4]
175 | ft_3_ = out4[:, 4:]
176 |
177 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
178 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
179 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
180 | ft_2_ = out3[:, 4:]
181 |
182 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
183 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
184 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
185 | ft_1_ = out2[:, 4:]
186 |
187 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
188 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
189 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
190 | up_mask_1 = torch.sigmoid(out1[:, 4:5])
191 | up_res_1 = out1[:, 5:]
192 |
193 | up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
194 | up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
195 | up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor))
196 | up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor))
197 |
198 | img0_warp = warp(img0, up_flow0_1)
199 | img1_warp = warp(img1, up_flow1_1)
200 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
201 | imgt_pred = imgt_merge + up_res_1
202 | imgt_pred = torch.clamp(imgt_pred, 0, 1)
203 | return imgt_pred
204 |
205 |
206 | def forward(self, img0, img1, embt, imgt, flow=None):
207 | mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
208 | img0 = img0 - mean_
209 | img1 = img1 - mean_
210 | imgt_ = imgt - mean_
211 |
212 | f0_1, f0_2, f0_3, f0_4 = self.encoder(img0)
213 | f1_1, f1_2, f1_3, f1_4 = self.encoder(img1)
214 | ft_1, ft_2, ft_3, ft_4 = self.encoder(imgt_)
215 |
216 | out4 = self.decoder4(f0_4, f1_4, embt)
217 | up_flow0_4 = out4[:, 0:2]
218 | up_flow1_4 = out4[:, 2:4]
219 | ft_3_ = out4[:, 4:]
220 |
221 | out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
222 | up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
223 | up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
224 | ft_2_ = out3[:, 4:]
225 |
226 | out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
227 | up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
228 | up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
229 | ft_1_ = out2[:, 4:]
230 |
231 | out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
232 | up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
233 | up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
234 | up_mask_1 = torch.sigmoid(out1[:, 4:5])
235 | up_res_1 = out1[:, 5:]
236 |
237 | img0_warp = warp(img0, up_flow0_1)
238 | img1_warp = warp(img1, up_flow1_1)
239 | imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
240 | imgt_pred = imgt_merge + up_res_1
241 | imgt_pred = torch.clamp(imgt_pred, 0, 1)
242 |
243 | loss_rec = self.l1_loss(imgt_pred - imgt) + self.tr_loss(imgt_pred, imgt)
244 | loss_geo = 0.01 * (self.gc_loss(ft_1_, ft_1) + self.gc_loss(ft_2_, ft_2) + self.gc_loss(ft_3_, ft_3))
245 | if flow is not None:
246 | robust_weight0 = get_robust_weight(up_flow0_1, flow[:, 0:2], beta=0.3)
247 | robust_weight1 = get_robust_weight(up_flow1_1, flow[:, 2:4], beta=0.3)
248 | loss_dis = 0.01 * (self.rb_loss(2.0 * resize(up_flow0_2, 2.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(2.0 * resize(up_flow1_2, 2.0) - flow[:, 2:4], weight=robust_weight1))
249 | loss_dis += 0.01 * (self.rb_loss(4.0 * resize(up_flow0_3, 4.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(4.0 * resize(up_flow1_3, 4.0) - flow[:, 2:4], weight=robust_weight1))
250 | loss_dis += 0.01 * (self.rb_loss(8.0 * resize(up_flow0_4, 8.0) - flow[:, 0:2], weight=robust_weight0) + self.rb_loss(8.0 * resize(up_flow1_4, 8.0) - flow[:, 2:4], weight=robust_weight1))
251 | else:
252 | loss_dis = 0.00 * loss_geo
253 |
254 | return imgt_pred, loss_rec, loss_geo, loss_dis
255 |
--------------------------------------------------------------------------------
/train_gopro.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import random
5 | import argparse
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from torch.utils.data.distributed import DistributedSampler
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from datasets import GoPro_Train_Dataset, GoPro_Test_Dataset
15 | from metric import calculate_psnr, calculate_ssim
16 | from utils import AverageMeter
17 | import logging
18 |
19 |
20 | def get_lr(args, iters):
21 | ratio = 0.5 * (1.0 + np.cos(iters / (args.epochs * args.iters_per_epoch) * math.pi))
22 | lr = (args.lr_start - args.lr_end) * ratio + args.lr_end
23 | return lr
24 |
25 |
26 | def set_lr(optimizer, lr):
27 | for param_group in optimizer.param_groups:
28 | param_group['lr'] = lr
29 |
30 |
31 | def train(args, ddp_model):
32 | local_rank = args.local_rank
33 | print('Distributed Data Parallel Training IFRNet on Rank {}'.format(local_rank))
34 |
35 | if local_rank == 0:
36 | os.makedirs(args.log_path, exist_ok=True)
37 | log_path = os.path.join(args.log_path, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
38 | os.makedirs(log_path, exist_ok=True)
39 | logger = logging.getLogger()
40 | logger.setLevel('INFO')
41 | BASIC_FORMAT = '%(asctime)s:%(levelname)s:%(message)s'
42 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
43 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
44 | chlr = logging.StreamHandler()
45 | chlr.setFormatter(formatter)
46 | chlr.setLevel('INFO')
47 | fhlr = logging.FileHandler(os.path.join(log_path, 'train.log'))
48 | fhlr.setFormatter(formatter)
49 | logger.addHandler(chlr)
50 | logger.addHandler(fhlr)
51 | logger.info(args)
52 |
53 | dataset_train = GoPro_Train_Dataset(dataset_dir='/home/ltkong/Datasets/GOPRO', augment=True)
54 | sampler = DistributedSampler(dataset_train)
55 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=sampler)
56 | args.iters_per_epoch = dataloader_train.__len__()
57 | iters = args.resume_epoch * args.iters_per_epoch
58 |
59 | dataset_val = GoPro_Test_Dataset(dataset_dir='/home/ltkong/Datasets/GOPRO')
60 | dataloader_val = DataLoader(dataset_val, batch_size=2, num_workers=4, pin_memory=True, shuffle=False, drop_last=True)
61 |
62 | optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr_start, weight_decay=0)
63 |
64 | time_stamp = time.time()
65 | avg_rec = AverageMeter()
66 | avg_geo = AverageMeter()
67 | avg_dis = AverageMeter()
68 | best_psnr = 0.0
69 |
70 | for epoch in range(args.resume_epoch, args.epochs):
71 | sampler.set_epoch(epoch)
72 | for i, data in enumerate(dataloader_train):
73 | for l in range(len(data)):
74 | data[l] = data[l].to(args.device)
75 | img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7 = data
76 |
77 | img0 = torch.cat([img0, img0, img0, img0, img0, img0, img0], 0)
78 | img8 = torch.cat([img8, img8, img8, img8, img8, img8, img8], 0)
79 | imgt = torch.cat([img1, img2, img3, img4, img5, img6, img7], 0)
80 | embt = torch.cat([emb1, emb2, emb3, emb4, emb5, emb6, emb7], 0)
81 |
82 | data_time_interval = time.time() - time_stamp
83 | time_stamp = time.time()
84 |
85 | lr = get_lr(args, iters)
86 | set_lr(optimizer, lr)
87 |
88 | optimizer.zero_grad()
89 |
90 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img8, embt, imgt, None)
91 |
92 | loss = loss_rec + loss_geo + loss_dis
93 | loss.backward()
94 | optimizer.step()
95 |
96 | avg_rec.update(loss_rec.cpu().data)
97 | avg_geo.update(loss_geo.cpu().data)
98 | avg_dis.update(loss_dis.cpu().data)
99 | train_time_interval = time.time() - time_stamp
100 |
101 | if (iters+1) % 50 == 0 and local_rank == 0:
102 | logger.info('epoch:{}/{} iter:{}/{} time:{:.2f}+{:.2f} lr:{:.5e} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e}'.format(epoch+1, args.epochs, iters+1, args.epochs * args.iters_per_epoch, data_time_interval, train_time_interval, lr, avg_rec.avg, avg_geo.avg, avg_dis.avg))
103 | avg_rec.reset()
104 | avg_geo.reset()
105 | avg_dis.reset()
106 |
107 | iters += 1
108 | time_stamp = time.time()
109 |
110 | if (epoch+1) % args.eval_interval == 0 and local_rank == 0:
111 | psnr = evaluate(args, ddp_model, dataloader_val, epoch, logger)
112 | if psnr > best_psnr:
113 | best_psnr = psnr
114 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'best'))
115 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'latest'))
116 |
117 | dist.barrier()
118 |
119 |
120 | def evaluate(args, ddp_model, dataloader_val, epoch, logger):
121 | loss_rec_list = []
122 | loss_geo_list = []
123 | loss_dis_list = []
124 | psnr_list = []
125 | time_stamp = time.time()
126 | for i, data in enumerate(dataloader_val):
127 | for l in range(len(data)):
128 | data[l] = data[l].to(args.device)
129 | img0, img1, img2, img3, img4, img5, img6, img7, img8, emb1, emb2, emb3, emb4, emb5, emb6, emb7 = data
130 |
131 | img0 = torch.cat([img0, img0, img0, img0, img0, img0, img0], 0)
132 | img8 = torch.cat([img8, img8, img8, img8, img8, img8, img8], 0)
133 | imgt = torch.cat([img1, img2, img3, img4, img5, img6, img7], 0)
134 | embt = torch.cat([emb1, emb2, emb3, emb4, emb5, emb6, emb7], 0)
135 |
136 | with torch.no_grad():
137 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img8, embt, imgt, None)
138 |
139 | loss_rec_list.append(loss_rec.cpu().numpy())
140 | loss_geo_list.append(loss_geo.cpu().numpy())
141 | loss_dis_list.append(loss_dis.cpu().numpy())
142 |
143 | for j in range(img0.shape[0]):
144 | psnr = calculate_psnr(imgt_pred[j].unsqueeze(0), imgt[j].unsqueeze(0)).cpu().data
145 | psnr_list.append(psnr)
146 |
147 | eval_time_interval = time.time() - time_stamp
148 |
149 | logger.info('eval epoch:{}/{} time:{:.2f} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e} psnr:{:.3f}'.format(epoch+1, args.epochs, eval_time_interval, np.array(loss_rec_list).mean(), np.array(loss_geo_list).mean(), np.array(loss_dis_list).mean(), np.array(psnr_list).mean()))
150 | return np.array(psnr_list).mean()
151 |
152 |
153 |
154 | if __name__ == '__main__':
155 | parser = argparse.ArgumentParser(description='IFRNet')
156 | parser.add_argument('--model_name', default='IFRNet', type=str, help='IFRNet, IFRNet_L, IFRNet_S')
157 | parser.add_argument('--local_rank', default=-1, type=int)
158 | parser.add_argument('--world_size', default=4, type=int)
159 | parser.add_argument('--epochs', default=600, type=int)
160 | parser.add_argument('--eval_interval', default=8, type=int)
161 | parser.add_argument('--batch_size', default=2, type=int)
162 | parser.add_argument('--lr_start', default=1e-4, type=float)
163 | parser.add_argument('--lr_end', default=1e-5, type=float)
164 | parser.add_argument('--log_path', default='checkpoint', type=str)
165 | parser.add_argument('--resume_epoch', default=0, type=int)
166 | parser.add_argument('--resume_path', default=None, type=str)
167 | args = parser.parse_args()
168 |
169 | dist.init_process_group(backend='nccl', world_size=args.world_size)
170 | torch.cuda.set_device(args.local_rank)
171 | args.device = torch.device('cuda', args.local_rank)
172 |
173 | seed = 1234
174 | random.seed(seed)
175 | np.random.seed(seed)
176 | torch.manual_seed(seed)
177 | torch.cuda.manual_seed_all(seed)
178 | torch.backends.cudnn.benchmark = True
179 |
180 | if args.model_name == 'IFRNet':
181 | from models.IFRNet import Model
182 | elif args.model_name == 'IFRNet_L':
183 | from models.IFRNet_L import Model
184 | elif args.model_name == 'IFRNet_S':
185 | from models.IFRNet_S import Model
186 |
187 | args.log_path = args.log_path + '/' + args.model_name
188 | args.num_workers = args.batch_size * 4
189 |
190 | model = Model().to(args.device)
191 |
192 | if args.resume_epoch != 0:
193 | model.load_state_dict(torch.load(args.resume_path, map_location='cpu'))
194 |
195 | ddp_model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
196 |
197 | train(args, ddp_model)
198 |
199 | dist.destroy_process_group()
200 |
--------------------------------------------------------------------------------
/train_vimeo90k.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import random
5 | import argparse
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from torch.utils.data import DataLoader
11 | import torch.distributed as dist
12 | from torch.utils.data.distributed import DistributedSampler
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from datasets import Vimeo90K_Train_Dataset, Vimeo90K_Test_Dataset
15 | from metric import calculate_psnr, calculate_ssim
16 | from utils import AverageMeter
17 | import logging
18 |
19 |
20 | def get_lr(args, iters):
21 | ratio = 0.5 * (1.0 + np.cos(iters / (args.epochs * args.iters_per_epoch) * math.pi))
22 | lr = (args.lr_start - args.lr_end) * ratio + args.lr_end
23 | return lr
24 |
25 |
26 | def set_lr(optimizer, lr):
27 | for param_group in optimizer.param_groups:
28 | param_group['lr'] = lr
29 |
30 |
31 | def train(args, ddp_model):
32 | local_rank = args.local_rank
33 | print('Distributed Data Parallel Training IFRNet on Rank {}'.format(local_rank))
34 |
35 | if local_rank == 0:
36 | os.makedirs(args.log_path, exist_ok=True)
37 | log_path = os.path.join(args.log_path, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
38 | os.makedirs(log_path, exist_ok=True)
39 | logger = logging.getLogger()
40 | logger.setLevel('INFO')
41 | BASIC_FORMAT = '%(asctime)s:%(levelname)s:%(message)s'
42 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
43 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
44 | chlr = logging.StreamHandler()
45 | chlr.setFormatter(formatter)
46 | chlr.setLevel('INFO')
47 | fhlr = logging.FileHandler(os.path.join(log_path, 'train.log'))
48 | fhlr.setFormatter(formatter)
49 | logger.addHandler(chlr)
50 | logger.addHandler(fhlr)
51 | logger.info(args)
52 |
53 | dataset_train = Vimeo90K_Train_Dataset(dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet', augment=True)
54 | sampler = DistributedSampler(dataset_train)
55 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=sampler)
56 | args.iters_per_epoch = dataloader_train.__len__()
57 | iters = args.resume_epoch * args.iters_per_epoch
58 |
59 | dataset_val = Vimeo90K_Test_Dataset(dataset_dir='/home/ltkong/Datasets/Vimeo90K/vimeo_triplet')
60 | dataloader_val = DataLoader(dataset_val, batch_size=16, num_workers=16, pin_memory=True, shuffle=False, drop_last=True)
61 |
62 | optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr_start, weight_decay=0)
63 |
64 | time_stamp = time.time()
65 | avg_rec = AverageMeter()
66 | avg_geo = AverageMeter()
67 | avg_dis = AverageMeter()
68 | best_psnr = 0.0
69 |
70 | for epoch in range(args.resume_epoch, args.epochs):
71 | sampler.set_epoch(epoch)
72 | for i, data in enumerate(dataloader_train):
73 | for l in range(len(data)):
74 | data[l] = data[l].to(args.device)
75 | img0, imgt, img1, flow, embt = data
76 |
77 | data_time_interval = time.time() - time_stamp
78 | time_stamp = time.time()
79 |
80 | lr = get_lr(args, iters)
81 | set_lr(optimizer, lr)
82 |
83 | optimizer.zero_grad()
84 |
85 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img1, embt, imgt, flow)
86 |
87 | loss = loss_rec + loss_geo + loss_dis
88 | loss.backward()
89 | optimizer.step()
90 |
91 | avg_rec.update(loss_rec.cpu().data)
92 | avg_geo.update(loss_geo.cpu().data)
93 | avg_dis.update(loss_dis.cpu().data)
94 | train_time_interval = time.time() - time_stamp
95 |
96 | if (iters+1) % 100 == 0 and local_rank == 0:
97 | logger.info('epoch:{}/{} iter:{}/{} time:{:.2f}+{:.2f} lr:{:.5e} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e}'.format(epoch+1, args.epochs, iters+1, args.epochs * args.iters_per_epoch, data_time_interval, train_time_interval, lr, avg_rec.avg, avg_geo.avg, avg_dis.avg))
98 | avg_rec.reset()
99 | avg_geo.reset()
100 | avg_dis.reset()
101 |
102 | iters += 1
103 | time_stamp = time.time()
104 |
105 | if (epoch+1) % args.eval_interval == 0 and local_rank == 0:
106 | psnr = evaluate(args, ddp_model, dataloader_val, epoch, logger)
107 | if psnr > best_psnr:
108 | best_psnr = psnr
109 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'best'))
110 | torch.save(ddp_model.module.state_dict(), '{}/{}_{}.pth'.format(log_path, args.model_name, 'latest'))
111 |
112 | dist.barrier()
113 |
114 |
115 | def evaluate(args, ddp_model, dataloader_val, epoch, logger):
116 | loss_rec_list = []
117 | loss_geo_list = []
118 | loss_dis_list = []
119 | psnr_list = []
120 | time_stamp = time.time()
121 | for i, data in enumerate(dataloader_val):
122 | for l in range(len(data)):
123 | data[l] = data[l].to(args.device)
124 | img0, imgt, img1, flow, embt = data
125 |
126 | with torch.no_grad():
127 | imgt_pred, loss_rec, loss_geo, loss_dis = ddp_model(img0, img1, embt, imgt, flow)
128 |
129 | loss_rec_list.append(loss_rec.cpu().numpy())
130 | loss_geo_list.append(loss_geo.cpu().numpy())
131 | loss_dis_list.append(loss_dis.cpu().numpy())
132 |
133 | for j in range(img0.shape[0]):
134 | psnr = calculate_psnr(imgt_pred[j].unsqueeze(0), imgt[j].unsqueeze(0)).cpu().data
135 | psnr_list.append(psnr)
136 |
137 | eval_time_interval = time.time() - time_stamp
138 |
139 | logger.info('eval epoch:{}/{} time:{:.2f} loss_rec:{:.4e} loss_geo:{:.4e} loss_dis:{:.4e} psnr:{:.3f}'.format(epoch+1, args.epochs, eval_time_interval, np.array(loss_rec_list).mean(), np.array(loss_geo_list).mean(), np.array(loss_dis_list).mean(), np.array(psnr_list).mean()))
140 | return np.array(psnr_list).mean()
141 |
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser(description='IFRNet')
146 | parser.add_argument('--model_name', default='IFRNet', type=str, help='IFRNet, IFRNet_L, IFRNet_S')
147 | parser.add_argument('--local_rank', default=-1, type=int)
148 | parser.add_argument('--world_size', default=4, type=int)
149 | parser.add_argument('--epochs', default=300, type=int)
150 | parser.add_argument('--eval_interval', default=1, type=int)
151 | parser.add_argument('--batch_size', default=6, type=int)
152 | parser.add_argument('--lr_start', default=1e-4, type=float)
153 | parser.add_argument('--lr_end', default=1e-5, type=float)
154 | parser.add_argument('--log_path', default='checkpoint', type=str)
155 | parser.add_argument('--resume_epoch', default=0, type=int)
156 | parser.add_argument('--resume_path', default=None, type=str)
157 | args = parser.parse_args()
158 |
159 | dist.init_process_group(backend='nccl', world_size=args.world_size)
160 | torch.cuda.set_device(args.local_rank)
161 | args.device = torch.device('cuda', args.local_rank)
162 |
163 | seed = 1234
164 | random.seed(seed)
165 | np.random.seed(seed)
166 | torch.manual_seed(seed)
167 | torch.cuda.manual_seed_all(seed)
168 | torch.backends.cudnn.benchmark = True
169 |
170 | if args.model_name == 'IFRNet':
171 | from models.IFRNet import Model
172 | elif args.model_name == 'IFRNet_L':
173 | from models.IFRNet_L import Model
174 | elif args.model_name == 'IFRNet_S':
175 | from models.IFRNet_S import Model
176 |
177 | args.log_path = args.log_path + '/' + args.model_name
178 | args.num_workers = args.batch_size
179 |
180 | model = Model().to(args.device)
181 |
182 | if args.resume_epoch != 0:
183 | model.load_state_dict(torch.load(args.resume_path, map_location='cpu'))
184 |
185 | ddp_model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
186 |
187 | train(args, ddp_model)
188 |
189 | dist.destroy_process_group()
190 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | import numpy as np
5 | from imageio import imread, imwrite
6 | import numpy as np
7 | from PIL import Image, ImageFile
8 | ImageFile.LOAD_TRUNCATED_IMAGES = True
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | def warp(img, flow):
15 | B, _, H, W = flow.shape
16 | xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
17 | yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
18 | grid = torch.cat([xx, yy], 1).to(img)
19 | flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
20 | grid_ = (grid + flow_).permute(0, 2, 3, 1)
21 | output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
22 | return output
23 |
24 |
25 | def get_robust_weight(flow_pred, flow_gt, beta):
26 | epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5
27 | robust_weight = torch.exp(-beta * epe)
28 | return robust_weight
29 |
30 |
31 | class AverageMeter():
32 | def __init__(self):
33 | self.reset()
34 |
35 | def reset(self):
36 | self.val = 0
37 | self.avg = 0
38 | self.sum = 0
39 | self.count = 0
40 |
41 | def update(self, val, n=1):
42 | self.val = val
43 | self.sum += val * n
44 | self.count += n
45 | self.avg = self.sum / self.count
46 |
47 |
48 | def read(file):
49 | if file.endswith('.float3'): return readFloat(file)
50 | elif file.endswith('.flo'): return readFlow(file)
51 | elif file.endswith('.ppm'): return readImage(file)
52 | elif file.endswith('.pgm'): return readImage(file)
53 | elif file.endswith('.png'): return readImage(file)
54 | elif file.endswith('.jpg'): return readImage(file)
55 | elif file.endswith('.pfm'): return readPFM(file)[0]
56 | else: raise Exception('don\'t know how to read %s' % file)
57 |
58 |
59 | def write(file, data):
60 | if file.endswith('.float3'): return writeFloat(file, data)
61 | elif file.endswith('.flo'): return writeFlow(file, data)
62 | elif file.endswith('.ppm'): return writeImage(file, data)
63 | elif file.endswith('.pgm'): return writeImage(file, data)
64 | elif file.endswith('.png'): return writeImage(file, data)
65 | elif file.endswith('.jpg'): return writeImage(file, data)
66 | elif file.endswith('.pfm'): return writePFM(file, data)
67 | else: raise Exception('don\'t know how to write %s' % file)
68 |
69 |
70 | def readPFM(file):
71 | file = open(file, 'rb')
72 |
73 | color = None
74 | width = None
75 | height = None
76 | scale = None
77 | endian = None
78 |
79 | header = file.readline().rstrip()
80 | if header.decode("ascii") == 'PF':
81 | color = True
82 | elif header.decode("ascii") == 'Pf':
83 | color = False
84 | else:
85 | raise Exception('Not a PFM file.')
86 |
87 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
88 | if dim_match:
89 | width, height = list(map(int, dim_match.groups()))
90 | else:
91 | raise Exception('Malformed PFM header.')
92 |
93 | scale = float(file.readline().decode("ascii").rstrip())
94 | if scale < 0:
95 | endian = '<'
96 | scale = -scale
97 | else:
98 | endian = '>'
99 |
100 | data = np.fromfile(file, endian + 'f')
101 | shape = (height, width, 3) if color else (height, width)
102 |
103 | data = np.reshape(data, shape)
104 | data = np.flipud(data)
105 | return data, scale
106 |
107 |
108 | def writePFM(file, image, scale=1):
109 | file = open(file, 'wb')
110 |
111 | color = None
112 |
113 | if image.dtype.name != 'float32':
114 | raise Exception('Image dtype must be float32.')
115 |
116 | image = np.flipud(image)
117 |
118 | if len(image.shape) == 3 and image.shape[2] == 3:
119 | color = True
120 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
121 | color = False
122 | else:
123 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
124 |
125 | file.write('PF\n' if color else 'Pf\n'.encode())
126 | file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
127 |
128 | endian = image.dtype.byteorder
129 |
130 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
131 | scale = -scale
132 |
133 | file.write('%f\n'.encode() % scale)
134 |
135 | image.tofile(file)
136 |
137 |
138 | def readFlow(name):
139 | if name.endswith('.pfm') or name.endswith('.PFM'):
140 | return readPFM(name)[0][:,:,0:2]
141 |
142 | f = open(name, 'rb')
143 |
144 | header = f.read(4)
145 | if header.decode("utf-8") != 'PIEH':
146 | raise Exception('Flow file header does not contain PIEH')
147 |
148 | width = np.fromfile(f, np.int32, 1).squeeze()
149 | height = np.fromfile(f, np.int32, 1).squeeze()
150 |
151 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
152 |
153 | return flow.astype(np.float32)
154 |
155 |
156 | def readImage(name):
157 | if name.endswith('.pfm') or name.endswith('.PFM'):
158 | data = readPFM(name)[0]
159 | if len(data.shape)==3:
160 | return data[:,:,0:3]
161 | else:
162 | return data
163 | return imread(name)
164 |
165 |
166 | def writeImage(name, data):
167 | if name.endswith('.pfm') or name.endswith('.PFM'):
168 | return writePFM(name, data, 1)
169 | return imwrite(name, data)
170 |
171 |
172 | def writeFlow(name, flow):
173 | f = open(name, 'wb')
174 | f.write('PIEH'.encode('utf-8'))
175 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
176 | flow = flow.astype(np.float32)
177 | flow.tofile(f)
178 |
179 |
180 | def readFloat(name):
181 | f = open(name, 'rb')
182 |
183 | if(f.readline().decode("utf-8")) != 'float\n':
184 | raise Exception('float file %s did not contain keyword' % name)
185 |
186 | dim = int(f.readline())
187 |
188 | dims = []
189 | count = 1
190 | for i in range(0, dim):
191 | d = int(f.readline())
192 | dims.append(d)
193 | count *= d
194 |
195 | dims = list(reversed(dims))
196 |
197 | data = np.fromfile(f, np.float32, count).reshape(dims)
198 | if dim > 2:
199 | data = np.transpose(data, (2, 1, 0))
200 | data = np.transpose(data, (1, 0, 2))
201 |
202 | return data
203 |
204 |
205 | def writeFloat(name, data):
206 | f = open(name, 'wb')
207 |
208 | dim=len(data.shape)
209 | if dim>3:
210 | raise Exception('bad float file dimension: %d' % dim)
211 |
212 | f.write(('float\n').encode('ascii'))
213 | f.write(('%d\n' % dim).encode('ascii'))
214 |
215 | if dim == 1:
216 | f.write(('%d\n' % data.shape[0]).encode('ascii'))
217 | else:
218 | f.write(('%d\n' % data.shape[1]).encode('ascii'))
219 | f.write(('%d\n' % data.shape[0]).encode('ascii'))
220 | for i in range(2, dim):
221 | f.write(('%d\n' % data.shape[i]).encode('ascii'))
222 |
223 | data = data.astype(np.float32)
224 | if dim==2:
225 | data.tofile(f)
226 |
227 | else:
228 | np.transpose(data, (2, 0, 1)).tofile(f)
229 |
--------------------------------------------------------------------------------