├── .gitignore
├── AverageMeter.py
├── LICENSE
├── README.md
├── data
├── Vimeo-90K.py
├── adobe240fps
│ ├── test_list.txt
│ └── train_list.txt
├── create_dataset.py
├── figures
│ ├── SYA_1.png
│ ├── SYA_2.png
│ ├── check_all.png
│ ├── shana.png
│ ├── tianqi_all.png
│ ├── video.png
│ └── youtube.png
└── video_transformer.py
├── eval_Vimeo90K.py
├── models
├── DCNv2
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── dcn_v2.py
│ ├── make.sh
│ ├── setup.py
│ └── src
│ │ ├── cpu
│ │ ├── dcn_v2_cpu.cpp
│ │ └── vision.h
│ │ ├── cuda
│ │ ├── dcn_v2_cuda.cu
│ │ ├── dcn_v2_im2col_cuda.cu
│ │ ├── dcn_v2_im2col_cuda.h
│ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ └── vision.h
│ │ ├── dcn_v2.h
│ │ └── vision.cpp
├── ResBlock.py
├── Unet.py
├── __init__.py
├── bdcn
│ ├── __init__.py
│ ├── bdcn.py
│ └── vgg16_c.py
└── warp.py
├── paper
├── FeatureFlow.pdf
└── Supp.pdf
├── pure_run.py
├── requirements.txt
├── run.py
├── sequence_run.py
├── src
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── dataloader.cpython-37.pyc
│ ├── layers.cpython-37.pyc
│ ├── loss.cpython-37.pyc
│ ├── model.cpython-37.pyc
│ └── pure_network.cpython-37.pyc
├── dataloader.py
├── eval.py
├── layers.py
├── loss.py
├── model.py
└── pure_network.py
├── train.py
├── utils
└── visualize.py
└── video_process.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/**
2 | log/**
3 | checkpoints/**
4 | models/bdcn/final-model/**
5 |
--------------------------------------------------------------------------------
/AverageMeter.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class AverageMeter(object):
4 | """Computes and stores the average and current value"""
5 | def __init__(self):
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = 0
10 | self.avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.avg = self.sum / self.count
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Citrine
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FeatureFlow
2 |
3 | [Paper](https://github.com/CM-BF/FeatureFlow/blob/master/paper/FeatureFlow.pdf) | [Supp](https://github.com/CM-BF/FeatureFlow/blob/master/paper/Supp.pdf)
4 |
5 | A state-of-the-art Video Frame Interpolation Method using deep semantic flows blending.
6 |
7 | FeatureFlow: Robust Video Interpolation via Structure-to-texture Generation (IEEE Conference on Computer Vision and Pattern Recognition 2020)
8 |
9 | ## To Do List
10 | - [x] Preprint
11 | - [x] Training code
12 |
13 | ## Table of Contents
14 |
15 | 1. [Requirements](#requirements)
16 | 1. [Demos](#video-demos)
17 | 1. [Installation](#installation)
18 | 1. [Pre-trained Model](#pre-trained-model)
19 | 1. [Download Results](#download-results)
20 | 1. [Evaluation](#evaluation)
21 | 1. [Test your video](#test-your-video)
22 | 1. [Training](#training)
23 | 1. [Citation](#citation)
24 |
25 | ## Requirements
26 |
27 | * Ubuntu
28 | * PyTorch (>=1.1)
29 | * Cuda (>=10.0) & Cudnn (>=7.0)
30 | * mmdet 1.0rc (from https://github.com/open-mmlab/mmdetection.git)
31 | * visdom (not necessary)
32 | * NVIDIA GPU
33 |
34 | Ps: `requirements.txt` is provided, but do not use it directly. It is just for reference because it contains another project's dependencies.
35 |
36 | ## Video demos
37 |
38 | Click the picture to Download one of them or click [Here(Google)](https://drive.google.com/open?id=1QUYoplBNjaWXJZPO90NiwQwqQz7yF7TX) or [Here(Baidu)](https://pan.baidu.com/s/1J9seoqgC2p9zZ7pegMlH1A)(key: oav2) to download **360p demos**.
39 |
40 | **360p demos**(including comparisons):
41 |
42 | [
](https://github.com/CM-BF/storage/tree/master/videos/youtube.mp4 "video1")
43 | [
](https://github.com/CM-BF/storage/tree/master/videos/check_all.mp4 "video2")
44 | [
](https://github.com/CM-BF/storage/tree/master/videos/tianqi_all.mp4 "video3")
45 | [
](https://github.com/CM-BF/storage/tree/master/videos/video.mp4 "video4")
46 | [
](https://github.com/CM-BF/storage/tree/master/videos/shana.mp4 "video5")
47 |
48 | **720p demos**:
49 |
50 | [
](https://github.com/CM-BF/storage/tree/master/videos/SYA_1.mp4 "video6")
51 | [
](https://github.com/CM-BF/storage/tree/master/videos/SYA_2.mp4 "video7")
52 |
53 | ## Installation
54 | * clone this repo
55 | * git clone https://github.com/open-mmlab/mmdetection.git
56 | * install mmdetection: please follow the guidence in its github
57 | ```bash
58 | $ cd mmdetection
59 | $ pip install -r requirements/build.txt
60 | $ pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
61 | $ pip install -v -e . # or "python setup.py develop"
62 | $ pip list | grep mmdet
63 | ```
64 | * Download [test set](http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip)
65 | ```bash
66 | $ unzip vimeo_interp_test.zip
67 | $ cd vimeo_interp_test
68 | $ mkdir sequences
69 | $ cp target/* sequences/ -r
70 | $ cp input/* sequences/ -r
71 | ```
72 | * Download BDCN's pre-trained model:bdcn_pretrained_on_bsds500.pth to ./model/bdcn/final-model/
73 |
74 | Ps: For your convenience, you can only download the bdcn_pretrained_on_bsds500.pth: [Google Drive](https://drive.google.com/file/d/1Zot0P8pSBawnehTE32T-8J0rjYq6CSAT/view?usp=sharing) or all of the pre-trained bdcn models its authors provided: [Google Drive](https://drive.google.com/file/d/1CmDMypSlLM6EAvOt5yjwUQ7O5w-xCm1n/view). For a Baidu Cloud link, you can resort to BDCN's GitHub repository.
75 |
76 | ```
77 | $ pip install scikit-image visdom tqdm prefetch-generator
78 | ```
79 |
80 | ## Pre-trained Model
81 |
82 | [Google Drive](https://drive.google.com/open?id=1S8C0chFV6Bip6W9lJdZkog0T3xiNxbEx)
83 |
84 | [Baidu Cloud](https://pan.baidu.com/s/1LxVw-89f3GX5r0mZ6wmsJw): ae4x
85 |
86 | Place FeFlow.ckpt to ./checkpoints/.
87 |
88 | ## Download Results
89 |
90 | [Google Drive](https://drive.google.com/open?id=1OtrExUiyIBJe0D6_ZwDfztqJBqji4lmt)
91 |
92 | [Baidu Cloud](https://pan.baidu.com/s/1BaJJ82nSKagly6XZ8KNtAw): pc0k
93 |
94 | ## Evaluation
95 | ```bash
96 | $ CUDA_VISIBLE_DEVICES=0 python eval_Vimeo90K.py --checkpoint ./checkpoints/FeFlow.ckpt --dataset_root ~/datasets/videos/vimeo_interp_test --visdom_env test --vimeo90k --imgpath ./results/
97 | ```
98 |
99 | ## Test your video
100 | ```bash
101 | $ CUDA_VISIBLE_DEVICES=0 python sequence_run.py --checkpoint checkpoints/FeFlow.ckpt --video_path ./yourvideo.mp4 --t_interp 4 --slow_motion
102 | ```
103 | `--t_interp` sets frame multiples, only power of 2(2,4,8...) are supported. Use flag `--slow_motion` to slow down the video which maintains the original fps.
104 |
105 | The output video will be saved as output.mp4 in your working diractory.
106 |
107 | ## Training
108 |
109 | Training Code **train.py** is available now. I can't run it for comfirmation now because I've left the Lab, but I'm sure it will work with right argument settings.
110 |
111 | ```bash
112 | $ CUDA_VISIBLE_DEVICES=0,1 python train.py
113 | ```
114 |
115 | * Please read the **arguments' help** carefully to fully control the **two-step training**.
116 | * Pay attention to the `--GEN_DE` which is the flag to set the model to Stage-I or Stage-II.
117 | * 2 GPUs is necessary for training or the small batch\_size will cause training process crash.
118 | * Deformable CNN is not stable enough so that you may face training crash sometimes(I didn't fix the random seed), but it can be detected soon after the beginning of running by visualizing results using Visdom.
119 | * Visdom visualization codes[line 75, 201-216 and 338-353] are included which is good for viewing training process and checking crash.
120 |
121 | ## Citation
122 | ```
123 | @InProceedings{Gui_2020_CVPR,
124 | author = {Gui, Shurui and Wang, Chaoyue and Chen, Qihua and Tao, Dacheng},
125 | title = {FeatureFlow: Robust Video Interpolation via Structure-to-Texture Generation},
126 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
127 | month = {June},
128 | year = {2020}
129 | }
130 | ```
131 |
132 | ## Contact
133 | [Shurui Gui](mailto:citrinegui@gmail.com); [Chaoyue Wang](mailto:chaoyue.wang@sydney.edu.au)
134 |
135 | ## License
136 | See [MIT License](https://github.com/CM-BF/FeatureFlow/blob/master/LICENSE)
137 |
--------------------------------------------------------------------------------
/data/Vimeo-90K.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--dataset', type=str, required=True, help='path to Vimeo-90K dataset')
6 | parser.add_argument('--out', type=str, required=True, help='path to output dataset place')
7 | args = parser.parse_args()
8 |
9 | def main():
10 | train_out_path = args.out + 'train'
11 | test_out_path = args.out + 'test'
12 | validation_out_path = args.out + 'validation'
13 | os.mkdir(train_out_path)
14 | os.mkdir(test_out_path)
15 | os.mkdir(validation_out_path)
16 |
17 | with open(args.dataset + '/tri_trainlist.txt', 'r') as f:
18 | train_paths = f.read().split('\n')
19 | test_paths = f.read().split('\n')
20 | print()
21 |
22 | if __name__ == '__main__':
23 | main()
24 |
--------------------------------------------------------------------------------
/data/adobe240fps/test_list.txt:
--------------------------------------------------------------------------------
1 | 720p_240fps_1.mov
2 | GOPR9635.mp4
3 | GOPR9637a.mp4
4 | IMG_0004a.mov
5 | IMG_0015.mov
6 | IMG_0023.mov
7 | IMG_0179.m4v
8 | IMG_0183.MOV
--------------------------------------------------------------------------------
/data/adobe240fps/train_list.txt:
--------------------------------------------------------------------------------
1 | 720p_240fps_2.mov
2 | 720p_240fps_3.mov
3 | 720p_240fps_5.mov
4 | 720p_240fps_6.mov
5 | GOPR9633.mp4
6 | GOPR9634.mp4
7 | GOPR9636.mp4
8 | GOPR9637b.mp4
9 | GOPR9638.mp4
10 | GOPR9639.mp4
11 | GOPR9640.mp4
12 | GOPR9641.mp4
13 | GOPR9642.mp4
14 | GOPR9643.mp4
15 | GOPR9644.mp4
16 | GOPR9645.mp4
17 | GOPR9646.mp4
18 | GOPR9647.mp4
19 | GOPR9648.mp4
20 | GOPR9649.mp4
21 | GOPR9650.mp4
22 | GOPR9651.mp4
23 | GOPR9652.mp4
24 | GOPR9653.mp4
25 | GOPR9654a.mp4
26 | GOPR9654b.mp4
27 | GOPR9655a.mp4
28 | GOPR9655b.mp4
29 | GOPR9656.mp4
30 | GOPR9657.mp4
31 | GOPR9658.MP4
32 | GOPR9659.mp4
33 | GOPR9660.MP4
34 | IMG_0001.mov
35 | IMG_0002.mov
36 | IMG_0003.mov
37 | IMG_0004b.mov
38 | IMG_0005.mov
39 | IMG_0006.mov
40 | IMG_0007.mov
41 | IMG_0008.mov
42 | IMG_0009.mov
43 | IMG_0010.mov
44 | IMG_0011.mov
45 | IMG_0012.mov
46 | IMG_0013.mov
47 | IMG_0014.mov
48 | IMG_0016.mov
49 | IMG_0017.mov
50 | IMG_0018.mov
51 | IMG_0019.mov
52 | IMG_0020.mov
53 | IMG_0021.mov
54 | IMG_0022.mov
55 | IMG_0024.mov
56 | IMG_0025.mov
57 | IMG_0026.mov
58 | IMG_0028.mov
59 | IMG_0029.mov
60 | IMG_0030.mov
61 | IMG_0031.mov
62 | IMG_0032.mov
63 | IMG_0033.mov
64 | IMG_0034.mov
65 | IMG_0034a.mov
66 | IMG_0035.mov
67 | IMG_0036.mov
68 | IMG_0037.mov
69 | IMG_0037a.mov
70 | IMG_0038.mov
71 | IMG_0039.mov
72 | IMG_0040.mov
73 | IMG_0041.mov
74 | IMG_0042.mov
75 | IMG_0043.mov
76 | IMG_0044.mov
77 | IMG_0045.mov
78 | IMG_0046.mov
79 | IMG_0047.mov
80 | IMG_0052.mov
81 | IMG_0054a.mov
82 | IMG_0054b.mov
83 | IMG_0055.mov
84 | IMG_0056.mov
85 | IMG_0058.mov
86 | IMG_0150.m4v
87 | IMG_0151.m4v
88 | IMG_0152.m4v
89 | IMG_0153.m4v
90 | IMG_0154.m4v
91 | IMG_0155.m4v
92 | IMG_0156.m4v
93 | IMG_0157.m4v
94 | IMG_0160.m4v
95 | IMG_0161.m4v
96 | IMG_0162.m4v
97 | IMG_0163.m4v
98 | IMG_0164.m4v
99 | IMG_0167.m4v
100 | IMG_0169.m4v
101 | IMG_0170.m4v
102 | IMG_0171.m4v
103 | IMG_0172.m4v
104 | IMG_0173.m4v
105 | IMG_0174.m4v
106 | IMG_0175.m4v
107 | IMG_0176.m4v
108 | IMG_0177.m4v
109 | IMG_0178.m4v
110 | IMG_0180.m4v
111 | IMG_0200.MOV
112 | IMG_0212.MOV
--------------------------------------------------------------------------------
/data/create_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path
4 | from shutil import rmtree, move
5 | import random
6 |
7 | # For parsing commandline arguments
8 | parser = argparse.ArgumentParser()
9 | # parser.add_argument("--ffmpeg_dir", type=str, required=True, help='path to ffmpeg.exe')
10 | parser.add_argument("--dataset", type=str, default="custom", help='specify if using "adobe240fps" or custom video dataset')
11 | parser.add_argument("--videos_folder", type=str, required=True, help='path to the folder containing videos')
12 | parser.add_argument("--dataset_folder", type=str, required=True, help='path to the output dataset folder')
13 | parser.add_argument("--img_width", type=int, default=640, help="output image width")
14 | parser.add_argument("--img_height", type=int, default=360, help="output image height")
15 | parser.add_argument("--train_test_split", type=tuple, default=(90, 10), help="train test split for custom dataset")
16 | args = parser.parse_args()
17 |
18 |
19 | def extract_frames(videos, inDir, outDir):
20 | """
21 | Converts all the videos passed in `videos` list to images.
22 |
23 | Parameters
24 | ----------
25 | videos : list
26 | name of all video files.
27 | inDir : string
28 | path to input directory containing videos in `videos` list.
29 | outDir : string
30 | path to directory to output the extracted images.
31 |
32 | Returns
33 | -------
34 | None
35 | """
36 |
37 |
38 | for video in videos:
39 | os.mkdir(os.path.join(outDir, os.path.splitext(video)[0]))
40 | retn = os.system('{} -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.png'.format('ffmpeg', os.path.join(inDir, video), args.img_width, args.img_height, os.path.join(outDir, os.path.splitext(video)[0])))
41 | if retn:
42 | print("Error converting file:{}. Exiting.".format(video))
43 |
44 |
45 | def create_clips(root, destination):
46 | """
47 | Distributes the images extracted by `extract_frames()` in
48 | clips containing 12 frames each.
49 |
50 | Parameters
51 | ----------
52 | root : string
53 | path containing extracted image folders.
54 | destination : string
55 | path to output clips.
56 |
57 | Returns
58 | -------
59 | None
60 | """
61 |
62 |
63 | folderCounter = -1
64 |
65 | files = os.listdir(root)
66 |
67 | # Iterate over each folder containing extracted video frames.
68 | for file in files:
69 | images = sorted(os.listdir(os.path.join(root, file)))
70 |
71 | for imageCounter, image in enumerate(images):
72 | # Bunch images in groups of 12 frames
73 | if (imageCounter % 12 == 0):
74 | if (imageCounter + 11 >= len(images)):
75 | break
76 | folderCounter += 1
77 | os.mkdir("{}/{}".format(destination, folderCounter))
78 | move("{}/{}/{}".format(root, file, image), "{}/{}/{}".format(destination, folderCounter, image))
79 | rmtree(os.path.join(root, file))
80 |
81 | def main():
82 | # Create dataset folder if it doesn't exist already.
83 | if not os.path.isdir(args.dataset_folder):
84 | os.mkdir(args.dataset_folder)
85 |
86 | extractPath = os.path.join(args.dataset_folder, "extracted")
87 | trainPath = os.path.join(args.dataset_folder, "train")
88 | testPath = os.path.join(args.dataset_folder, "test")
89 | validationPath = os.path.join(args.dataset_folder, "validation")
90 | os.mkdir(extractPath)
91 | os.mkdir(trainPath)
92 | os.mkdir(testPath)
93 | os.mkdir(validationPath)
94 |
95 | if(args.dataset == "adobe240fps"):
96 | f = open("data/adobe240fps/test_list.txt", "r")
97 | videos = f.read().split('\n')
98 | extract_frames(videos, args.videos_folder, extractPath)
99 | create_clips(extractPath, testPath)
100 |
101 | f = open("data/adobe240fps/train_list.txt", "r")
102 | videos = f.read().split('\n')
103 | extract_frames(videos, args.videos_folder, extractPath)
104 | create_clips(extractPath, trainPath)
105 |
106 | # Select 100 clips at random from test set for validation set.
107 | testClips = os.listdir(testPath)
108 | indices = random.sample(range(len(testClips)), 100)
109 | for index in indices:
110 | move("{}/{}".format(testPath, index), "{}/{}".format(validationPath, index))
111 |
112 | else: # custom dataset
113 |
114 | # Extract video names
115 | videos = os.listdir(args.videos_folder)
116 |
117 | # Create random train-test split.
118 | testIndices = random.sample(range(len(videos)), int((args.train_test_split[1] * len(videos)) / 100))
119 | trainIndices = [x for x in range((len(videos))) if x not in testIndices]
120 |
121 | # Create list of video names
122 | testVideoNames = [videos[index] for index in testIndices]
123 | trainVideoNames = [videos[index] for index in trainIndices]
124 |
125 | # Create train-test dataset
126 | extract_frames(testVideoNames, args.videos_folder, extractPath)
127 | create_clips(extractPath, testPath)
128 | extract_frames(trainVideoNames, args.videos_folder, extractPath)
129 | create_clips(extractPath, trainPath)
130 |
131 | # Select clips at random from test set for validation set.
132 | testClips = os.listdir(testPath)
133 | indices = random.sample(range(len(testClips)), min(100, int(len(testClips) / 5)))
134 | for index in indices:
135 | move("{}/{}".format(testPath, index), "{}/{}".format(validationPath, index))
136 |
137 | rmtree(extractPath)
138 |
139 | main()
140 |
--------------------------------------------------------------------------------
/data/figures/SYA_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/SYA_1.png
--------------------------------------------------------------------------------
/data/figures/SYA_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/SYA_2.png
--------------------------------------------------------------------------------
/data/figures/check_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/check_all.png
--------------------------------------------------------------------------------
/data/figures/shana.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/shana.png
--------------------------------------------------------------------------------
/data/figures/tianqi_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/tianqi_all.png
--------------------------------------------------------------------------------
/data/figures/video.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/video.png
--------------------------------------------------------------------------------
/data/figures/youtube.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/youtube.png
--------------------------------------------------------------------------------
/data/video_transformer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import argparse
5 | import shutil
6 |
7 | import multiprocessing
8 | from tqdm import tqdm
9 |
10 | # command line parser
11 | parser = argparse.ArgumentParser()
12 |
13 | parser.add_argument('--videos_folder', type=str, required=True, help='the path to video dataset folder.')
14 | parser.add_argument('--output_folder', type=str, default='../pre_dataset/', help='the path to output dataset folder.')
15 | parser.add_argument('--lower_rate', type=int, default=5, help='lower the video fps by n times.')
16 | args = parser.parse_args()
17 |
18 |
19 | class DataCreator(object):
20 |
21 | def __init__(self):
22 |
23 | self.videos_folder = args.videos_folder
24 | self.output_folder = args.output_folder
25 | self.lower_rate = args.lower_rate
26 | self.tmp = '../.tmp/'
27 | try:
28 | os.mkdir(self.tmp)
29 | except:
30 | pass
31 |
32 | def _listener(self, pbar, q):
33 | for item in iter(q.get, None):
34 | pbar.update(1)
35 |
36 | def _lower_fps(self, p_args):
37 | video_name, q = p_args
38 | # pbar.set_description("Processing %s" % video_name)
39 |
40 | # read a video and create video_writer for lower fps video output
41 | video = cv2.VideoCapture(os.path.join(self.videos_folder, video_name))
42 | fps = video.get(cv2.CAP_PROP_FPS)
43 | size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
44 | # fourcc = cv2.VideoWriter_fourcc(*'XVID')
45 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
46 | video_writer = [cv2.VideoWriter(self.tmp + video_name[:-4] + '_%s' % str(i) + '.mp4',
47 | fourcc,
48 | fps / self.lower_rate,
49 | size)
50 | for i in range(self.lower_rate)]
51 |
52 | count = 0
53 | while video.isOpened():
54 | ret, frame = video.read()
55 | if ret:
56 | video_writer[count % self.lower_rate].write(frame)
57 | if cv2.waitKey(1) & 0xFF == ord('q'):
58 | raise KeyboardInterrupt
59 | else:
60 | break
61 | count += 1
62 |
63 | for i in range(self.lower_rate):
64 | video_writer[i].release()
65 |
66 | q.put(1)
67 |
68 |
69 | def lower_fps(self):
70 |
71 | videos_name = os.listdir(self.videos_folder)
72 | pbar = tqdm(total=len(videos_name))
73 | m = multiprocessing.Manager()
74 | q = m.Queue()
75 |
76 | listener = multiprocessing.Process(target=self._listener, args=(pbar, q))
77 | listener.start()
78 |
79 | p_args = [(video_name, q) for video_name in videos_name]
80 | pool = multiprocessing.Pool()
81 | pool.map(self._lower_fps, p_args)
82 |
83 | pool.close()
84 | pool.join()
85 | q.put(None)
86 | listener.join()
87 |
88 | def output(self):
89 | os.system('mkdir %s' % self.output_folder)
90 | os.system('cp %s %s' % (self.tmp + '*', self.output_folder))
91 | os.system('rm -rf %s' % self.tmp)
92 |
93 |
94 | if __name__ == '__main__':
95 | data_creator = DataCreator()
96 | data_creator.lower_fps()
97 | data_creator.output()
98 |
99 |
--------------------------------------------------------------------------------
/eval_Vimeo90K.py:
--------------------------------------------------------------------------------
1 | # SeDraw
2 | import argparse
3 | from skimage.measure import compare_psnr, compare_ssim
4 | import numpy as np
5 | import os
6 | import torch
7 | import cv2
8 | import torchvision.transforms as transforms
9 | import torch.optim as optim
10 | import torch.nn as nn
11 | import time
12 | import src.dataloader as dataloader
13 | import src.layers as layers
14 | from math import log10
15 | import datetime
16 | from tqdm import tqdm
17 | from prefetch_generator import BackgroundGenerator
18 | import visdom
19 | from utils.visualize import feature_transform
20 | import numpy as np
21 | import models.bdcn.bdcn as bdcn
22 |
23 | # For parsing commandline arguments
24 | def str2bool(v):
25 | if isinstance(v, bool):
26 | return v
27 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
28 | return True
29 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
30 | return False
31 | else:
32 | raise argparse.ArgumentTypeError('Boolean value expected.')
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--dataset_root", type=str, required=True,
36 | help='path to dataset folder containing train-test-validation folders')
37 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model')
38 | parser.add_argument("--test_batch_size", type=int, default=1, help='batch size for training. Default: 6.')
39 | parser.add_argument('--visdom_env', type=str, default='SeDraw_0', help='Environment for visdom show')
40 | parser.add_argument('--vimeo90k', action='store_true', help='use this flag if using Vimeo-90K dataset')
41 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3')
42 | parser.add_argument('--bdcn_model', default='./models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth')
43 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.')
44 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint')
45 | parser.add_argument('--imgpath', type=str, required=True)
46 | args = parser.parse_args()
47 |
48 |
49 |
50 | # --For visualizing loss and interpolated frames--
51 |
52 |
53 | # Visdom for real-time visualizing
54 | vis = visdom.Visdom(env=args.visdom_env, port=8098)
55 |
56 | # device
57 | device_count = torch.cuda.device_count()
58 |
59 | # --Initialize network--
60 | bdcn = bdcn.BDCN()
61 | bdcn.cuda()
62 | structure_gen = layers.StructureGen(feature_level=args.feature_level)
63 | structure_gen.cuda()
64 | detail_enhance = layers.DetailEnhance()
65 | detail_enhance.cuda()
66 |
67 |
68 | # --Load Datasets--
69 |
70 |
71 | # Channel wise mean calculated on adobe240-fps training dataset
72 | mean = [0.5, 0.5, 0.5]
73 | std = [0.5, 0.5, 0.5]
74 | normalize = transforms.Normalize(mean=mean,
75 | std=std)
76 | transform = transforms.Compose([transforms.ToTensor(), normalize])
77 |
78 | if args.vimeo90k:
79 | testset = dataloader.SeDraw_vimeo90k(root=args.dataset_root, transform=transform,
80 | randomCropSize=(448, 256), train=False, test=True)
81 | else:
82 | testset = dataloader.SeDraw(root=args.dataset_root + '/validation', transform=transform,
83 | randomCropSize=(448, 256), train=False)
84 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size * device_count, shuffle=False)
85 |
86 | print(testset)
87 |
88 | # --Create transform to display image from tensor--
89 |
90 |
91 | negmean = [-1 for x in mean]
92 | restd = [2, 2, 2]
93 | revNormalize = transforms.Normalize(mean=negmean, std=restd)
94 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
95 |
96 |
97 | # --Utils--
98 |
99 | def get_lr(optimizer):
100 | for param_group in optimizer.param_groups:
101 | return param_group['lr']
102 |
103 |
104 |
105 | # --Validation function--
106 | #
107 |
108 |
109 | def validate():
110 | # For details see training.
111 | psnr = 0
112 | ie = 0
113 | tloss = 0
114 |
115 | with torch.no_grad():
116 | for testIndex, testData in tqdm(enumerate(testloader, 0)):
117 | frame0, frameT, frame1 = testData
118 |
119 | img0 = frame0.cuda()
120 | img1 = frame1.cuda()
121 | IFrame = frameT.cuda()
122 |
123 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
124 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
125 | IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1)
126 | _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
127 | loss, MSE_val, IE, imgt = detail_enhance((img0, img1, IFrame, ref_imgt))
128 | imgt = torch.clamp(imgt, max=1., min=-1.)
129 | IFrame_np = IFrame.squeeze(0).cpu().numpy()
130 | imgt_np = imgt.squeeze(0).cpu().numpy()
131 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
132 | IFrame_png = np.uint8(((IFrame_np + 1.0) /2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
133 | imgpath = args.imgpath + '/' + str(testIndex)
134 | if not os.path.isdir(imgpath):
135 | os.system('mkdir -p %s' % imgpath)
136 | cv2.imwrite(imgpath + '/imgt.png', imgt_png)
137 | cv2.imwrite(imgpath + '/IFrame.png', IFrame_png)
138 |
139 | PSNR = compare_psnr(IFrame_np, imgt_np, data_range=2)
140 | print('PSNR:', PSNR)
141 |
142 | loss = torch.mean(loss)
143 | MSE_val = torch.mean(MSE_val)
144 |
145 | if testIndex % 100 == 99:
146 | vImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0),
147 | revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0),
148 | revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)],
149 | dim=0)
150 |
151 |
152 | vImg = torch.clamp(vImg, max=1., min=0)
153 | vis.images(vImg, win='vImage', env=args.visdom_env, nrow=2, opts={'title': 'visual_image'})
154 |
155 | # psnr
156 | tloss += loss.item()
157 |
158 | psnr += PSNR
159 | ie += IE
160 |
161 | return (psnr / len(testloader)), (tloss / len(testloader)), MSE_val, (ie / len(testloader))
162 |
163 |
164 | # --Initialization--
165 |
166 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
167 |
168 | dict1 = torch.load(args.checkpoint)
169 | structure_gen.load_state_dict(dict1['state_dictGEN'])
170 | detail_enhance.load_state_dict(dict1['state_dictDE'])
171 |
172 | start = time.time()
173 |
174 | bdcn.eval()
175 | structure_gen.eval()
176 | detail_enhance.eval()
177 | psnr, vLoss, MSE_val, ie = validate()
178 | end = time.time()
179 |
180 | print(" Loss: %0.6f TestExecTime: %0.1f ValPSNR: %0.4f ValIE: %0.4f" % (
181 | vLoss, end - start, psnr, ie))
182 |
183 |
--------------------------------------------------------------------------------
/models/DCNv2/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .idea
3 | *.so
4 | *.o
5 | *pyc
6 | _ext
7 | build
8 | DCNv2.egg-info
9 | dist
--------------------------------------------------------------------------------
/models/DCNv2/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Charles Shang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/models/DCNv2/README.md:
--------------------------------------------------------------------------------
1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0
2 |
3 | ### Build
4 | ```bash
5 | ./make.sh # build
6 | python test.py # run examples and gradient check
7 | ```
8 |
9 | ### An Example
10 | - deformable conv
11 | ```python
12 | from dcn_v2 import DCN
13 | input = torch.randn(2, 64, 128, 128).cuda()
14 | # wrap all things (offset and mask) in DCN
15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
16 | output = dcn(input)
17 | print(output.shape)
18 | ```
19 | - deformable roi pooling
20 | ```python
21 | from dcn_v2 import DCNPooling
22 | input = torch.randn(2, 32, 64, 64).cuda()
23 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
24 | x = torch.randint(256, (20, 1)).cuda().float()
25 | y = torch.randint(256, (20, 1)).cuda().float()
26 | w = torch.randint(64, (20, 1)).cuda().float()
27 | h = torch.randint(64, (20, 1)).cuda().float()
28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
29 |
30 | # mdformable pooling (V2)
31 | # wrap all things (offset and mask) in DCNPooling
32 | dpooling = DCNPooling(spatial_scale=1.0 / 4,
33 | pooled_size=7,
34 | output_dim=32,
35 | no_trans=False,
36 | group_size=1,
37 | trans_std=0.1).cuda()
38 |
39 | dout = dpooling(input, rois)
40 | ```
41 | ### Note
42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
43 | ```bash
44 | git checkout pytorch_0.4
45 | ```
46 |
47 | ### Known Issues:
48 |
49 | - [x] Gradient check w.r.t offset (solved)
50 | - [ ] Backward is not reentrant (minor)
51 |
52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
53 |
54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
56 | non-differential points?
57 |
58 | Update: all gradient check passes with double precision.
59 |
60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
61 | float `<1e-15` for double),
62 | so it may not be a serious problem (?)
63 |
64 | Please post an issue or PR if you have any comments.
65 |
--------------------------------------------------------------------------------
/models/DCNv2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/models/DCNv2/__init__.py
--------------------------------------------------------------------------------
/models/DCNv2/dcn_v2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import math
4 | import logging
5 | import torch
6 | from torch import nn
7 | from torch.autograd import Function
8 | from torch.nn.modules.utils import _pair
9 | from torch.autograd.function import once_differentiable
10 |
11 | import _ext as _backend
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class _DCNv2(Function):
16 | @staticmethod
17 | def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation,
18 | deformable_groups):
19 | ctx.stride = _pair(stride)
20 | ctx.padding = _pair(padding)
21 | ctx.dilation = _pair(dilation)
22 | ctx.kernel_size = _pair(weight.shape[2:4])
23 | ctx.deformable_groups = deformable_groups
24 | output = _backend.dcn_v2_forward(input, weight, bias, offset, mask, ctx.kernel_size[0],
25 | ctx.kernel_size[1], ctx.stride[0], ctx.stride[1],
26 | ctx.padding[0], ctx.padding[1], ctx.dilation[0],
27 | ctx.dilation[1], ctx.deformable_groups)
28 | ctx.save_for_backward(input, offset, mask, weight, bias)
29 | return output
30 |
31 | @staticmethod
32 | @once_differentiable
33 | def backward(ctx, grad_output):
34 | input, offset, mask, weight, bias = ctx.saved_tensors
35 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
36 | _backend.dcn_v2_backward(input, weight,
37 | bias,
38 | offset, mask,
39 | grad_output,
40 | ctx.kernel_size[0], ctx.kernel_size[1],
41 | ctx.stride[0], ctx.stride[1],
42 | ctx.padding[0], ctx.padding[1],
43 | ctx.dilation[0], ctx.dilation[1],
44 | ctx.deformable_groups)
45 |
46 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\
47 | None, None, None, None,
48 |
49 |
50 | dcn_v2_conv = _DCNv2.apply
51 |
52 |
53 | class DCNv2(nn.Module):
54 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
55 | deformable_groups=1):
56 | super(DCNv2, self).__init__()
57 | self.in_channels = in_channels
58 | self.out_channels = out_channels
59 | self.kernel_size = _pair(kernel_size)
60 | self.stride = _pair(stride)
61 | self.padding = _pair(padding)
62 | self.dilation = _pair(dilation)
63 | self.deformable_groups = deformable_groups
64 |
65 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
66 | self.bias = nn.Parameter(torch.Tensor(out_channels))
67 | self.reset_parameters()
68 |
69 | def reset_parameters(self):
70 | n = self.in_channels
71 | for k in self.kernel_size:
72 | n *= k
73 | stdv = 1. / math.sqrt(n)
74 | self.weight.data.uniform_(-stdv, stdv)
75 | self.bias.data.zero_()
76 |
77 | def forward(self, input, offset, mask):
78 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
79 | offset.shape[1]
80 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
81 | mask.shape[1]
82 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
83 | self.dilation, self.deformable_groups)
84 |
85 |
86 | class DCN(DCNv2):
87 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
88 | deformable_groups=1):
89 | super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
90 | deformable_groups)
91 |
92 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
93 | self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
94 | stride=self.stride, padding=self.padding, bias=True)
95 | self.init_offset()
96 |
97 | def init_offset(self):
98 | self.conv_offset_mask.weight.data.zero_()
99 | self.conv_offset_mask.bias.data.zero_()
100 |
101 | def forward(self, input):
102 | out = self.conv_offset_mask(input)
103 | o1, o2, mask = torch.chunk(out, 3, dim=1)
104 | offset = torch.cat((o1, o2), dim=1)
105 | mask = torch.sigmoid(mask)
106 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
107 | self.dilation, self.deformable_groups)
108 |
109 |
110 | class DCN_sep(DCNv2):
111 | '''Use other features to generate offsets and masks'''
112 |
113 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
114 | deformable_groups=1):
115 | super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
116 | dilation, deformable_groups)
117 |
118 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
119 | self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
120 | stride=self.stride, padding=self.padding, bias=True)
121 | self.init_offset()
122 |
123 | def init_offset(self):
124 | self.conv_offset_mask.weight.data.zero_()
125 | self.conv_offset_mask.bias.data.zero_()
126 |
127 | def forward(self, input, fea):
128 | '''input: input features for deformable conv
129 | fea: other features used for generating offsets and mask'''
130 | out = self.conv_offset_mask(fea)
131 | o1, o2, mask = torch.chunk(out, 3, dim=1)
132 | offset = torch.cat((o1, o2), dim=1)
133 |
134 | offset_mean = torch.mean(torch.abs(offset))
135 | if offset_mean > 100:
136 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
137 |
138 | mask = torch.sigmoid(mask)
139 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
140 | self.dilation, self.deformable_groups)
141 |
142 |
143 | class _DCNv2Pooling(Function):
144 | @staticmethod
145 | def forward(ctx, input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans,
146 | group_size=1, part_size=None, sample_per_part=4, trans_std=.0):
147 | ctx.spatial_scale = spatial_scale
148 | ctx.no_trans = int(no_trans)
149 | ctx.output_dim = output_dim
150 | ctx.group_size = group_size
151 | ctx.pooled_size = pooled_size
152 | ctx.part_size = pooled_size if part_size is None else part_size
153 | ctx.sample_per_part = sample_per_part
154 | ctx.trans_std = trans_std
155 |
156 | output, output_count = \
157 | _backend.dcn_v2_psroi_pooling_forward(input, rois, offset,
158 | ctx.no_trans, ctx.spatial_scale,
159 | ctx.output_dim, ctx.group_size,
160 | ctx.pooled_size, ctx.part_size,
161 | ctx.sample_per_part, ctx.trans_std)
162 | ctx.save_for_backward(input, rois, offset, output_count)
163 | return output
164 |
165 | @staticmethod
166 | @once_differentiable
167 | def backward(ctx, grad_output):
168 | input, rois, offset, output_count = ctx.saved_tensors
169 | grad_input, grad_offset = \
170 | _backend.dcn_v2_psroi_pooling_backward(grad_output,
171 | input,
172 | rois,
173 | offset,
174 | output_count,
175 | ctx.no_trans,
176 | ctx.spatial_scale,
177 | ctx.output_dim,
178 | ctx.group_size,
179 | ctx.pooled_size,
180 | ctx.part_size,
181 | ctx.sample_per_part,
182 | ctx.trans_std)
183 |
184 | return grad_input, None, grad_offset, \
185 | None, None, None, None, None, None, None, None
186 |
187 |
188 | dcn_v2_pooling = _DCNv2Pooling.apply
189 |
190 |
191 | class DCNv2Pooling(nn.Module):
192 | def __init__(self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1,
193 | part_size=None, sample_per_part=4, trans_std=.0):
194 | super(DCNv2Pooling, self).__init__()
195 | self.spatial_scale = spatial_scale
196 | self.pooled_size = pooled_size
197 | self.output_dim = output_dim
198 | self.no_trans = no_trans
199 | self.group_size = group_size
200 | self.part_size = pooled_size if part_size is None else part_size
201 | self.sample_per_part = sample_per_part
202 | self.trans_std = trans_std
203 |
204 | def forward(self, input, rois, offset):
205 | assert input.shape[1] == self.output_dim
206 | if self.no_trans:
207 | offset = input.new()
208 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size,
209 | self.output_dim, self.no_trans, self.group_size, self.part_size,
210 | self.sample_per_part, self.trans_std)
211 |
212 |
213 | class DCNPooling(DCNv2Pooling):
214 | def __init__(self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1,
215 | part_size=None, sample_per_part=4, trans_std=.0, deform_fc_dim=1024):
216 | super(DCNPooling, self).__init__(spatial_scale, pooled_size, output_dim, no_trans,
217 | group_size, part_size, sample_per_part, trans_std)
218 |
219 | self.deform_fc_dim = deform_fc_dim
220 |
221 | if not no_trans:
222 | self.offset_mask_fc = nn.Sequential(
223 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim,
224 | self.deform_fc_dim), nn.ReLU(inplace=True),
225 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True),
226 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3))
227 | self.offset_mask_fc[4].weight.data.zero_()
228 | self.offset_mask_fc[4].bias.data.zero_()
229 |
230 | def forward(self, input, rois):
231 | offset = input.new()
232 |
233 | if not self.no_trans:
234 |
235 | # do roi_align first
236 | n = rois.shape[0]
237 | roi = dcn_v2_pooling(
238 | input,
239 | rois,
240 | offset,
241 | self.spatial_scale,
242 | self.pooled_size,
243 | self.output_dim,
244 | True, # no trans
245 | self.group_size,
246 | self.part_size,
247 | self.sample_per_part,
248 | self.trans_std)
249 |
250 | # build mask and offset
251 | offset_mask = self.offset_mask_fc(roi.view(n, -1))
252 | offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size)
253 | o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
254 | offset = torch.cat((o1, o2), dim=1)
255 | mask = torch.sigmoid(mask)
256 |
257 | # do pooling with offset and mask
258 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size,
259 | self.output_dim, self.no_trans, self.group_size, self.part_size,
260 | self.sample_per_part, self.trans_std) * mask
261 | # only roi_align
262 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size,
263 | self.output_dim, self.no_trans, self.group_size, self.part_size,
264 | self.sample_per_part, self.trans_std)
265 |
--------------------------------------------------------------------------------
/models/DCNv2/make.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 |
3 | # You may need to modify the following paths before compiling.
4 |
5 | # CUDA_HOME=/usr/local/cuda-10.0 \
6 | # CUDNN_INCLUDE_DIR=/usr/local/cuda-10.0/include \
7 | # CUDNN_LIB_DIR=/usr/local/cuda-10.0/lib64 \
8 | python setup.py build develop
9 |
--------------------------------------------------------------------------------
/models/DCNv2/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import glob
5 |
6 | import torch
7 |
8 | from torch.utils.cpp_extension import CUDA_HOME
9 | from torch.utils.cpp_extension import CppExtension
10 | from torch.utils.cpp_extension import CUDAExtension
11 |
12 | from setuptools import find_packages
13 | from setuptools import setup
14 |
15 | requirements = ["torch", "torchvision"]
16 |
17 |
18 | def get_extensions():
19 | this_dir = os.path.dirname(os.path.abspath(__file__))
20 | extensions_dir = os.path.join(this_dir, "src")
21 |
22 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
23 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
24 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
25 |
26 | sources = main_file + source_cpu
27 | extension = CppExtension
28 | extra_compile_args = {"cxx": []}
29 | define_macros = []
30 |
31 | if torch.cuda.is_available() and CUDA_HOME is not None:
32 | extension = CUDAExtension
33 | sources += source_cuda
34 | define_macros += [("WITH_CUDA", None)]
35 | extra_compile_args["nvcc"] = [
36 | "-DCUDA_HAS_FP16=1",
37 | "-D__CUDA_NO_HALF_OPERATORS__",
38 | "-D__CUDA_NO_HALF_CONVERSIONS__",
39 | "-D__CUDA_NO_HALF2_OPERATORS__",
40 | ]
41 | else:
42 | raise NotImplementedError('Cuda is not availabel')
43 |
44 | sources = [os.path.join(extensions_dir, s) for s in sources]
45 | include_dirs = [extensions_dir]
46 | ext_modules = [
47 | extension(
48 | "_ext",
49 | sources,
50 | include_dirs=include_dirs,
51 | define_macros=define_macros,
52 | extra_compile_args=extra_compile_args,
53 | )
54 | ]
55 | return ext_modules
56 |
57 |
58 | setup(
59 | name="DCNv2",
60 | version="0.1",
61 | author="charlesshang",
62 | url="https://github.com/charlesshang/DCNv2",
63 | description="deformable convolutional networks",
64 | packages=find_packages(exclude=(
65 | "configs",
66 | "tests",
67 | )),
68 | # install_requires=requirements,
69 | ext_modules=get_extensions(),
70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71 | )
--------------------------------------------------------------------------------
/models/DCNv2/src/cpu/dcn_v2_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 |
7 | at::Tensor
8 | dcn_v2_cpu_forward(const at::Tensor &input,
9 | const at::Tensor &weight,
10 | const at::Tensor &bias,
11 | const at::Tensor &offset,
12 | const at::Tensor &mask,
13 | const int kernel_h,
14 | const int kernel_w,
15 | const int stride_h,
16 | const int stride_w,
17 | const int pad_h,
18 | const int pad_w,
19 | const int dilation_h,
20 | const int dilation_w,
21 | const int deformable_group)
22 | {
23 | AT_ERROR("Not implement on cpu");
24 | }
25 |
26 | std::vector
27 | dcn_v2_cpu_backward(const at::Tensor &input,
28 | const at::Tensor &weight,
29 | const at::Tensor &bias,
30 | const at::Tensor &offset,
31 | const at::Tensor &mask,
32 | const at::Tensor &grad_output,
33 | int kernel_h, int kernel_w,
34 | int stride_h, int stride_w,
35 | int pad_h, int pad_w,
36 | int dilation_h, int dilation_w,
37 | int deformable_group)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 | std::tuple
43 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
44 | const at::Tensor &bbox,
45 | const at::Tensor &trans,
46 | const int no_trans,
47 | const float spatial_scale,
48 | const int output_dim,
49 | const int group_size,
50 | const int pooled_size,
51 | const int part_size,
52 | const int sample_per_part,
53 | const float trans_std)
54 | {
55 | AT_ERROR("Not implement on cpu");
56 | }
57 |
58 | std::tuple
59 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
60 | const at::Tensor &input,
61 | const at::Tensor &bbox,
62 | const at::Tensor &trans,
63 | const at::Tensor &top_count,
64 | const int no_trans,
65 | const float spatial_scale,
66 | const int output_dim,
67 | const int group_size,
68 | const int pooled_size,
69 | const int part_size,
70 | const int sample_per_part,
71 | const float trans_std)
72 | {
73 | AT_ERROR("Not implement on cpu");
74 | }
--------------------------------------------------------------------------------
/models/DCNv2/src/cpu/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cpu_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cpu_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/models/DCNv2/src/cuda/dcn_v2_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "cuda/dcn_v2_im2col_cuda.h"
3 |
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | #include
10 |
11 | extern THCState *state;
12 |
13 | // author: Charles Shang
14 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
15 |
16 | // [batch gemm]
17 | // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu
18 |
19 | __global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
20 | float **columns_b, const float **ones_b,
21 | const float **weight_b, const float **bias_b,
22 | float *input, float *output,
23 | float *columns, float *ones,
24 | float *weight, float *bias,
25 | const int input_stride, const int output_stride,
26 | const int columns_stride, const int ones_stride,
27 | const int num_batches)
28 | {
29 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
30 | if (idx < num_batches)
31 | {
32 | input_b[idx] = input + idx * input_stride;
33 | output_b[idx] = output + idx * output_stride;
34 | columns_b[idx] = columns + idx * columns_stride;
35 | ones_b[idx] = ones + idx * ones_stride;
36 | // share weights and bias within a Mini-Batch
37 | weight_b[idx] = weight;
38 | bias_b[idx] = bias;
39 | }
40 | }
41 |
42 | at::Tensor
43 | dcn_v2_cuda_forward(const at::Tensor &input,
44 | const at::Tensor &weight,
45 | const at::Tensor &bias,
46 | const at::Tensor &offset,
47 | const at::Tensor &mask,
48 | const int kernel_h,
49 | const int kernel_w,
50 | const int stride_h,
51 | const int stride_w,
52 | const int pad_h,
53 | const int pad_w,
54 | const int dilation_h,
55 | const int dilation_w,
56 | const int deformable_group)
57 | {
58 | using scalar_t = float;
59 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
60 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
61 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
62 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
63 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
64 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
65 |
66 | const int batch = input.size(0);
67 | const int channels = input.size(1);
68 | const int height = input.size(2);
69 | const int width = input.size(3);
70 |
71 | const int channels_out = weight.size(0);
72 | const int channels_kernel = weight.size(1);
73 | const int kernel_h_ = weight.size(2);
74 | const int kernel_w_ = weight.size(3);
75 |
76 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
77 | // printf("Channels: %d %d\n", channels, channels_kernel);
78 | // printf("Channels: %d %d\n", channels_out, channels_kernel);
79 |
80 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
81 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
82 |
83 | AT_ASSERTM(channels == channels_kernel,
84 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
85 |
86 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
87 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
88 |
89 | auto ones = at::ones({batch, height_out, width_out}, input.options());
90 | auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
91 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
92 |
93 | // prepare for batch-wise computing, which is significantly faster than instance-wise computing
94 | // when batch size is large.
95 | // launch batch threads
96 | int matrices_size = batch * sizeof(float *);
97 | auto input_b = static_cast(THCudaMalloc(state, matrices_size));
98 | auto output_b = static_cast(THCudaMalloc(state, matrices_size));
99 | auto columns_b = static_cast(THCudaMalloc(state, matrices_size));
100 | auto ones_b = static_cast(THCudaMalloc(state, matrices_size));
101 | auto weight_b = static_cast(THCudaMalloc(state, matrices_size));
102 | auto bias_b = static_cast(THCudaMalloc(state, matrices_size));
103 |
104 | const int block = 128;
105 | const int grid = (batch + block - 1) / block;
106 |
107 | createBatchGemmBuffer<<>>(
108 | input_b, output_b,
109 | columns_b, ones_b,
110 | weight_b, bias_b,
111 | input.data(),
112 | output.data(),
113 | columns.data(),
114 | ones.data(),
115 | weight.data(),
116 | bias.data(),
117 | channels * width * height,
118 | channels_out * width_out * height_out,
119 | channels * kernel_h * kernel_w * height_out * width_out,
120 | height_out * width_out,
121 | batch);
122 |
123 | long m_ = channels_out;
124 | long n_ = height_out * width_out;
125 | long k_ = 1;
126 | THCudaBlas_SgemmBatched(state,
127 | 't',
128 | 'n',
129 | n_,
130 | m_,
131 | k_,
132 | 1.0f,
133 | ones_b, k_,
134 | bias_b, k_,
135 | 0.0f,
136 | output_b, n_,
137 | batch);
138 |
139 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
140 | input.data(),
141 | offset.data(),
142 | mask.data(),
143 | batch, channels, height, width,
144 | height_out, width_out, kernel_h, kernel_w,
145 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
146 | deformable_group,
147 | columns.data());
148 |
149 | long m = channels_out;
150 | long n = height_out * width_out;
151 | long k = channels * kernel_h * kernel_w;
152 | THCudaBlas_SgemmBatched(state,
153 | 'n',
154 | 'n',
155 | n,
156 | m,
157 | k,
158 | 1.0f,
159 | (const float **)columns_b, n,
160 | weight_b, k,
161 | 1.0f,
162 | output_b, n,
163 | batch);
164 |
165 | THCudaFree(state, input_b);
166 | THCudaFree(state, output_b);
167 | THCudaFree(state, columns_b);
168 | THCudaFree(state, ones_b);
169 | THCudaFree(state, weight_b);
170 | THCudaFree(state, bias_b);
171 | return output;
172 | }
173 |
174 | __global__ void createBatchGemmBufferBackward(
175 | float **grad_output_b,
176 | float **columns_b,
177 | float **ones_b,
178 | float **weight_b,
179 | float **grad_weight_b,
180 | float **grad_bias_b,
181 | float *grad_output,
182 | float *columns,
183 | float *ones,
184 | float *weight,
185 | float *grad_weight,
186 | float *grad_bias,
187 | const int grad_output_stride,
188 | const int columns_stride,
189 | const int ones_stride,
190 | const int num_batches)
191 | {
192 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
193 | if (idx < num_batches)
194 | {
195 | grad_output_b[idx] = grad_output + idx * grad_output_stride;
196 | columns_b[idx] = columns + idx * columns_stride;
197 | ones_b[idx] = ones + idx * ones_stride;
198 |
199 | // share weights and bias within a Mini-Batch
200 | weight_b[idx] = weight;
201 | grad_weight_b[idx] = grad_weight;
202 | grad_bias_b[idx] = grad_bias;
203 | }
204 | }
205 |
206 | std::vector dcn_v2_cuda_backward(const at::Tensor &input,
207 | const at::Tensor &weight,
208 | const at::Tensor &bias,
209 | const at::Tensor &offset,
210 | const at::Tensor &mask,
211 | const at::Tensor &grad_output,
212 | int kernel_h, int kernel_w,
213 | int stride_h, int stride_w,
214 | int pad_h, int pad_w,
215 | int dilation_h, int dilation_w,
216 | int deformable_group)
217 | {
218 |
219 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
220 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
221 |
222 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
223 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
224 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
225 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
226 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
227 |
228 | const int batch = input.size(0);
229 | const int channels = input.size(1);
230 | const int height = input.size(2);
231 | const int width = input.size(3);
232 |
233 | const int channels_out = weight.size(0);
234 | const int channels_kernel = weight.size(1);
235 | const int kernel_h_ = weight.size(2);
236 | const int kernel_w_ = weight.size(3);
237 |
238 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
239 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
240 |
241 | AT_ASSERTM(channels == channels_kernel,
242 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
243 |
244 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
245 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
246 |
247 | auto ones = at::ones({height_out, width_out}, input.options());
248 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
249 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
250 |
251 | auto grad_input = at::zeros_like(input);
252 | auto grad_weight = at::zeros_like(weight);
253 | auto grad_bias = at::zeros_like(bias);
254 | auto grad_offset = at::zeros_like(offset);
255 | auto grad_mask = at::zeros_like(mask);
256 |
257 | using scalar_t = float;
258 |
259 | for (int b = 0; b < batch; b++)
260 | {
261 | auto input_n = input.select(0, b);
262 | auto offset_n = offset.select(0, b);
263 | auto mask_n = mask.select(0, b);
264 | auto grad_output_n = grad_output.select(0, b);
265 | auto grad_input_n = grad_input.select(0, b);
266 | auto grad_offset_n = grad_offset.select(0, b);
267 | auto grad_mask_n = grad_mask.select(0, b);
268 |
269 | long m = channels * kernel_h * kernel_w;
270 | long n = height_out * width_out;
271 | long k = channels_out;
272 |
273 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
274 | grad_output_n.data(), n,
275 | weight.data(), m, 0.0f,
276 | columns.data(), n);
277 |
278 | // gradient w.r.t. input coordinate data
279 | modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
280 | columns.data(),
281 | input_n.data(),
282 | offset_n.data(),
283 | mask_n.data(),
284 | 1, channels, height, width,
285 | height_out, width_out, kernel_h, kernel_w,
286 | pad_h, pad_w, stride_h, stride_w,
287 | dilation_h, dilation_w, deformable_group,
288 | grad_offset_n.data(),
289 | grad_mask_n.data());
290 | // gradient w.r.t. input data
291 | modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
292 | columns.data(),
293 | offset_n.data(),
294 | mask_n.data(),
295 | 1, channels, height, width,
296 | height_out, width_out, kernel_h, kernel_w,
297 | pad_h, pad_w, stride_h, stride_w,
298 | dilation_h, dilation_w, deformable_group,
299 | grad_input_n.data());
300 |
301 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group
302 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
303 | input_n.data(),
304 | offset_n.data(),
305 | mask_n.data(),
306 | 1, channels, height, width,
307 | height_out, width_out, kernel_h, kernel_w,
308 | pad_h, pad_w, stride_h, stride_w,
309 | dilation_h, dilation_w, deformable_group,
310 | columns.data());
311 |
312 | long m_ = channels_out;
313 | long n_ = channels * kernel_h * kernel_w;
314 | long k_ = height_out * width_out;
315 |
316 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
317 | columns.data(), k_,
318 | grad_output_n.data(), k_, 1.0f,
319 | grad_weight.data(), n_);
320 |
321 | // gradient w.r.t. bias
322 | // long m_ = channels_out;
323 | // long k__ = height_out * width_out;
324 | THCudaBlas_Sgemv(state,
325 | 't',
326 | k_, m_, 1.0f,
327 | grad_output_n.data(), k_,
328 | ones.data(), 1, 1.0f,
329 | grad_bias.data(), 1);
330 | }
331 |
332 | return {
333 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias
334 | };
335 | }
--------------------------------------------------------------------------------
/models/DCNv2/src/cuda/dcn_v2_im2col_cuda.h:
--------------------------------------------------------------------------------
1 |
2 | /*!
3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
4 | *
5 | * COPYRIGHT
6 | *
7 | * All contributions by the University of California:
8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
9 | * All rights reserved.
10 | *
11 | * All other contributions:
12 | * Copyright (c) 2014-2017, the respective contributors
13 | * All rights reserved.
14 | *
15 | * Caffe uses a shared copyright model: each contributor holds copyright over
16 | * their contributions to Caffe. The project versioning records all such
17 | * contribution and copyright details. If a contributor wants to further mark
18 | * their specific copyright on a particular contribution, they should indicate
19 | * their copyright solely in the commit message of the change when it is
20 | * committed.
21 | *
22 | * LICENSE
23 | *
24 | * Redistribution and use in source and binary forms, with or without
25 | * modification, are permitted provided that the following conditions are met:
26 | *
27 | * 1. Redistributions of source code must retain the above copyright notice, this
28 | * list of conditions and the following disclaimer.
29 | * 2. Redistributions in binary form must reproduce the above copyright notice,
30 | * this list of conditions and the following disclaimer in the documentation
31 | * and/or other materials provided with the distribution.
32 | *
33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 | *
44 | * CONTRIBUTION AGREEMENT
45 | *
46 | * By contributing to the BVLC/caffe repository through pull-request, comment,
47 | * or otherwise, the contributor releases their content to the
48 | * license and copyright terms herein.
49 | *
50 | ***************** END Caffe Copyright Notice and Disclaimer ********************
51 | *
52 | * Copyright (c) 2018 Microsoft
53 | * Licensed under The MIT License [see LICENSE for details]
54 | * \file modulated_deformable_im2col.h
55 | * \brief Function definitions of converting an image to
56 | * column matrix based on kernel, padding, dilation, and offset.
57 | * These functions are mainly used in deformable convolution operators.
58 | * \ref: https://arxiv.org/abs/1811.11168
59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
60 | */
61 |
62 | /***************** Adapted by Charles Shang *********************/
63 |
64 | #ifndef DCN_V2_IM2COL_CUDA
65 | #define DCN_V2_IM2COL_CUDA
66 |
67 | #ifdef __cplusplus
68 | extern "C"
69 | {
70 | #endif
71 |
72 | void modulated_deformable_im2col_cuda(cudaStream_t stream,
73 | const float *data_im, const float *data_offset, const float *data_mask,
74 | const int batch_size, const int channels, const int height_im, const int width_im,
75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
77 | const int dilation_h, const int dilation_w,
78 | const int deformable_group, float *data_col);
79 |
80 | void modulated_deformable_col2im_cuda(cudaStream_t stream,
81 | const float *data_col, const float *data_offset, const float *data_mask,
82 | const int batch_size, const int channels, const int height_im, const int width_im,
83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
85 | const int dilation_h, const int dilation_w,
86 | const int deformable_group, float *grad_im);
87 |
88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
90 | const int batch_size, const int channels, const int height_im, const int width_im,
91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
93 | const int dilation_h, const int dilation_w,
94 | const int deformable_group,
95 | float *grad_offset, float *grad_mask);
96 |
97 | #ifdef __cplusplus
98 | }
99 | #endif
100 |
101 | #endif
--------------------------------------------------------------------------------
/models/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017 Microsoft
3 | * Licensed under The MIT License [see LICENSE for details]
4 | * \file deformable_psroi_pooling.cu
5 | * \brief
6 | * \author Yi Li, Guodong Zhang, Jifeng Dai
7 | */
8 | /***************** Adapted by Charles Shang *********************/
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #define CUDA_KERNEL_LOOP(i, n) \
23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
24 | i < (n); \
25 | i += blockDim.x * gridDim.x)
26 |
27 | const int CUDA_NUM_THREADS = 1024;
28 | inline int GET_BLOCKS(const int N)
29 | {
30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
31 | }
32 |
33 | template
34 | __device__ T bilinear_interp(
35 | const T *data,
36 | const T x,
37 | const T y,
38 | const int width,
39 | const int height)
40 | {
41 | int x1 = floor(x);
42 | int x2 = ceil(x);
43 | int y1 = floor(y);
44 | int y2 = ceil(y);
45 | T dist_x = static_cast(x - x1);
46 | T dist_y = static_cast(y - y1);
47 | T value11 = data[y1 * width + x1];
48 | T value12 = data[y2 * width + x1];
49 | T value21 = data[y1 * width + x2];
50 | T value22 = data[y2 * width + x2];
51 | T value = (1 - dist_x) * (1 - dist_y) * value11 +
52 | (1 - dist_x) * dist_y * value12 +
53 | dist_x * (1 - dist_y) * value21 +
54 | dist_x * dist_y * value22;
55 | return value;
56 | }
57 |
58 | template
59 | __global__ void DeformablePSROIPoolForwardKernel(
60 | const int count,
61 | const T *bottom_data,
62 | const T spatial_scale,
63 | const int channels,
64 | const int height, const int width,
65 | const int pooled_height, const int pooled_width,
66 | const T *bottom_rois, const T *bottom_trans,
67 | const int no_trans,
68 | const T trans_std,
69 | const int sample_per_part,
70 | const int output_dim,
71 | const int group_size,
72 | const int part_size,
73 | const int num_classes,
74 | const int channels_each_class,
75 | T *top_data,
76 | T *top_count)
77 | {
78 | CUDA_KERNEL_LOOP(index, count)
79 | {
80 | // The output is in order (n, ctop, ph, pw)
81 | int pw = index % pooled_width;
82 | int ph = (index / pooled_width) % pooled_height;
83 | int ctop = (index / pooled_width / pooled_height) % output_dim;
84 | int n = index / pooled_width / pooled_height / output_dim;
85 |
86 | // [start, end) interval for spatial sampling
87 | const T *offset_bottom_rois = bottom_rois + n * 5;
88 | int roi_batch_ind = offset_bottom_rois[0];
89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
93 |
94 | // Force too small ROIs to be 1x1
95 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
96 | T roi_height = max(roi_end_h - roi_start_h, 0.1);
97 |
98 | // Compute w and h at bottom
99 | T bin_size_h = roi_height / static_cast(pooled_height);
100 | T bin_size_w = roi_width / static_cast(pooled_width);
101 |
102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
104 |
105 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
106 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
107 | int class_id = ctop / channels_each_class;
108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
110 |
111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
112 | wstart += trans_x * roi_width;
113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
114 | hstart += trans_y * roi_height;
115 |
116 | T sum = 0;
117 | int count = 0;
118 | int gw = floor(static_cast(pw) * group_size / pooled_width);
119 | int gh = floor(static_cast(ph) * group_size / pooled_height);
120 | gw = min(max(gw, 0), group_size - 1);
121 | gh = min(max(gh, 0), group_size - 1);
122 |
123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
124 | for (int ih = 0; ih < sample_per_part; ih++)
125 | {
126 | for (int iw = 0; iw < sample_per_part; iw++)
127 | {
128 | T w = wstart + iw * sub_bin_size_w;
129 | T h = hstart + ih * sub_bin_size_h;
130 | // bilinear interpolation
131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
132 | {
133 | continue;
134 | }
135 | w = min(max(w, 0.), width - 1.);
136 | h = min(max(h, 0.), height - 1.);
137 | int c = (ctop * group_size + gh) * group_size + gw;
138 | T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
139 | sum += val;
140 | count++;
141 | }
142 | }
143 | top_data[index] = count == 0 ? static_cast(0) : sum / count;
144 | top_count[index] = count;
145 | }
146 | }
147 |
148 | template
149 | __global__ void DeformablePSROIPoolBackwardAccKernel(
150 | const int count,
151 | const T *top_diff,
152 | const T *top_count,
153 | const int num_rois,
154 | const T spatial_scale,
155 | const int channels,
156 | const int height, const int width,
157 | const int pooled_height, const int pooled_width,
158 | const int output_dim,
159 | T *bottom_data_diff, T *bottom_trans_diff,
160 | const T *bottom_data,
161 | const T *bottom_rois,
162 | const T *bottom_trans,
163 | const int no_trans,
164 | const T trans_std,
165 | const int sample_per_part,
166 | const int group_size,
167 | const int part_size,
168 | const int num_classes,
169 | const int channels_each_class)
170 | {
171 | CUDA_KERNEL_LOOP(index, count)
172 | {
173 | // The output is in order (n, ctop, ph, pw)
174 | int pw = index % pooled_width;
175 | int ph = (index / pooled_width) % pooled_height;
176 | int ctop = (index / pooled_width / pooled_height) % output_dim;
177 | int n = index / pooled_width / pooled_height / output_dim;
178 |
179 | // [start, end) interval for spatial sampling
180 | const T *offset_bottom_rois = bottom_rois + n * 5;
181 | int roi_batch_ind = offset_bottom_rois[0];
182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
186 |
187 | // Force too small ROIs to be 1x1
188 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
189 | T roi_height = max(roi_end_h - roi_start_h, 0.1);
190 |
191 | // Compute w and h at bottom
192 | T bin_size_h = roi_height / static_cast(pooled_height);
193 | T bin_size_w = roi_width / static_cast(pooled_width);
194 |
195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
197 |
198 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
199 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
200 | int class_id = ctop / channels_each_class;
201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
203 |
204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
205 | wstart += trans_x * roi_width;
206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
207 | hstart += trans_y * roi_height;
208 |
209 | if (top_count[index] <= 0)
210 | {
211 | continue;
212 | }
213 | T diff_val = top_diff[index] / top_count[index];
214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
216 | int gw = floor(static_cast(pw) * group_size / pooled_width);
217 | int gh = floor(static_cast(ph) * group_size / pooled_height);
218 | gw = min(max(gw, 0), group_size - 1);
219 | gh = min(max(gh, 0), group_size - 1);
220 |
221 | for (int ih = 0; ih < sample_per_part; ih++)
222 | {
223 | for (int iw = 0; iw < sample_per_part; iw++)
224 | {
225 | T w = wstart + iw * sub_bin_size_w;
226 | T h = hstart + ih * sub_bin_size_h;
227 | // bilinear interpolation
228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
229 | {
230 | continue;
231 | }
232 | w = min(max(w, 0.), width - 1.);
233 | h = min(max(h, 0.), height - 1.);
234 | int c = (ctop * group_size + gh) * group_size + gw;
235 | // backward on feature
236 | int x0 = floor(w);
237 | int x1 = ceil(w);
238 | int y0 = floor(h);
239 | int y1 = ceil(h);
240 | T dist_x = w - x0, dist_y = h - y0;
241 | T q00 = (1 - dist_x) * (1 - dist_y);
242 | T q01 = (1 - dist_x) * dist_y;
243 | T q10 = dist_x * (1 - dist_y);
244 | T q11 = dist_x * dist_y;
245 | int bottom_index_base = c * height * width;
246 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
250 |
251 | if (no_trans)
252 | {
253 | continue;
254 | }
255 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
256 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
257 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
258 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
259 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
260 | diff_x *= roi_width;
261 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
262 | diff_y *= roi_height;
263 |
264 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
265 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
266 | }
267 | }
268 | }
269 | }
270 |
271 | std::tuple
272 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
273 | const at::Tensor &bbox,
274 | const at::Tensor &trans,
275 | const int no_trans,
276 | const float spatial_scale,
277 | const int output_dim,
278 | const int group_size,
279 | const int pooled_size,
280 | const int part_size,
281 | const int sample_per_part,
282 | const float trans_std)
283 | {
284 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
285 | AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor");
286 | AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
287 |
288 | const int batch = input.size(0);
289 | const int channels = input.size(1);
290 | const int height = input.size(2);
291 | const int width = input.size(3);
292 | const int channels_trans = no_trans ? 2 : trans.size(1);
293 | const int num_bbox = bbox.size(0);
294 |
295 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
296 | auto pooled_height = pooled_size;
297 | auto pooled_width = pooled_size;
298 |
299 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
300 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
301 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
302 |
303 | const int num_classes = no_trans ? 1 : channels_trans / 2;
304 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
305 |
306 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
307 |
308 | if (out.numel() == 0)
309 | {
310 | THCudaCheck(cudaGetLastError());
311 | return std::make_tuple(out, top_count);
312 | }
313 |
314 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
315 | dim3 block(512);
316 |
317 | AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] {
318 | DeformablePSROIPoolForwardKernel<<>>(
319 | out_size,
320 | input.contiguous().data(),
321 | spatial_scale,
322 | channels,
323 | height, width,
324 | pooled_height,
325 | pooled_width,
326 | bbox.contiguous().data(),
327 | trans.contiguous().data(),
328 | no_trans,
329 | trans_std,
330 | sample_per_part,
331 | output_dim,
332 | group_size,
333 | part_size,
334 | num_classes,
335 | channels_each_class,
336 | out.data(),
337 | top_count.data());
338 | });
339 | THCudaCheck(cudaGetLastError());
340 | return std::make_tuple(out, top_count);
341 | }
342 |
343 | std::tuple
344 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
345 | const at::Tensor &input,
346 | const at::Tensor &bbox,
347 | const at::Tensor &trans,
348 | const at::Tensor &top_count,
349 | const int no_trans,
350 | const float spatial_scale,
351 | const int output_dim,
352 | const int group_size,
353 | const int pooled_size,
354 | const int part_size,
355 | const int sample_per_part,
356 | const float trans_std)
357 | {
358 | AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor");
359 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
360 | AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor");
361 | AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
362 | AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");
363 |
364 | const int batch = input.size(0);
365 | const int channels = input.size(1);
366 | const int height = input.size(2);
367 | const int width = input.size(3);
368 | const int channels_trans = no_trans ? 2 : trans.size(1);
369 | const int num_bbox = bbox.size(0);
370 |
371 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
372 | auto pooled_height = pooled_size;
373 | auto pooled_width = pooled_size;
374 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
375 | const int num_classes = no_trans ? 1 : channels_trans / 2;
376 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
377 |
378 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
379 | auto trans_grad = at::zeros_like(trans);
380 |
381 | if (input_grad.numel() == 0)
382 | {
383 | THCudaCheck(cudaGetLastError());
384 | return std::make_tuple(input_grad, trans_grad);
385 | }
386 |
387 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
388 | dim3 block(512);
389 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
390 |
391 | AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] {
392 | DeformablePSROIPoolBackwardAccKernel<<>>(
393 | out_size,
394 | out_grad.contiguous().data(),
395 | top_count.contiguous().data(),
396 | num_bbox,
397 | spatial_scale,
398 | channels,
399 | height,
400 | width,
401 | pooled_height,
402 | pooled_width,
403 | output_dim,
404 | input_grad.contiguous().data(),
405 | trans_grad.contiguous().data(),
406 | input.contiguous().data(),
407 | bbox.contiguous().data(),
408 | trans.contiguous().data(),
409 | no_trans,
410 | trans_std,
411 | sample_per_part,
412 | group_size,
413 | part_size,
414 | num_classes,
415 | channels_each_class);
416 | });
417 | THCudaCheck(cudaGetLastError());
418 | return std::make_tuple(input_grad, trans_grad);
419 | }
--------------------------------------------------------------------------------
/models/DCNv2/src/cuda/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cuda_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cuda_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/models/DCNv2/src/dcn_v2.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "cpu/vision.h"
4 |
5 | #ifdef WITH_CUDA
6 | #include "cuda/vision.h"
7 | #endif
8 |
9 | at::Tensor
10 | dcn_v2_forward(const at::Tensor &input,
11 | const at::Tensor &weight,
12 | const at::Tensor &bias,
13 | const at::Tensor &offset,
14 | const at::Tensor &mask,
15 | const int kernel_h,
16 | const int kernel_w,
17 | const int stride_h,
18 | const int stride_w,
19 | const int pad_h,
20 | const int pad_w,
21 | const int dilation_h,
22 | const int dilation_w,
23 | const int deformable_group)
24 | {
25 | if (input.type().is_cuda())
26 | {
27 | #ifdef WITH_CUDA
28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
29 | kernel_h, kernel_w,
30 | stride_h, stride_w,
31 | pad_h, pad_w,
32 | dilation_h, dilation_w,
33 | deformable_group);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | dcn_v2_backward(const at::Tensor &input,
43 | const at::Tensor &weight,
44 | const at::Tensor &bias,
45 | const at::Tensor &offset,
46 | const at::Tensor &mask,
47 | const at::Tensor &grad_output,
48 | int kernel_h, int kernel_w,
49 | int stride_h, int stride_w,
50 | int pad_h, int pad_w,
51 | int dilation_h, int dilation_w,
52 | int deformable_group)
53 | {
54 | if (input.type().is_cuda())
55 | {
56 | #ifdef WITH_CUDA
57 | return dcn_v2_cuda_backward(input,
58 | weight,
59 | bias,
60 | offset,
61 | mask,
62 | grad_output,
63 | kernel_h, kernel_w,
64 | stride_h, stride_w,
65 | pad_h, pad_w,
66 | dilation_h, dilation_w,
67 | deformable_group);
68 | #else
69 | AT_ERROR("Not compiled with GPU support");
70 | #endif
71 | }
72 | AT_ERROR("Not implemented on the CPU");
73 | }
74 |
75 | std::tuple
76 | dcn_v2_psroi_pooling_forward(const at::Tensor &input,
77 | const at::Tensor &bbox,
78 | const at::Tensor &trans,
79 | const int no_trans,
80 | const float spatial_scale,
81 | const int output_dim,
82 | const int group_size,
83 | const int pooled_size,
84 | const int part_size,
85 | const int sample_per_part,
86 | const float trans_std)
87 | {
88 | if (input.type().is_cuda())
89 | {
90 | #ifdef WITH_CUDA
91 | return dcn_v2_psroi_pooling_cuda_forward(input,
92 | bbox,
93 | trans,
94 | no_trans,
95 | spatial_scale,
96 | output_dim,
97 | group_size,
98 | pooled_size,
99 | part_size,
100 | sample_per_part,
101 | trans_std);
102 | #else
103 | AT_ERROR("Not compiled with GPU support");
104 | #endif
105 | }
106 | AT_ERROR("Not implemented on the CPU");
107 | }
108 |
109 | std::tuple
110 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
111 | const at::Tensor &input,
112 | const at::Tensor &bbox,
113 | const at::Tensor &trans,
114 | const at::Tensor &top_count,
115 | const int no_trans,
116 | const float spatial_scale,
117 | const int output_dim,
118 | const int group_size,
119 | const int pooled_size,
120 | const int part_size,
121 | const int sample_per_part,
122 | const float trans_std)
123 | {
124 | if (input.type().is_cuda())
125 | {
126 | #ifdef WITH_CUDA
127 | return dcn_v2_psroi_pooling_cuda_backward(out_grad,
128 | input,
129 | bbox,
130 | trans,
131 | top_count,
132 | no_trans,
133 | spatial_scale,
134 | output_dim,
135 | group_size,
136 | pooled_size,
137 | part_size,
138 | sample_per_part,
139 | trans_std);
140 | #else
141 | AT_ERROR("Not compiled with GPU support");
142 | #endif
143 | }
144 | AT_ERROR("Not implemented on the CPU");
145 | }
--------------------------------------------------------------------------------
/models/DCNv2/src/vision.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "dcn_v2.h"
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
9 | }
10 |
--------------------------------------------------------------------------------
/models/ResBlock.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class GlobalGenerator(nn.Module):
5 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
6 | padding_type='reflect'):
7 | assert (n_blocks >= 0)
8 | super(GlobalGenerator, self).__init__()
9 | activation = nn.ReLU(True)
10 |
11 | model = []
12 |
13 | ### resnet blocks
14 | mult = 2 ** n_downsampling
15 | for i in range(n_blocks):
16 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation)]
17 |
18 | ### upsample
19 | for i in range(n_downsampling):
20 | mult = 2 ** (n_downsampling - i)
21 | # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
22 | # output_padding=1),
23 | # norm_layer(int(ngf * mult / 2)), activation]
24 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
25 | output_padding=1), activation]
26 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
27 | self.model = nn.Sequential(*model)
28 |
29 | def forward(self, input):
30 | return self.model(input)
31 |
32 |
33 | class ResnetBlock(nn.Module):
34 | def __init__(self, inchannel, padding_type='reflect', activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False):
35 | super(ResnetBlock, self).__init__()
36 | self.conv_block = self.build_conv_block(inchannel, padding_type, activation, use_dropout)
37 |
38 | def build_conv_block(self, inchannel, padding_type, activation, use_dropout):
39 | conv_block = []
40 | p = 0
41 | if padding_type == 'reflect':
42 | conv_block += [nn.ReflectionPad2d(1)]
43 | elif padding_type == 'replicate':
44 | conv_block += [nn.ReplicationPad2d(1)]
45 | elif padding_type == 'zero':
46 | p = 1
47 | else:
48 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
49 |
50 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p),
51 | activation]
52 | if use_dropout:
53 | conv_block += [nn.Dropout(0.5)]
54 |
55 | p = 0
56 | if padding_type == 'reflect':
57 | conv_block += [nn.ReflectionPad2d(1)]
58 | elif padding_type == 'replicate':
59 | conv_block += [nn.ReplicationPad2d(1)]
60 | elif padding_type == 'zero':
61 | p = 1
62 | else:
63 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
64 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), activation]
65 |
66 | return nn.Sequential(*conv_block)
67 |
68 | def forward(self, x):
69 | out = x + self.conv_block(x)
70 | return out
71 |
72 |
73 | class ResnetBlock_bn(nn.Module):
74 | def __init__(self, inchannel, padding_type='reflect', norm_layer=nn.BatchNorm2d, activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False):
75 | super(ResnetBlock_bn, self).__init__()
76 | self.conv_block = self.build_conv_block(inchannel, padding_type, norm_layer, activation, use_dropout)
77 |
78 | def build_conv_block(self, inchannel, padding_type, norm_layer, activation, use_dropout):
79 | conv_block = []
80 | p = 0
81 | if padding_type == 'reflect':
82 | conv_block += [nn.ReflectionPad2d(1)]
83 | elif padding_type == 'replicate':
84 | conv_block += [nn.ReplicationPad2d(1)]
85 | elif padding_type == 'zero':
86 | p = 1
87 | else:
88 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
89 |
90 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p),
91 | norm_layer(inchannel),
92 | activation]
93 | if use_dropout:
94 | conv_block += [nn.Dropout(0.5)]
95 |
96 | p = 0
97 | if padding_type == 'reflect':
98 | conv_block += [nn.ReflectionPad2d(1)]
99 | elif padding_type == 'replicate':
100 | conv_block += [nn.ReplicationPad2d(1)]
101 | elif padding_type == 'zero':
102 | p = 1
103 | else:
104 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
105 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p),
106 | norm_layer(inchannel)]
107 |
108 | return nn.Sequential(*conv_block)
109 |
110 | def forward(self, x):
111 | out = x + self.conv_block(x)
112 | return out
113 |
114 |
115 | class SemiResnetBlock(nn.Module):
116 | def __init__(self, inchannel, outchannel, padding_type='reflect', activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False, end=False):
117 | super(SemiResnetBlock, self).__init__()
118 | self.conv_block1 = self.build_conv_block1(inchannel, padding_type, activation, use_dropout)
119 | self.conv_block2 = self.build_conv_block2(inchannel, outchannel, padding_type, activation,
120 | use_dropout, end)
121 |
122 | def build_conv_block1(self, inchannel, padding_type, activation, use_dropout):
123 | conv_block = []
124 | p = 0
125 | if padding_type == 'reflect':
126 | conv_block += [nn.ReflectionPad2d(1)]
127 | elif padding_type == 'replicate':
128 | conv_block += [nn.ReplicationPad2d(1)]
129 | elif padding_type == 'zero':
130 | p = 1
131 | else:
132 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
133 |
134 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p),
135 | activation]
136 | if use_dropout:
137 | conv_block += [nn.Dropout(0.5)]
138 |
139 | return nn.Sequential(*conv_block)
140 |
141 | def build_conv_block2(self, inchannel, outchannel, padding_type, activation, use_dropout, end):
142 | conv_block = []
143 | p = 0
144 | if padding_type == 'reflect':
145 | conv_block += [nn.ReflectionPad2d(1)]
146 | elif padding_type == 'replicate':
147 | conv_block += [nn.ReplicationPad2d(1)]
148 | elif padding_type == 'zero':
149 | p = 1
150 | else:
151 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
152 | if end:
153 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p)]
154 | else:
155 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p),
156 | activation]
157 | return nn.Sequential(*conv_block)
158 |
159 | def forward(self, x):
160 | x = self.conv_block1(x) + x
161 | out = self.conv_block2(x)
162 | return out
163 |
164 | class SemiResnetBlock_bn(nn.Module):
165 | def __init__(self, inchannel, outchannel, padding_type='reflect', norm_layer=nn.BatchNorm2d, activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False, end=False):
166 | super(SemiResnetBlock_bn, self).__init__()
167 | self.conv_block1 = self.build_conv_block1(inchannel, padding_type, norm_layer, activation, use_dropout)
168 | self.conv_block2 = self.build_conv_block2(inchannel, outchannel, padding_type, norm_layer, activation,
169 | use_dropout, end)
170 |
171 | def build_conv_block1(self, inchannel, padding_type, norm_layer, activation, use_dropout):
172 | conv_block = []
173 | p = 0
174 | if padding_type == 'reflect':
175 | conv_block += [nn.ReflectionPad2d(1)]
176 | elif padding_type == 'replicate':
177 | conv_block += [nn.ReplicationPad2d(1)]
178 | elif padding_type == 'zero':
179 | p = 1
180 | else:
181 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
182 |
183 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p),
184 | norm_layer(inchannel),
185 | activation]
186 | if use_dropout:
187 | conv_block += [nn.Dropout(0.5)]
188 |
189 | return nn.Sequential(*conv_block)
190 |
191 | def build_conv_block2(self, inchannel, outchannel, padding_type, norm_layer, activation, use_dropout, end):
192 | conv_block = []
193 | p = 0
194 | if padding_type == 'reflect':
195 | conv_block += [nn.ReflectionPad2d(1)]
196 | elif padding_type == 'replicate':
197 | conv_block += [nn.ReplicationPad2d(1)]
198 | elif padding_type == 'zero':
199 | p = 1
200 | else:
201 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
202 | if end:
203 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p)]
204 | else:
205 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p),
206 | norm_layer(outchannel), activation]
207 | return nn.Sequential(*conv_block)
208 |
209 | def forward(self, x):
210 | x = self.conv_block1(x) + x
211 | out = self.conv_block2(x)
212 | return out
213 |
--------------------------------------------------------------------------------
/models/Unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class down(nn.Module):
7 | """
8 | A class for creating neural network blocks containing layers:
9 |
10 | Average Pooling --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU
11 |
12 | This is used in the UNet Class to create a UNet like NN architecture.
13 |
14 | ...
15 |
16 | Methods
17 | -------
18 | forward(x)
19 | Returns output tensor after passing input `x` to the neural network
20 | block.
21 | """
22 |
23 |
24 | def __init__(self, inChannels, outChannels, filterSize):
25 | """
26 | Parameters
27 | ----------
28 | inChannels : int
29 | number of input channels for the first convolutional layer.
30 | outChannels : int
31 | number of output channels for the first convolutional layer.
32 | This is also used as input and output channels for the
33 | second convolutional layer.
34 | filterSize : int
35 | filter size for the convolution filter. input N would create
36 | a N x N filter.
37 | """
38 |
39 |
40 | super(down, self).__init__()
41 | # Initialize convolutional layers.
42 | self.conv1 = nn.Conv2d(inChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2))
43 | self.conv2 = nn.Conv2d(outChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2))
44 |
45 | def forward(self, x):
46 | """
47 | Returns output tensor after passing input `x` to the neural network
48 | block.
49 |
50 | Parameters
51 | ----------
52 | x : tensor
53 | input to the NN block.
54 |
55 | Returns
56 | -------
57 | tensor
58 | output of the NN block.
59 | """
60 |
61 |
62 | # Average pooling with kernel size 2 (2 x 2).
63 | x = F.avg_pool2d(x, 2)
64 | # Convolution + Leaky ReLU
65 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
66 | # Convolution + Leaky ReLU
67 | x = F.leaky_relu(self.conv2(x), negative_slope = 0.1)
68 | return x
69 |
70 | class up(nn.Module):
71 | """
72 | A class for creating neural network blocks containing layers:
73 |
74 | Bilinear interpolation --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU
75 |
76 | This is used in the UNet Class to create a UNet like NN architecture.
77 |
78 | ...
79 |
80 | Methods
81 | -------
82 | forward(x, skpCn)
83 | Returns output tensor after passing input `x` to the neural network
84 | block.
85 | """
86 |
87 |
88 | def __init__(self, inChannels, outChannels):
89 | """
90 | Parameters
91 | ----------
92 | inChannels : int
93 | number of input channels for the first convolutional layer.
94 | outChannels : int
95 | number of output channels for the first convolutional layer.
96 | This is also used for setting input and output channels for
97 | the second convolutional layer.
98 | """
99 |
100 |
101 | super(up, self).__init__()
102 | # Initialize convolutional layers.
103 | self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
104 | # (2 * outChannels) is used for accommodating skip connection.
105 | self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
106 |
107 | def forward(self, x, skpCn):
108 | """
109 | Returns output tensor after passing input `x` to the neural network
110 | block.
111 |
112 | Parameters
113 | ----------
114 | x : tensor
115 | input to the NN block.
116 | skpCn : tensor
117 | skip connection input to the NN block.
118 |
119 | Returns
120 | -------
121 | tensor
122 | output of the NN block.
123 | """
124 |
125 | # Bilinear interpolation with scaling 2.
126 | x = F.interpolate(x, scale_factor=2, mode='bilinear')
127 | # Convolution + Leaky ReLU
128 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
129 | # Convolution + Leaky ReLU on (`x`, `skpCn`)
130 | x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1)
131 | return x
132 |
133 |
134 |
135 | class UNet(nn.Module):
136 | """
137 | A class for creating UNet like architecture as specified by the
138 | Super SloMo paper.
139 |
140 | ...
141 |
142 | Methods
143 | -------
144 | forward(x)
145 | Returns output tensor after passing input `x` to the neural network
146 | block.
147 | """
148 |
149 |
150 | def __init__(self, inChannels, outChannels):
151 | """
152 | Parameters
153 | ----------
154 | inChannels : int
155 | number of input channels for the UNet.
156 | outChannels : int
157 | number of output channels for the UNet.
158 | """
159 |
160 |
161 | super(UNet, self).__init__()
162 | # Initialize neural network blocks.
163 | self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3)
164 | self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3)
165 | self.down1 = down(32, 64, 5)
166 | self.down2 = down(64, 128, 3)
167 | self.down3 = down(128, 256, 3)
168 | self.down4 = down(256, 512, 3)
169 | self.down5 = down(512, 512, 3)
170 | self.up1 = up(512, 512)
171 | self.up2 = up(512, 256)
172 | self.up3 = up(256, 128)
173 | self.up4 = up(128, 64)
174 | self.up5 = up(64, 32)
175 | self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
176 |
177 | def forward(self, x):
178 | """
179 | Returns output tensor after passing input `x` to the neural network.
180 |
181 | Parameters
182 | ----------
183 | x : tensor
184 | input to the UNet.
185 |
186 | Returns
187 | -------
188 | tensor
189 | output of the UNet.
190 | """
191 |
192 |
193 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
194 | s1 = F.leaky_relu(self.conv2(x), negative_slope = 0.1)
195 | s2 = self.down1(s1)
196 | s3 = self.down2(s2)
197 | s4 = self.down3(s3)
198 | s5 = self.down4(s4)
199 | x = self.down5(s5)
200 | x = self.up1(x, s5)
201 | x = self.up2(x, s4)
202 | x = self.up3(x, s3)
203 | x = self.up4(x, s2)
204 | x = self.up5(x, s1)
205 | x = F.leaky_relu(self.conv3(x), negative_slope = 0.1)
206 | return x
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/models/__init__.py
--------------------------------------------------------------------------------
/models/bdcn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/models/bdcn/__init__.py
--------------------------------------------------------------------------------
/models/bdcn/bdcn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | import models.bdcn.vgg16_c as vgg16_c
6 |
7 | def crop(data1, data2, crop_h, crop_w):
8 | _, _, h1, w1 = data1.size()
9 | _, _, h2, w2 = data2.size()
10 | assert(h2 <= h1 and w2 <= w1)
11 | data = data1[:, :, crop_h:crop_h+h2, crop_w:crop_w+w2]
12 | return data
13 |
14 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
15 | """Make a 2D bilinear kernel suitable for upsampling"""
16 | factor = (kernel_size + 1) // 2
17 | if kernel_size % 2 == 1:
18 | center = factor - 1
19 | else:
20 | center = factor - 0.5
21 | og = np.ogrid[:kernel_size, :kernel_size]
22 | filt = (1 - abs(og[0] - center) / factor) * \
23 | (1 - abs(og[1] - center) / factor)
24 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
25 | dtype=np.float64)
26 | weight[range(in_channels), range(out_channels), :, :] = filt
27 | return torch.from_numpy(weight).float()
28 |
29 | class MSBlock(nn.Module):
30 | def __init__(self, c_in, rate=4):
31 | super(MSBlock, self).__init__()
32 | c_out = c_in
33 | self.rate = rate
34 |
35 | self.conv = nn.Conv2d(c_in, 32, 3, stride=1, padding=1)
36 | self.relu = nn.ReLU(inplace=True)
37 | dilation = self.rate*1 if self.rate >= 1 else 1
38 | self.conv1 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation)
39 | self.relu1 = nn.ReLU(inplace=True)
40 | dilation = self.rate*2 if self.rate >= 1 else 1
41 | self.conv2 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation)
42 | self.relu2 = nn.ReLU(inplace=True)
43 | dilation = self.rate*3 if self.rate >= 1 else 1
44 | self.conv3 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation)
45 | self.relu3 = nn.ReLU(inplace=True)
46 |
47 | self._initialize_weights()
48 |
49 | def forward(self, x):
50 | o = self.relu(self.conv(x))
51 | o1 = self.relu1(self.conv1(o))
52 | o2 = self.relu2(self.conv2(o))
53 | o3 = self.relu3(self.conv3(o))
54 | out = o + o1 + o2 + o3
55 | return out
56 |
57 | def _initialize_weights(self):
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | m.weight.data.normal_(0, 0.01)
61 | if m.bias is not None:
62 | m.bias.data.zero_()
63 |
64 |
65 | class BDCN(nn.Module):
66 | def __init__(self, pretrain=None, logger=None, rate=4):
67 | super(BDCN, self).__init__()
68 | self.pretrain = pretrain
69 | t = 1
70 |
71 | self.features = vgg16_c.VGG16_C(pretrain, logger)
72 | self.msblock1_1 = MSBlock(64, rate)
73 | self.msblock1_2 = MSBlock(64, rate)
74 | self.conv1_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
75 | self.conv1_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
76 | self.score_dsn1 = nn.Conv2d(21, 1, (1, 1), stride=1)
77 | self.score_dsn1_1 = nn.Conv2d(21, 1, 1, stride=1)
78 | self.msblock2_1 = MSBlock(128, rate)
79 | self.msblock2_2 = MSBlock(128, rate)
80 | self.conv2_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
81 | self.conv2_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
82 | self.score_dsn2 = nn.Conv2d(21, 1, (1, 1), stride=1)
83 | self.score_dsn2_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
84 | self.msblock3_1 = MSBlock(256, rate)
85 | self.msblock3_2 = MSBlock(256, rate)
86 | self.msblock3_3 = MSBlock(256, rate)
87 | self.conv3_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
88 | self.conv3_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
89 | self.conv3_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
90 | self.score_dsn3 = nn.Conv2d(21, 1, (1, 1), stride=1)
91 | self.score_dsn3_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
92 | self.msblock4_1 = MSBlock(512, rate)
93 | self.msblock4_2 = MSBlock(512, rate)
94 | self.msblock4_3 = MSBlock(512, rate)
95 | self.conv4_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
96 | self.conv4_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
97 | self.conv4_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
98 | self.score_dsn4 = nn.Conv2d(21, 1, (1, 1), stride=1)
99 | self.score_dsn4_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
100 | self.msblock5_1 = MSBlock(512, rate)
101 | self.msblock5_2 = MSBlock(512, rate)
102 | self.msblock5_3 = MSBlock(512, rate)
103 | self.conv5_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
104 | self.conv5_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
105 | self.conv5_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
106 | self.score_dsn5 = nn.Conv2d(21, 1, (1, 1), stride=1)
107 | self.score_dsn5_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
108 | self.upsample_2 = nn.ConvTranspose2d(1, 1, 4, stride=2, bias=False)
109 | self.upsample_4 = nn.ConvTranspose2d(1, 1, 8, stride=4, bias=False)
110 | self.upsample_8 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False)
111 | self.upsample_8_5 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False)
112 | self.fuse = nn.Conv2d(10, 1, 1, stride=1)
113 |
114 | self._initialize_weights(logger)
115 |
116 | def forward(self, x):
117 | features = self.features(x)
118 | sum1 = self.conv1_1_down(self.msblock1_1(features[0])) + \
119 | self.conv1_2_down(self.msblock1_2(features[1]))
120 | s1 = self.score_dsn1(sum1)
121 | s11 = self.score_dsn1_1(sum1)
122 | # print(s1.data.shape, s11.data.shape)
123 | sum2 = self.conv2_1_down(self.msblock2_1(features[2])) + \
124 | self.conv2_2_down(self.msblock2_2(features[3]))
125 | s2 = self.score_dsn2(sum2)
126 | s21 = self.score_dsn2_1(sum2)
127 | s2 = self.upsample_2(s2)
128 | s21 = self.upsample_2(s21)
129 | # print(s2.data.shape, s21.data.shape)
130 | s2 = crop(s2, x, 1, 1)
131 | s21 = crop(s21, x, 1, 1)
132 | sum3 = self.conv3_1_down(self.msblock3_1(features[4])) + \
133 | self.conv3_2_down(self.msblock3_2(features[5])) + \
134 | self.conv3_3_down(self.msblock3_3(features[6]))
135 | s3 = self.score_dsn3(sum3)
136 | s3 =self.upsample_4(s3)
137 | # print(s3.data.shape)
138 | s3 = crop(s3, x, 2, 2)
139 | s31 = self.score_dsn3_1(sum3)
140 | s31 =self.upsample_4(s31)
141 | # print(s31.data.shape)
142 | s31 = crop(s31, x, 2, 2)
143 | sum4 = self.conv4_1_down(self.msblock4_1(features[7])) + \
144 | self.conv4_2_down(self.msblock4_2(features[8])) + \
145 | self.conv4_3_down(self.msblock4_3(features[9]))
146 | s4 = self.score_dsn4(sum4)
147 | s4 = self.upsample_8(s4)
148 | # print(s4.data.shape)
149 | s4 = crop(s4, x, 4, 4)
150 | s41 = self.score_dsn4_1(sum4)
151 | s41 = self.upsample_8(s41)
152 | # print(s41.data.shape)
153 | s41 = crop(s41, x, 4, 4)
154 | sum5 = self.conv5_1_down(self.msblock5_1(features[10])) + \
155 | self.conv5_2_down(self.msblock5_2(features[11])) + \
156 | self.conv5_3_down(self.msblock5_3(features[12]))
157 | s5 = self.score_dsn5(sum5)
158 | s5 = self.upsample_8_5(s5)
159 | # print(s5.data.shape)
160 | s5 = crop(s5, x, 0, 0)
161 | s51 = self.score_dsn5_1(sum5)
162 | s51 = self.upsample_8_5(s51)
163 | # print(s51.data.shape)
164 | s51 = crop(s51, x, 0, 0)
165 | o1, o2, o3, o4, o5 = s1.detach(), s2.detach(), s3.detach(), s4.detach(), s5.detach()
166 | o11, o21, o31, o41, o51 = s11.detach(), s21.detach(), s31.detach(), s41.detach(), s51.detach()
167 | p1_1 = s1
168 | p2_1 = s2 + o1
169 | p3_1 = s3 + o2 + o1
170 | p4_1 = s4 + o3 + o2 + o1
171 | p5_1 = s5 + o4 + o3 + o2 + o1
172 | p1_2 = s11 + o21 + o31 + o41 + o51
173 | p2_2 = s21 + o31 + o41 + o51
174 | p3_2 = s31 + o41 + o51
175 | p4_2 = s41 + o51
176 | p5_2 = s51
177 |
178 | fuse = self.fuse(torch.cat([p1_1, p2_1, p3_1, p4_1, p5_1, p1_2, p2_2, p3_2, p4_2, p5_2], 1))
179 |
180 | return [p1_1, p2_1, p3_1, p4_1, p5_1, p1_2, p2_2, p3_2, p4_2, p5_2, fuse]
181 |
182 | def _initialize_weights(self, logger=None):
183 | for name, param in self.state_dict().items():
184 | if self.pretrain and 'features' in name:
185 | continue
186 | # elif 'down' in name:
187 | # param.zero_()
188 | elif 'upsample' in name:
189 | if logger:
190 | logger.info('init upsamle layer %s ' % name)
191 | k = int(name.split('.')[0].split('_')[1])
192 | param.copy_(get_upsampling_weight(1, 1, k*2))
193 | elif 'fuse' in name:
194 | if logger:
195 | logger.info('init params %s ' % name)
196 | if 'bias' in name:
197 | param.zero_()
198 | else:
199 | nn.init.constant(param, 0.080)
200 | else:
201 | if logger:
202 | logger.info('init params %s ' % name)
203 | if 'bias' in name:
204 | param.zero_()
205 | else:
206 | param.normal_(0, 0.01)
207 | # print self.conv1_1_down.weight
208 |
209 | if __name__ == '__main__':
210 | model = BDCN('./caffemodel2pytorch/vgg16.pth')
211 | a=torch.rand((2,3,100,100))
212 | a=torch.autograd.Variable(a)
213 | for x in model(a):
214 | print(x.data.shape)
215 | # for name, param in model.state_dict().items():
216 | # print name, param
217 |
--------------------------------------------------------------------------------
/models/bdcn/vgg16_c.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision
4 | import torch.nn as nn
5 | import math
6 |
7 | class VGG16_C(nn.Module):
8 | """"""
9 | def __init__(self, pretrain=None, logger=None):
10 | super(VGG16_C, self).__init__()
11 | self.conv1_1 = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1)
12 | self.relu1_1 = nn.ReLU(inplace=True)
13 | self.conv1_2 = nn.Conv2d(64, 64, (3, 3), stride=1, padding=1)
14 | self.relu1_2 = nn.ReLU(inplace=True)
15 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
16 | self.conv2_1 = nn.Conv2d(64, 128, (3, 3), stride=1, padding=1)
17 | self.relu2_1 = nn.ReLU(inplace=True)
18 | self.conv2_2 = nn.Conv2d(128, 128, (3, 3), stride=1, padding=1)
19 | self.relu2_2 = nn.ReLU(inplace=True)
20 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
21 | self.conv3_1 = nn.Conv2d(128, 256, (3, 3), stride=1, padding=1)
22 | self.relu3_1 = nn.ReLU(inplace=True)
23 | self.conv3_2 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1)
24 | self.relu3_2 = nn.ReLU(inplace=True)
25 | self.conv3_3 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1)
26 | self.relu3_3 = nn.ReLU(inplace=True)
27 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
28 | self.conv4_1 = nn.Conv2d(256, 512, (3, 3), stride=1, padding=1)
29 | self.relu4_1 = nn.ReLU(inplace=True)
30 | self.conv4_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1)
31 | self.relu4_2 = nn.ReLU(inplace=True)
32 | self.conv4_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1)
33 | self.relu4_3 = nn.ReLU(inplace=True)
34 | self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True)
35 | self.conv5_1 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2)
36 | self.relu5_1 = nn.ReLU(inplace=True)
37 | self.conv5_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2)
38 | self.relu5_2 = nn.ReLU(inplace=True)
39 | self.conv5_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2)
40 | self.relu5_3 = nn.ReLU(inplace=True)
41 | if pretrain:
42 | if '.npy' in pretrain:
43 | state_dict = np.load(pretrain).item()
44 | for k in state_dict:
45 | state_dict[k] = torch.from_numpy(state_dict[k])
46 | else:
47 | state_dict = torch.load(pretrain)
48 | own_state_dict = self.state_dict()
49 | for name, param in own_state_dict.items():
50 | if name in state_dict:
51 | if logger:
52 | logger.info('copy the weights of %s from pretrained model' % name)
53 | param.copy_(state_dict[name])
54 | else:
55 | if logger:
56 | logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\
57 | % name)
58 | if 'bias' in name:
59 | param.zero_()
60 | else:
61 | param.normal_(0, 0.01)
62 | else:
63 | self._initialize_weights(logger)
64 |
65 | def forward(self, x):
66 | conv1_1 = self.relu1_1(self.conv1_1(x))
67 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
68 | pool1 = self.pool1(conv1_2)
69 | conv2_1 = self.relu2_1(self.conv2_1(pool1))
70 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
71 | pool2 = self.pool2(conv2_2)
72 | conv3_1 = self.relu3_1(self.conv3_1(pool2))
73 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
74 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
75 | pool3 = self.pool3(conv3_3)
76 | conv4_1 = self.relu4_1(self.conv4_1(pool3))
77 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
78 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
79 | pool4 = self.pool4(conv4_3)
80 | # pool4 = conv4_3
81 | conv5_1 = self.relu5_1(self.conv5_1(pool4))
82 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
83 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
84 |
85 | side = [conv1_1, conv1_2, conv2_1, conv2_2,
86 | conv3_1, conv3_2, conv3_3, conv4_1,
87 | conv4_2, conv4_3, conv5_1, conv5_2, conv5_3]
88 | return side
89 |
90 | def _initialize_weights(self, logger=None):
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv2d):
93 | if logger:
94 | logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\
95 | % m)
96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
97 | m.weight.data.normal_(0, math.sqrt(2. / n))
98 | if m.bias is not None:
99 | m.bias.data.zero_()
100 | elif isinstance(m, nn.BatchNorm2d):
101 | m.weight.data.fill_(1)
102 | m.bias.data.zero_()
103 | elif isinstance(m, nn.Linear):
104 | m.weight.data.normal_(0, 0.01)
105 | m.bias.data.zero_()
106 |
107 | if __name__ == '__main__':
108 | model = VGG16_C()
109 | # im = np.zeros((1,3,100,100))
110 | # out = model(Variable(torch.from_numpy(im)))
111 |
112 |
113 |
--------------------------------------------------------------------------------
/models/warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class backWarp(nn.Module):
8 | """
9 | A class for creating a backwarping object.
10 |
11 | This is used for backwarping to an image:
12 |
13 | Given optical flow from frame I0 to I1 --> F_0_1 and frame I1,
14 | it generates I0 <-- backwarp(F_0_1, I1).
15 |
16 | ...
17 |
18 | Methods
19 | -------
20 | forward(x)
21 | Returns output tensor after passing input `img` and `flow` to the backwarping
22 | block.
23 | """
24 |
25 |
26 | def __init__(self, H, W):
27 | """
28 | Parameters
29 | ----------
30 | W : int
31 | width of the image.
32 | H : int
33 | height of the image.
34 | device : device
35 | computation device (cpu/cuda).
36 | """
37 |
38 |
39 | super(backWarp, self).__init__()
40 | # create a grid
41 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
42 | self.W = W
43 | self.H = H
44 | self.gridX = torch.nn.Parameter(torch.tensor(gridX), requires_grad=False)
45 | self.gridY = torch.nn.Parameter(torch.tensor(gridY), requires_grad=False)
46 |
47 | def forward(self, img, flow):
48 | """
49 | Returns output tensor after passing input `img` and `flow` to the backwarping
50 | block.
51 | I0 = backwarp(I1, F_0_1)
52 |
53 | Parameters
54 | ----------
55 | img : tensor
56 | frame I1.
57 | flow : tensor
58 | optical flow from I0 and I1: F_0_1.
59 |
60 | Returns
61 | -------
62 | tensor
63 | frame I0.
64 | """
65 |
66 |
67 | # Extract horizontal and vertical flows.
68 | u = flow[:, 0, :, :]
69 | v = flow[:, 1, :, :]
70 | x = self.gridX.unsqueeze(0).expand_as(u).float() + u
71 | y = self.gridY.unsqueeze(0).expand_as(v).float() + v
72 | # range -1 to 1
73 | x = 2*(x/self.W - 0.5)
74 | y = 2*(y/self.H - 0.5)
75 | # stacking X and Y
76 | grid = torch.stack((x,y), dim=3)
77 | # Sample pixels using bilinear interpolation.
78 | imgOut = torch.nn.functional.grid_sample(img, grid, padding_mode='border')
79 | return imgOut
80 |
81 |
82 | # Creating an array of `t` values for the 7 intermediate frames between
83 | # reference frames I0 and I1.
84 | class Coeff(nn.Module):
85 |
86 | def __init__(self):
87 | super(Coeff, self).__init__()
88 | self.t = torch.nn.Parameter(torch.FloatTensor(np.linspace(0.125, 0.875, 7)), requires_grad=False)
89 |
90 | def getFlowCoeff (self, indices):
91 | """
92 | Gets flow coefficients used for calculating intermediate optical
93 | flows from optical flows between I0 and I1: F_0_1 and F_1_0.
94 |
95 | F_t_0 = C00 x F_0_1 + C01 x F_1_0
96 | F_t_1 = C10 x F_0_1 + C11 x F_1_0
97 |
98 | where,
99 | C00 = -(1 - t) x t
100 | C01 = t x t
101 | C10 = (1 - t) x (1 - t)
102 | C11 = -t x (1 - t)
103 |
104 | Parameters
105 | ----------
106 | indices : tensor
107 | indices corresponding to the intermediate frame positions
108 | of all samples in the batch.
109 | device : device
110 | computation device (cpu/cuda).
111 |
112 | Returns
113 | -------
114 | tensor
115 | coefficients C00, C01, C10, C11.
116 | """
117 |
118 |
119 | # Convert indices tensor to numpy array
120 | ind = indices.detach()
121 | C11 = C00 = - (1 - (self.t[ind])) * (self.t[ind])
122 | C01 = (self.t[ind]) * (self.t[ind])
123 | C10 = (1 - (self.t[ind])) * (1 - (self.t[ind]))
124 | return C00[None, None, None, :].permute(3, 0, 1, 2), C01[None, None, None, :].permute(3, 0, 1, 2), C10[None, None, None, :].permute(3, 0, 1, 2), C11[None, None, None, :].permute(3, 0, 1, 2)
125 |
126 | def getWarpCoeff (self, indices):
127 | """
128 | Gets coefficients used for calculating final intermediate
129 | frame `It_gen` from backwarped images using flows F_t_0 and F_t_1.
130 |
131 | It_gen = (C0 x V_t_0 x g_I_0_F_t_0 + C1 x V_t_1 x g_I_1_F_t_1) / (C0 x V_t_0 + C1 x V_t_1)
132 |
133 | where,
134 | C0 = 1 - t
135 | C1 = t
136 |
137 | V_t_0, V_t_1 --> visibility maps
138 | g_I_0_F_t_0, g_I_1_F_t_1 --> backwarped intermediate frames
139 |
140 | Parameters
141 | ----------
142 | indices : tensor
143 | indices corresponding to the intermediate frame positions
144 | of all samples in the batch.
145 | device : device
146 | computation device (cpu/cuda).
147 |
148 | Returns
149 | -------
150 | tensor
151 | coefficients C0 and C1.
152 | """
153 |
154 |
155 | # Convert indices tensor to numpy array
156 | ind = indices.detach()
157 | C0 = 1 - self.t[ind]
158 | C1 = self.t[ind]
159 | return C0[None, None, None, :].permute(3, 0, 1, 2), C1[None, None, None, :].permute(3, 0, 1, 2)
160 |
161 | def set_t(self, factor):
162 | ti = 1 / factor
163 | self.t = torch.nn.Parameter(torch.FloatTensor(np.linspace(ti, 1 - ti, factor - 1)), requires_grad=False)
--------------------------------------------------------------------------------
/paper/FeatureFlow.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/paper/FeatureFlow.pdf
--------------------------------------------------------------------------------
/paper/Supp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/paper/Supp.pdf
--------------------------------------------------------------------------------
/pure_run.py:
--------------------------------------------------------------------------------
1 | # SeDraw
2 | import argparse
3 | import os
4 | import torch
5 | import cv2
6 | import torchvision.transforms as transforms
7 | from skimage.measure import compare_psnr
8 | from PIL import Image
9 | import src.pure_network as layers
10 | from tqdm import tqdm
11 | import numpy as np
12 | import math
13 | import models.bdcn.bdcn as bdcn
14 |
15 | # For parsing commandline arguments
16 | def str2bool(v):
17 | if isinstance(v, bool):
18 | return v
19 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
20 | return True
21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
22 | return False
23 | else:
24 | raise argparse.ArgumentTypeError('Boolean value expected.')
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model')
28 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3')
29 | parser.add_argument('--bdcn_model', type=str, default='/home/visiting/Projects/citrine/SeDraw/models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth')
30 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.')
31 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint')
32 | parser.add_argument('--imgpath', type=str, required=True)
33 | parser.add_argument('--first', type=str, required=True)
34 | parser.add_argument('--second', type=str, required=True)
35 | parser.add_argument('--gt', type=str, required=True)
36 | args = parser.parse_args()
37 |
38 |
39 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0):
40 |
41 |
42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
43 | with open(path, 'rb') as f:
44 | img = Image.open(f)
45 | # Resize image if specified.
46 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img
47 | # Crop image if crop area specified.
48 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img
49 | # Flip image horizontally if specified.
50 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img
51 | return flipped_img.convert('RGB')
52 |
53 |
54 |
55 | bdcn = bdcn.BDCN()
56 | bdcn.cuda()
57 | structure_gen = layers.StructureGen(feature_level=args.feature_level)
58 | structure_gen.cuda()
59 | detail_enhance = layers.DetailEnhance()
60 | detail_enhance.cuda()
61 |
62 |
63 | # Channel wise mean calculated on adobe240-fps training dataset
64 | mean = [0.5, 0.5, 0.5]
65 | std = [0.5, 0.5, 0.5]
66 | normalize = transforms.Normalize(mean=mean,
67 | std=std)
68 | transform = transforms.Compose([transforms.ToTensor(), normalize])
69 |
70 | negmean = [-1 for x in mean]
71 | restd = [2, 2, 2]
72 | revNormalize = transforms.Normalize(mean=negmean, std=restd)
73 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
74 |
75 |
76 | def ToImage(frame0, frame1):
77 |
78 | with torch.no_grad():
79 |
80 | img0 = frame0.cuda()
81 | img1 = frame1.cuda()
82 |
83 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
84 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
85 | ref_imgt, _ = structure_gen((img0_e, img1_e))
86 | imgt = detail_enhance((img0, img1, ref_imgt))
87 | # imgt = detail_enhance((img0, img1, imgt))
88 | imgt = torch.clamp(imgt, max=1., min=-1.)
89 |
90 | return imgt
91 |
92 |
93 | def main():
94 | # initial
95 |
96 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
97 | dict1 = torch.load(args.checkpoint)
98 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
99 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)
100 |
101 | bdcn.eval()
102 | structure_gen.eval()
103 | detail_enhance.eval()
104 |
105 | IE = 0
106 | PSNR = 0
107 | count = 0
108 | for folder in tqdm(os.listdir(args.imgpath)):
109 | triple_path = os.path.join(args.imgpath, folder)
110 | if not (os.path.isdir(triple_path)):
111 | continue
112 | X0 = transform(_pil_loader('%s/%s' % (triple_path, args.first))).unsqueeze(0)
113 | X1 = transform(_pil_loader('%s/%s' % (triple_path, args.second))).unsqueeze(0)
114 |
115 | assert (X0.size(2) == X1.size(2))
116 | assert (X0.size(3) == X1.size(3))
117 |
118 | intWidth = X0.size(3)
119 | intHeight = X0.size(2)
120 | channel = X0.size(1)
121 | if not channel == 3:
122 | print('Not RGB image')
123 | continue
124 | count += 1
125 |
126 | # if intWidth != ((intWidth >> 4) << 4):
127 | # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary
128 | # intPaddingLeft = int((intWidth_pad - intWidth) / 2)
129 | # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
130 | # else:
131 | # intWidth_pad = intWidth
132 | # intPaddingLeft = 0
133 | # intPaddingRight = 0
134 | #
135 | # if intHeight != ((intHeight >> 4) << 4):
136 | # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary
137 | # intPaddingTop = int((intHeight_pad - intHeight) / 2)
138 | # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
139 | # else:
140 | # intHeight_pad = intHeight
141 | # intPaddingTop = 0
142 | # intPaddingBottom = 0
143 | #
144 | # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])
145 |
146 | # first, second = pader(X0), pader(X1)
147 | first, second = X0, X1
148 | imgt = ToImage(first, second)
149 |
150 | imgt_np = imgt.squeeze(0).cpu().numpy()#[:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
151 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
152 | if not os.path.isdir(triple_path):
153 | os.system('mkdir -p %s' % triple_path)
154 | cv2.imwrite(triple_path + '/SeDraw.png', imgt_png)
155 |
156 | rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png')))
157 | gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt)))
158 |
159 | diff_rgb = rec_rgb - gt_rgb
160 | avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2))
161 |
162 | mse = np.mean((diff_rgb) ** 2)
163 |
164 | PIXEL_MAX = 255.0
165 | psnr = compare_psnr(gt_rgb, rec_rgb, 255)
166 | print(folder, psnr)
167 |
168 | IE += avg_interp_error_abs
169 | PSNR += psnr
170 |
171 | # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr)
172 |
173 | IE = IE / count
174 | PSNR = PSNR / count
175 | print('Average IE/PSNR:', IE, PSNR)
176 |
177 | main()
178 |
179 |
180 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | addict==2.2.1
2 | apipkg==1.5
3 | asn1crypto==1.3.0
4 | attrs==19.3.0
5 | backcall==0.1.0
6 | bleach==3.1.4
7 | blis==0.4.1
8 | boto==2.49.0
9 | boto3==1.12.12
10 | botocore==1.15.12
11 | bz2file==0.98
12 | captum==0.2.0
13 | catalogue==1.0.0
14 | certifi==2020.4.5.1
15 | cffi==1.14.0
16 | chardet==3.0.4
17 | Click==7.0
18 | correlation-cuda==0.0.0
19 | cryptography==2.8
20 | cupy==7.2.0
21 | cycler==0.10.0
22 | cymem==2.0.3
23 | Cython==0.29.14
24 | decorator==4.4.2
25 | defusedxml==0.6.0
26 | depthflowprojection-cuda==0.0.0
27 | docutils==0.15.2
28 | en-core-web-sm==2.2.5
29 | entrypoints==0.3
30 | execnet==1.7.1
31 | fastrlock==0.4
32 | filelock==3.0.12
33 | filterinterpolation-cuda==0.0.0
34 | flowprojection-cuda==0.0.0
35 | gensim==3.8.0
36 | idna==2.9
37 | imageio==2.6.1
38 | imageio-ffmpeg==0.4.1
39 | importlib-metadata==1.5.0
40 | interpolation-cuda==0.0.0
41 | interpolationch-cuda==0.0.0
42 | ipykernel==5.1.4
43 | ipython==7.13.0
44 | ipython-genutils==0.2.0
45 | ipywidgets==7.5.1
46 | jedi==0.16.0
47 | jieba==0.42.1
48 | Jinja2==2.11.1
49 | jmespath==0.9.4
50 | joblib==0.14.1
51 | jsonpatch==1.25
52 | jsonpointer==2.0
53 | jsonschema==3.2.0
54 | jupyter==1.0.0
55 | jupyter-client==6.0.0
56 | jupyter-console==6.1.0
57 | jupyter-core==4.6.3
58 | kiwisolver==1.1.0
59 | lxml==4.5.0
60 | MarkupSafe==1.1.1
61 | matplotlib==3.1.3
62 | mindepthflowprojection-cuda==0.0.0
63 | mistune==0.8.4
64 | mkl-fft==1.0.15
65 | mkl-random==1.1.0
66 | mkl-service==2.3.0
67 | mmcv==0.2.16
68 | -e git+https://github.com/open-mmlab/mmdetection.git@93bed07bb65b24ae040d797a7696ecb994acf9f5#egg=mmdet
69 | more-itertools==8.2.0
70 | moviepy==1.0.1
71 | murmurhash==1.0.2
72 | nbconvert==5.6.1
73 | nbformat==5.0.4
74 | networkx==2.4
75 | nltk==3.4.5
76 | notebook==6.0.3
77 | numpy==1.18.1
78 | olefile==0.46
79 | opencv-python==4.2.0.32
80 | packaging==20.3
81 | pandocfilters==1.4.2
82 | parso==0.6.2
83 | pexpect==4.8.0
84 | pickleshare==0.7.5
85 | Pillow==6.2.2
86 | plac==1.1.3
87 | pluggy==0.13.1
88 | prefetch-generator==1.0.1
89 | preshed==3.0.2
90 | proglog==0.1.9
91 | prometheus-client==0.7.1
92 | prompt-toolkit==3.0.4
93 | ptyprocess==0.6.0
94 | py==1.8.1
95 | pycocotools==2.0
96 | pycparser==2.20
97 | Pygments==2.6.1
98 | pyOpenSSL==19.1.0
99 | pyparsing==2.4.6
100 | pyrsistent==0.15.7
101 | PySocks==1.7.1
102 | pytest==5.4.1
103 | pytest-forked==1.1.3
104 | pytest-xdist==1.31.0
105 | python-dateutil==2.8.1
106 | PyWavelets==1.1.1
107 | PyYAML==5.3
108 | pyzmq==18.1.1
109 | qtconsole==4.7.1
110 | QtPy==1.9.0
111 | regex==2020.1.8
112 | requests==2.23.0
113 | s3transfer==0.3.3
114 | sacremoses==0.0.38
115 | scikit-image==0.16.2
116 | scipy==1.4.1
117 | Send2Trash==1.5.0
118 | sentencepiece==0.1.85
119 | separableconv-cuda==0.0.0
120 | separableconvflow-cuda==0.0.0
121 | simplejson==3.17.0
122 | six==1.14.0
123 | smart-open==1.9.0
124 | spacy==2.2.4
125 | srsly==1.0.2
126 | terminado==0.8.3
127 | terminaltables==3.1.0
128 | testpath==0.4.4
129 | thinc==7.4.0
130 | tokenizers==0.5.2
131 | torch==1.2.0
132 | torchfile==0.1.0
133 | torchtext==0.5.0
134 | torchvision==0.4.0a0+6b959ee
135 | tornado==6.0.4
136 | tqdm==4.42.1
137 | traitlets==4.3.3
138 | -e git+https://github.com/huggingface/transformers@1789c7daf1b8013006b0aef6cb1b8f80573031c5#egg=transformers
139 | urllib3==1.25.8
140 | visdom==0.1.8.9
141 | wasabi==0.6.0
142 | wcwidth==0.1.8
143 | webencodings==0.5.1
144 | websocket-client==0.57.0
145 | widgetsnbextension==3.5.1
146 | zipp==3.1.0
147 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # SeDraw
2 | import argparse
3 | import os
4 | import torch
5 | import cv2
6 | import torchvision.transforms as transforms
7 | from skimage.measure import compare_psnr
8 | from PIL import Image
9 | import src.pure_network as layers
10 | from tqdm import tqdm
11 | import numpy as np
12 | import math
13 | import models.bdcn.bdcn as bdcn
14 |
15 | # For parsing commandline arguments
16 | def str2bool(v):
17 | if isinstance(v, bool):
18 | return v
19 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
20 | return True
21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
22 | return False
23 | else:
24 | raise argparse.ArgumentTypeError('Boolean value expected.')
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model')
28 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3')
29 | parser.add_argument('--bdcn_model', type=str, default='/home/visiting/Projects/citrine/SeDraw/models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth')
30 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.')
31 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint')
32 | parser.add_argument('--imgpath', type=str, required=True)
33 | parser.add_argument('--first', type=str, required=True)
34 | parser.add_argument('--second', type=str, required=True)
35 | parser.add_argument('--gt', type=str, required=True)
36 | args = parser.parse_args()
37 |
38 |
39 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0):
40 |
41 |
42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
43 | with open(path, 'rb') as f:
44 | img = Image.open(f)
45 | # Resize image if specified.
46 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img
47 | # Crop image if crop area specified.
48 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img
49 | # Flip image horizontally if specified.
50 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img
51 | return flipped_img.convert('RGB')
52 |
53 |
54 |
55 | bdcn = bdcn.BDCN()
56 | bdcn.cuda()
57 | structure_gen = layers.StructureGen(feature_level=args.feature_level)
58 | structure_gen.cuda()
59 | detail_enhance = layers.DetailEnhance()
60 | detail_enhance.cuda()
61 |
62 |
63 | # Channel wise mean calculated on adobe240-fps training dataset
64 | mean = [0.5, 0.5, 0.5]
65 | std = [0.5, 0.5, 0.5]
66 | normalize = transforms.Normalize(mean=mean,
67 | std=std)
68 | transform = transforms.Compose([transforms.ToTensor(), normalize])
69 |
70 | negmean = [-1 for x in mean]
71 | restd = [2, 2, 2]
72 | revNormalize = transforms.Normalize(mean=negmean, std=restd)
73 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
74 |
75 |
76 | def ToImage(frame0, frame1):
77 |
78 | with torch.no_grad():
79 |
80 | img0 = frame0.cuda()
81 | img1 = frame1.cuda()
82 |
83 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
84 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
85 | ref_imgt, edge_ref_imgt = structure_gen((img0_e, img1_e))
86 | imgt = detail_enhance((img0, img1, ref_imgt))
87 | # imgt = detail_enhance((img0, img1, imgt))
88 | imgt = torch.clamp(imgt, max=1., min=-1.)
89 |
90 | return imgt, ref_imgt, edge_ref_imgt, img0_e[:, 3:, :, :].repeat(1, 3, 1, 1), img1_e[:, 3:, :, :].repeat(1, 3, 1, 1)
91 |
92 |
93 | def main():
94 | # initial
95 |
96 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
97 | dict1 = torch.load(args.checkpoint)
98 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
99 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)
100 |
101 | bdcn.eval()
102 | structure_gen.eval()
103 | detail_enhance.eval()
104 |
105 | IE = 0
106 | PSNR = 0
107 | count = 0
108 | for folder in tqdm(os.listdir(args.imgpath)):
109 | triple_path = os.path.join(args.imgpath, folder)
110 | if not (os.path.isdir(triple_path)):
111 | continue
112 | X0 = transform(_pil_loader('%s/%s' % (triple_path, args.first))).unsqueeze(0)
113 | X1 = transform(_pil_loader('%s/%s' % (triple_path, args.second))).unsqueeze(0)
114 |
115 | assert (X0.size(2) == X1.size(2))
116 | assert (X0.size(3) == X1.size(3))
117 |
118 | intWidth = X0.size(3)
119 | intHeight = X0.size(2)
120 | channel = X0.size(1)
121 | if not channel == 3:
122 | print('Not RGB image')
123 | continue
124 | count += 1
125 |
126 | # if intWidth != ((intWidth >> 4) << 4):
127 | # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary
128 | # intPaddingLeft = int((intWidth_pad - intWidth) / 2)
129 | # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
130 | # else:
131 | # intWidth_pad = intWidth
132 | # intPaddingLeft = 0
133 | # intPaddingRight = 0
134 | #
135 | # if intHeight != ((intHeight >> 4) << 4):
136 | # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary
137 | # intPaddingTop = int((intHeight_pad - intHeight) / 2)
138 | # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
139 | # else:
140 | # intHeight_pad = intHeight
141 | # intPaddingTop = 0
142 | # intPaddingBottom = 0
143 | #
144 | # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])
145 |
146 | # first, second = pader(X0), pader(X1)
147 | first, second = X0, X1
148 | imgt, ref_imgt, edge_ref_imgt, edge0, edge1 = ToImage(first, second)
149 |
150 | imgt_np = imgt.squeeze(0).cpu().numpy()#[:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
151 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
152 | ref_imgt_np = ref_imgt.squeeze(
153 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
154 | ref_imgt_png = np.uint8(((ref_imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
155 | edge_ref_imgt_np = edge_ref_imgt.squeeze(
156 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
157 | edge_ref_imgt_png = np.uint8(((edge_ref_imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
158 | edge0_np = edge0.squeeze(
159 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
160 | edge0_png = np.uint8(((edge0_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
161 | edge1_np = edge1.squeeze(
162 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
163 | edge1_png = np.uint8(((edge1_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
164 | if not os.path.isdir(triple_path):
165 | os.system('mkdir -p %s' % triple_path)
166 | cv2.imwrite(triple_path + '/SeDraw.png', imgt_png)
167 | cv2.imwrite(triple_path + '/SeDraw_ref.png', ref_imgt_png)
168 | cv2.imwrite(triple_path + '/SeDraw_edge_ref.png', edge_ref_imgt_png)
169 | cv2.imwrite(triple_path + '/SeDraw_edge0.png', edge0_png)
170 | cv2.imwrite(triple_path + '/SeDraw_edge1.png', edge1_png)
171 |
172 | rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png')))
173 | gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt)))
174 |
175 | diff_rgb = rec_rgb - gt_rgb
176 | avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2))
177 |
178 | mse = np.mean((diff_rgb) ** 2)
179 |
180 | PIXEL_MAX = 255.0
181 | psnr = compare_psnr(gt_rgb, rec_rgb, 255)
182 | print(folder, psnr)
183 |
184 | IE += avg_interp_error_abs
185 | PSNR += psnr
186 |
187 | # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr)
188 |
189 | IE = IE / count
190 | PSNR = PSNR / count
191 | print('Average IE/PSNR:', IE, PSNR)
192 |
193 | main()
194 |
195 |
196 |
--------------------------------------------------------------------------------
/sequence_run.py:
--------------------------------------------------------------------------------
1 | # SeDraw
2 | import re
3 | import argparse
4 | import os
5 | import torch
6 | import cv2
7 | import torchvision.transforms as transforms
8 | from skimage.measure import compare_psnr
9 | from PIL import Image
10 | import src.pure_network as layers
11 | from tqdm import tqdm
12 | import numpy as np
13 | import math
14 | import models.bdcn.bdcn as bdcn
15 |
16 | # For parsing commandline arguments
17 | def str2bool(v):
18 | if isinstance(v, bool):
19 | return v
20 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
21 | return True
22 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
23 | return False
24 | else:
25 | raise argparse.ArgumentTypeError('Boolean value expected.')
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model')
29 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3')
30 | parser.add_argument('--bdcn_model', type=str, default='./models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth')
31 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.')
32 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint')
33 | parser.add_argument('--video_path', type=str, required=True)
34 | parser.add_argument('--t_interp', type=int, default=4, help='times of interpolating')
35 | parser.add_argument('--fps', type=int, default=-1, help='specify the fps.')
36 | parser.add_argument('--slow_motion', action='store_true', help='using this flag if you want to slow down the video and maintain fps.')
37 |
38 | args = parser.parse_args()
39 |
40 |
41 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0):
42 |
43 |
44 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
45 | with open(path, 'rb') as f:
46 | img = Image.open(f)
47 | # Resize image if specified.
48 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img
49 | # Crop image if crop area specified.
50 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img
51 | # Flip image horizontally if specified.
52 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img
53 | return flipped_img.convert('RGB')
54 |
55 |
56 |
57 | bdcn = bdcn.BDCN()
58 | bdcn.cuda()
59 | structure_gen = layers.StructureGen(feature_level=args.feature_level)
60 | structure_gen.cuda()
61 | detail_enhance = layers.DetailEnhance()
62 | detail_enhance.cuda()
63 |
64 |
65 | # Channel wise mean calculated on adobe240-fps training dataset
66 | mean = [0.5, 0.5, 0.5]
67 | std = [0.5, 0.5, 0.5]
68 | normalize = transforms.Normalize(mean=mean,
69 | std=std)
70 | transform = transforms.Compose([transforms.ToTensor(), normalize])
71 |
72 | negmean = [-1 for x in mean]
73 | restd = [2, 2, 2]
74 | revNormalize = transforms.Normalize(mean=negmean, std=restd)
75 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
76 |
77 |
78 | def ToImage(frame0, frame1):
79 |
80 | with torch.no_grad():
81 |
82 | img0 = frame0.cuda()
83 | img1 = frame1.cuda()
84 |
85 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
86 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
87 | ref_imgt, _ = structure_gen((img0_e, img1_e))
88 | imgt = detail_enhance((img0, img1, ref_imgt))
89 | # imgt = detail_enhance((img0, img1, imgt))
90 | imgt = torch.clamp(imgt, max=1., min=-1.)
91 |
92 | return imgt
93 | def IndexHelper(i, digit):
94 | index = str(i)
95 | for j in range(digit-len(str(i))):
96 | index = '0'+index
97 | return index
98 |
99 | def VideoToSequence(path, time):
100 | video = cv2.VideoCapture(path)
101 | dir_path = 'frames_tmp'
102 | os.system("rm -rf %s" % dir_path)
103 | os.mkdir(dir_path)
104 | fps = int(video.get(cv2.CAP_PROP_FPS))
105 | length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
106 | print('making ' + str(length) + ' frame sequence in ' + dir_path)
107 | i = -1
108 | while (True):
109 | (grabbed, frame) = video.read()
110 | if not grabbed:
111 | break
112 | i = i + 1
113 | index = IndexHelper(i*time, len(str(time*length)))
114 | cv2.imwrite(dir_path + '/' + index + '.png', frame)
115 | # print(index)
116 | return [dir_path, length, fps]
117 |
118 | def main():
119 | # initial
120 | iter = math.log(args.t_interp, int(2))
121 | if iter%1:
122 | print('the times of interpolating must be power of 2!!')
123 | return
124 | iter = int(iter)
125 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
126 | dict1 = torch.load(args.checkpoint)
127 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
128 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)
129 |
130 | bdcn.eval()
131 | structure_gen.eval()
132 | detail_enhance.eval()
133 |
134 | IE = 0
135 | PSNR = 0
136 | count = 0
137 | [dir_path, frame_count, fps] = VideoToSequence(args.video_path, args.t_interp)
138 |
139 | for i in range(iter):
140 | print('processing iter' + str(i+1) + ', ' + str((i+1)*frame_count) + ' frames in total')
141 | filenames = os.listdir(dir_path)
142 | filenames.sort()
143 | for i in range(0, len(filenames) - 1):
144 | arguments_strFirst = os.path.join(dir_path, filenames[i])
145 | arguments_strSecond = os.path.join(dir_path, filenames[i + 1])
146 | index1 = int(re.sub("\D", "", filenames[i]))
147 | index2 = int(re.sub("\D", "", filenames[i + 1]))
148 | index = int((index1 + index2) / 2)
149 | arguments_strOut = os.path.join(dir_path,
150 | IndexHelper(index, len(str(args.t_interp * frame_count))) + ".png")
151 |
152 | # print(arguments_strFirst)
153 | # print(arguments_strSecond)
154 | # print(arguments_strOut)
155 |
156 | X0 = transform(_pil_loader(arguments_strFirst)).unsqueeze(0)
157 | X1 = transform(_pil_loader(arguments_strSecond)).unsqueeze(0)
158 |
159 | assert (X0.size(2) == X1.size(2))
160 | assert (X0.size(3) == X1.size(3))
161 |
162 | intWidth = X0.size(3)
163 | intHeight = X0.size(2)
164 | channel = X0.size(1)
165 | if not channel == 3:
166 | print('Not RGB image')
167 | continue
168 | count += 1
169 |
170 | # if intWidth != ((intWidth >> 4) << 4):
171 | # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary
172 | # intPaddingLeft = int((intWidth_pad - intWidth) / 2)
173 | # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
174 | # else:
175 | # intWidth_pad = intWidth
176 | # intPaddingLeft = 0
177 | # intPaddingRight = 0
178 | #
179 | # if intHeight != ((intHeight >> 4) << 4):
180 | # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary
181 | # intPaddingTop = int((intHeight_pad - intHeight) / 2)
182 | # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
183 | # else:
184 | # intHeight_pad = intHeight
185 | # intPaddingTop = 0
186 | # intPaddingBottom = 0
187 | #
188 | # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])
189 |
190 | # first, second = pader(X0), pader(X1)
191 | first, second = X0, X1
192 | imgt = ToImage(first, second)
193 |
194 | imgt_np = imgt.squeeze(
195 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
196 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
197 | cv2.imwrite(arguments_strOut, imgt_png)
198 |
199 | # rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png')))
200 | # gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt)))
201 |
202 | # diff_rgb = rec_rgb - gt_rgb
203 | # avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2))
204 |
205 | # mse = np.mean((diff_rgb) ** 2)
206 |
207 | # PIXEL_MAX = 255.0
208 | # psnr = compare_psnr(gt_rgb, rec_rgb, 255)
209 | # print(folder, psnr)
210 |
211 | # IE += avg_interp_error_abs
212 | # PSNR += psnr
213 |
214 | # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr)
215 |
216 | # IE = IE / count
217 | # PSNR = PSNR / count
218 | # print('Average IE/PSNR:', IE, PSNR)
219 | if args.fps != -1:
220 | output_fps = args.fps
221 | else:
222 | output_fps = fps if args.slow_motion else args.t_interp*fps
223 | os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "/*.png' -pix_fmt yuv420p output.mp4")
224 | os.system("rm -rf %s" % dir_path)
225 |
226 |
227 | main()
228 |
229 |
230 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__init__.py
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/pure_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/pure_network.cpython-37.pyc
--------------------------------------------------------------------------------
/src/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import random
6 |
7 |
8 | def _make_dataset(dir):
9 | """
10 | Creates a 2D list of all the frames in N clips containing
11 | M frames each.
12 |
13 | 2D List Structure:
14 | [[frame00, frame01,...frameM] <-- clip0
15 | [frame00, frame01,...frameM] <-- clip1
16 | :
17 | [frame00, frame01,...frameM]] <-- clipN
18 |
19 | Parameters
20 | ----------
21 | dir : string
22 | root directory containing clips.
23 |
24 | Returns
25 | -------
26 | list
27 | 2D list described above.
28 | """
29 |
30 |
31 | framesPath = []
32 | # Find and loop over all the clips in root `dir`.
33 | for index, folder in enumerate(os.listdir(dir)):
34 | clipsFolderPath = os.path.join(dir, folder)
35 | # Skip items which are not folders.
36 | if not (os.path.isdir(clipsFolderPath)):
37 | continue
38 | framesPath.append([])
39 | # Find and loop over all the frames inside the clip.
40 | for image in sorted(os.listdir(clipsFolderPath)):
41 | # Add path to list.
42 | framesPath[index].append(os.path.join(clipsFolderPath, image))
43 | return framesPath
44 |
45 | def _make_video_dataset(dir):
46 | """
47 | Creates a 1D list of all the frames.
48 |
49 | 1D List Structure:
50 | [frame0, frame1,...frameN]
51 |
52 | Parameters
53 | ----------
54 | dir : string
55 | root directory containing frames.
56 |
57 | Returns
58 | -------
59 | list
60 | 1D list described above.
61 | """
62 |
63 |
64 | framesPath = []
65 | # Find and loop over all the frames in root `dir`.
66 | for image in sorted(os.listdir(dir)):
67 | # Add path to list.
68 | framesPath.append(os.path.join(dir, image))
69 | return framesPath
70 |
71 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0):
72 | """
73 | Opens image at `path` using pil and applies data augmentation.
74 |
75 | Parameters
76 | ----------
77 | path : string
78 | path of the image.
79 | cropArea : tuple, optional
80 | coordinates for cropping image. Default: None
81 | resizeDim : tuple, optional
82 | dimensions for resizing image. Default: None
83 | frameFlip : int, optional
84 | Non zero to flip image horizontally. Default: 0
85 |
86 | Returns
87 | -------
88 | list
89 | 2D list described above.
90 | """
91 |
92 |
93 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
94 | with open(path, 'rb') as f:
95 | img = Image.open(f)
96 | # Resize image if specified.
97 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img
98 | # Crop image if crop area specified.
99 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img
100 | # Flip image horizontally if specified.
101 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img
102 | return flipped_img.convert('RGB')
103 |
104 |
105 | class FeFlow_vimeo90k(data.Dataset):
106 |
107 | def __init__(self, root, transform=None, dim=(448, 256), randomCropSize=(208, 208), train=True, test=False):
108 |
109 |
110 | # Populate the list with image paths for all the
111 | # frame in `root`.
112 | if train:
113 | train_f = open(os.path.join(root, 'tri_trainlist.txt'), 'r')
114 | test_f = open(os.path.join(root, 'tri_testlist.txt'), 'r')
115 | if train:
116 | self.train_path = train_f.read().split('\n')
117 | self.train_path = [item for item in self.train_path if item != '']
118 | test_val_path = test_f.read().split('\n')
119 | self.test_path = [item for index, item in enumerate(test_val_path) if item != '']
120 | self.validation_path = [item for index, item in enumerate(test_val_path) if index % 100 == 0 and item != '']
121 |
122 | self.randomCropSize = randomCropSize
123 | self.cropX0 = dim[0] - randomCropSize[0]
124 | self.cropY0 = dim[1] - randomCropSize[1]
125 | self.root = root
126 | self.transform = transform
127 | self.train = train
128 | self.test = test
129 |
130 | def __getitem__(self, index):
131 | """
132 | Returns the sample corresponding to `index` from dataset.
133 |
134 | The sample consists of two reference frames - I0 and I1 -
135 | and a random frame chosen from the 7 intermediate frames
136 | available between I0 and I1 along with it's relative index.
137 |
138 | Parameters
139 | ----------
140 | index : int
141 | Index
142 |
143 | Returns
144 | -------
145 | tuple
146 | (sample, returnIndex) where sample is
147 | [I0, intermediate_frame, I1] and returnIndex is
148 | the position of `random_intermediate_frame`.
149 | e.g.- `returnIndex` of frame next to I0 would be 0 and
150 | frame before I1 would be 6.
151 | """
152 |
153 | sample = []
154 |
155 | if (self.train):
156 | ### Data Augmentation ###
157 |
158 | cropX = random.randint(0, self.cropX0)
159 | cropY = random.randint(0, self.cropY0)
160 | cropArea = (cropX, cropY, cropX + self.randomCropSize[0], cropY + self.randomCropSize[1])
161 | # Random reverse frame
162 | if (random.randint(0, 1)):
163 | frameRange = [0, 1, 2]
164 | else:
165 | frameRange = [2, 1, 0]
166 | # Random flip frame
167 | randomFrameFlip = random.randint(0, 1)
168 | else:
169 | # Fixed settings to return same samples every epoch.
170 | # For validation/test sets.
171 | cropArea = (0, 0, self.randomCropSize[0], self.randomCropSize[1])
172 | frameRange = [0, 1, 2]
173 | randomFrameFlip = 0
174 |
175 | # Loop over for all frames corresponding to the `index`.
176 | for frameIndex in frameRange:
177 | # Open image using pil and augment the image.
178 | if self.train:
179 | image = _pil_loader(
180 | self.root + '/sequences/' + self.train_path[index] + '/im{}.png'.format(frameIndex + 1),
181 | cropArea=cropArea, frameFlip=randomFrameFlip)
182 | elif self.test:
183 | image = _pil_loader(
184 | self.root + '/sequences/' + self.test_path[index] + '/im{}.png'.format(frameIndex + 1),
185 | cropArea=cropArea, frameFlip=randomFrameFlip)
186 | else:
187 | image = _pil_loader(
188 | self.root + '/sequences/' + self.validation_path[index] + '/im{}.png'.format(frameIndex + 1),
189 | cropArea=cropArea, frameFlip=randomFrameFlip)
190 | # Apply transformation if specified.
191 | if self.transform is not None:
192 | image = self.transform(image)
193 | sample.append(image)
194 |
195 | return sample
196 |
197 | def __len__(self):
198 | """
199 | Returns the size of dataset. Invoked as len(datasetObj).
200 |
201 | Returns
202 | -------
203 | int
204 | number of samples.
205 | """
206 | if self.train:
207 | return len(self.train_path)
208 | elif self.test:
209 | return len(self.test_path)
210 | else:
211 | return len(self.validation_path)
212 |
213 | def __repr__(self):
214 | """
215 | Returns printable representation of the dataset object.
216 |
217 | Returns
218 | -------
219 | string
220 | info.
221 | """
222 |
223 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
224 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
225 | fmt_str += ' Root Location: {}\n'.format(self.root)
226 | tmp = ' Transforms (if any): '
227 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
228 | return fmt_str
229 |
230 |
231 | class UCI101Test(data.Dataset):
232 | """
233 | A dataloader for loading N samples arranged in this way:
234 |
235 | |-- clip0
236 | |-- frame00
237 | |-- frame01
238 | |-- frame02
239 | |-- clip1
240 | |-- frame00
241 | |-- frame01
242 | |-- frame02
243 | :
244 | :
245 | |-- clipN
246 | |-- frame00
247 | |-- frame01
248 | |-- frame02
249 |
250 | ...
251 |
252 | Attributes
253 | ----------
254 | framesPath : list
255 | List of frames' path in the dataset.
256 |
257 | Methods
258 | -------
259 | __getitem__(index)
260 | Returns the sample corresponding to `index` from dataset.
261 | __len__()
262 | Returns the size of dataset. Invoked as len(datasetObj).
263 | __repr__()
264 | Returns printable representation of the dataset object.
265 | """
266 |
267 | def __init__(self, root, transform=None):
268 | """
269 | Parameters
270 | ----------
271 | root : string
272 | Root directory path.
273 | transform : callable, optional
274 | A function/transform that takes in
275 | a sample and returns a transformed version.
276 | E.g, ``transforms.RandomCrop`` for images.
277 | """
278 |
279 | # Populate the list with image paths for all the
280 | # frame in `root`.
281 | framesPath = _make_dataset(root)
282 | # Raise error if no images found in root.
283 | if len(framesPath) == 0:
284 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n"))
285 |
286 | self.root = root
287 | self.framesPath = framesPath
288 | self.transform = transform
289 |
290 | def __getitem__(self, index):
291 | """
292 | Returns the sample corresponding to `index` from dataset.
293 |
294 | The sample consists of two reference frames - I0 and I1 -
295 | and a intermediate frame between I0 and I1.
296 |
297 | Parameters
298 | ----------
299 | index : int
300 | Index
301 |
302 | Returns
303 | -------
304 | tuple
305 | (sample, returnIndex) where sample is
306 | [I0, intermediate_frame, I1] and returnIndex is
307 | the position of `intermediate_frame`.
308 | The returnIndex is always 3 and is being returned
309 | to maintain compatibility with the `FeFlow`
310 | dataloader where 3 corresponds to the middle frame.
311 | """
312 |
313 | sample = []
314 | # Loop over for all frames corresponding to the `index`.
315 | for framePath in self.framesPath[index]:
316 | # Open image using pil.
317 | image = _pil_loader(framePath)
318 | # Apply transformation if specified.
319 | if self.transform is not None:
320 | image = self.transform(image)
321 | sample.append(image)
322 | return sample, 3
323 |
324 | def __len__(self):
325 | """
326 | Returns the size of dataset. Invoked as len(datasetObj).
327 |
328 | Returns
329 | -------
330 | int
331 | number of samples.
332 | """
333 |
334 | return len(self.framesPath)
335 |
336 | def __repr__(self):
337 | """
338 | Returns printable representation of the dataset object.
339 |
340 | Returns
341 | -------
342 | string
343 | info.
344 | """
345 |
346 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
347 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
348 | fmt_str += ' Root Location: {}\n'.format(self.root)
349 | tmp = ' Transforms (if any): '
350 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
351 | return fmt_str
352 |
353 |
354 | class FeFlow(data.Dataset):
355 | """
356 | A dataloader for loading N samples arranged in this way:
357 |
358 | |-- clip0
359 | |-- frame00
360 | |-- frame01
361 | :
362 | |-- frame11
363 | |-- frame12
364 | |-- clip1
365 | |-- frame00
366 | |-- frame01
367 | :
368 | |-- frame11
369 | |-- frame12
370 | :
371 | :
372 | |-- clipN
373 | |-- frame00
374 | |-- frame01
375 | :
376 | |-- frame11
377 | |-- frame12
378 |
379 | ...
380 |
381 | Attributes
382 | ----------
383 | framesPath : list
384 | List of frames' path in the dataset.
385 |
386 | Methods
387 | -------
388 | __getitem__(index)
389 | Returns the sample corresponding to `index` from dataset.
390 | __len__()
391 | Returns the size of dataset. Invoked as len(datasetObj).
392 | __repr__()
393 | Returns printable representation of the dataset object.
394 | """
395 |
396 |
397 | def __init__(self, root, transform=None, dim=(448, 256), randomCropSize=(208, 208), train=True):
398 | """
399 | Parameters
400 | ----------
401 | root : string
402 | Root directory path.
403 | transform : callable, optional
404 | A function/transform that takes in
405 | a sample and returns a transformed version.
406 | E.g, ``transforms.RandomCrop`` for images.
407 | dim : tuple, optional
408 | Dimensions of images in dataset. Default: (640, 360)
409 | randomCropSize : tuple, optional
410 | Dimensions of random crop to be applied. Default: (352, 352)
411 | train : boolean, optional
412 | Specifies if the dataset is for training or testing/validation.
413 | `True` returns samples with data augmentation like random
414 | flipping, random cropping, etc. while `False` returns the
415 | samples without randomization. Default: True
416 | """
417 |
418 |
419 | # Populate the list with image paths for all the
420 | # frame in `root`.
421 | framesPath = _make_dataset(root)
422 | # Raise error if no images found in root.
423 | if len(framesPath) == 0:
424 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"))
425 |
426 | self.randomCropSize = randomCropSize
427 | self.cropX0 = dim[0] - randomCropSize[0]
428 | self.cropY0 = dim[1] - randomCropSize[1]
429 | self.root = root
430 | self.transform = transform
431 | self.train = train
432 |
433 | self.framesPath = framesPath
434 |
435 | def __getitem__(self, index):
436 | """
437 | Returns the sample corresponding to `index` from dataset.
438 |
439 | The sample consists of two reference frames - I0 and I1 -
440 | and a random frame chosen from the 7 intermediate frames
441 | available between I0 and I1 along with it's relative index.
442 |
443 | Parameters
444 | ----------
445 | index : int
446 | Index
447 |
448 | Returns
449 | -------
450 | tuple
451 | (sample, returnIndex) where sample is
452 | [I0, intermediate_frame, I1] and returnIndex is
453 | the position of `random_intermediate_frame`.
454 | e.g.- `returnIndex` of frame next to I0 would be 0 and
455 | frame before I1 would be 6.
456 | """
457 |
458 |
459 | sample = []
460 |
461 | if (self.train):
462 | ### Data Augmentation ###
463 | # To select random 9 frames from 12 frames in a clip
464 | firstFrame = 0
465 | # Apply random crop on the 9 input frames
466 | cropX = random.randint(0, self.cropX0)
467 | cropY = random.randint(0, self.cropY0)
468 | cropArea = (cropX, cropY, cropX + self.randomCropSize[0], cropY + self.randomCropSize[1])
469 | # Random reverse frame
470 | #frameRange = range(firstFrame, firstFrame + 9) if (random.randint(0, 1)) else range(firstFrame + 8, firstFrame - 1, -1)
471 | # IFrameIndex = random.randint(firstFrame + 1, firstFrame + 7)
472 | IFrameIndex = 1
473 | if (random.randint(0, 1)):
474 | frameRange = [firstFrame, IFrameIndex, firstFrame + 2]
475 | returnIndex = 1
476 | else:
477 | frameRange = [firstFrame + 2, IFrameIndex, firstFrame]
478 | returnIndex = 1
479 | # Random flip frame
480 | randomFrameFlip = random.randint(0, 1)
481 | else:
482 | # Fixed settings to return same samples every epoch.
483 | # For validation/test sets.
484 | firstFrame = 0
485 | cropArea = (0, 0, self.randomCropSize[0], self.randomCropSize[1])
486 | IFrameIndex = 1
487 | returnIndex = 1
488 | frameRange = [0, IFrameIndex, 2]
489 | randomFrameFlip = 0
490 |
491 | # Loop over for all frames corresponding to the `index`.
492 | for frameIndex in frameRange:
493 | # Open image using pil and augment the image.
494 | image = _pil_loader(self.framesPath[index][frameIndex], cropArea=cropArea, frameFlip=randomFrameFlip)
495 | # Apply transformation if specified.
496 | if self.transform is not None:
497 | image = self.transform(image)
498 | sample.append(image)
499 |
500 | return sample, returnIndex
501 |
502 |
503 | def __len__(self):
504 | """
505 | Returns the size of dataset. Invoked as len(datasetObj).
506 |
507 | Returns
508 | -------
509 | int
510 | number of samples.
511 | """
512 |
513 |
514 | return len(self.framesPath)
515 |
516 | def __repr__(self):
517 | """
518 | Returns printable representation of the dataset object.
519 |
520 | Returns
521 | -------
522 | string
523 | info.
524 | """
525 |
526 |
527 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
528 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
529 | fmt_str += ' Root Location: {}\n'.format(self.root)
530 | tmp = ' Transforms (if any): '
531 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
532 | return fmt_str
533 |
534 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Converts a Video to SeDraw version
3 | """
4 | from time import time
5 | import click
6 | import cv2
7 | import torch
8 | from PIL import Image
9 | import numpy as np
10 | import src.model as model
11 | import src.layers as layers
12 | from torchvision import transforms
13 | from torch.functional import F
14 |
15 |
16 | torch.set_grad_enabled(False)
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 |
19 | trans_forward = transforms.ToTensor()
20 | trans_backward = transforms.ToPILImage()
21 | if device != "cpu":
22 | mean = [0.429, 0.431, 0.397]
23 | mea0 = [-m for m in mean]
24 | std = [1] * 3
25 | trans_forward = transforms.Compose([trans_forward, transforms.Normalize(mean=mean, std=std)])
26 | trans_backward = transforms.Compose([transforms.Normalize(mean=mea0, std=std), trans_backward])
27 |
28 | network = layers.Network()
29 | network = torch.nn.DataParallel(network)
30 |
31 |
32 |
33 |
34 | def load_models(checkpoint):
35 | states = torch.load(checkpoint, map_location='cpu')
36 | network.module.load_state_dict(states['state_dictNET'])
37 |
38 |
39 | def interpolate_batch(frames, factor):
40 | frame0 = torch.stack(frames[:-1])
41 | frame1 = torch.stack(frames[1:])
42 |
43 | img0 = frame0.to(device)
44 | img1 = frame1.to(device)
45 |
46 | frame_buffer = []
47 | for i in range(factor - 1):
48 | # start = time()
49 | frame_index = torch.ones(frames.__len__() - 1).type(torch.long) * i
50 | output = network((img0, img1, frame_index, None))
51 | ft_p = output[3]
52 |
53 | frame_buffer.append(ft_p)
54 | # print('time:', time() - start)
55 |
56 | return frame_buffer
57 |
58 |
59 | def load_batch(video_in, batch_size, batch, w, h):
60 | if len(batch) > 0:
61 | batch = [batch[-1]]
62 |
63 | for i in range(batch_size):
64 | ok, frame = video_in.read()
65 | if not ok:
66 | break
67 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
68 | frame = Image.fromarray(frame)
69 | frame = frame.resize((w, h), Image.ANTIALIAS)
70 | frame = frame.convert('RGB')
71 | frame = trans_forward(frame)
72 | batch.append(frame)
73 |
74 | return batch
75 |
76 |
77 | def denorm_frame(frame, w0, h0):
78 | frame = frame.cpu()
79 | frame = trans_backward(frame)
80 | frame = frame.resize((w0, h0), Image.BILINEAR)
81 | frame = frame.convert('RGB')
82 | return np.array(frame)[:, :, ::-1].copy()
83 |
84 |
85 | def convert_video(source, dest, factor, batch_size=10, output_format='mp4v', output_fps=30):
86 | vin = cv2.VideoCapture(source)
87 | count = vin.get(cv2.CAP_PROP_FRAME_COUNT)
88 | w0, h0 = int(vin.get(cv2.CAP_PROP_FRAME_WIDTH)), int(vin.get(cv2.CAP_PROP_FRAME_HEIGHT))
89 |
90 | codec = cv2.VideoWriter_fourcc(*output_format)
91 | vout = cv2.VideoWriter(dest, codec, float(output_fps), (w0, h0))
92 |
93 | w, h = (w0 // 32) * 32, (h0 // 32) * 32
94 | network.module.setup('custom', h, w)
95 |
96 | network.module.setup_t(factor)
97 | network.cuda()
98 |
99 | done = 0
100 | batch = []
101 | while True:
102 | batch = load_batch(vin, batch_size, batch, w, h)
103 | if len(batch) == 1:
104 | break
105 | done += len(batch) - 1
106 |
107 | intermediate_frames = interpolate_batch(batch, factor)
108 | intermediate_frames = list(zip(*intermediate_frames))
109 |
110 | for fid, iframe in enumerate(intermediate_frames):
111 | vout.write(denorm_frame(batch[fid], w0, h0))
112 | for frm in iframe:
113 | vout.write(denorm_frame(frm, w0, h0))
114 |
115 | try:
116 | yield len(batch), done, count
117 | except StopIteration:
118 | break
119 |
120 | vout.write(denorm_frame(batch[0], w0, h0))
121 |
122 | vin.release()
123 | vout.release()
124 |
125 |
126 | @click.command('Evaluate Model by converting a low-FPS video to high-fps')
127 | @click.argument('input')
128 | @click.option('--checkpoint', help='Path to model checkpoint')
129 | @click.option('--output', help='Path to output file to save')
130 | @click.option('--batch', default=4, help='Number of frames to process in single forward pass')
131 | @click.option('--scale', default=4, help='Scale Factor of FPS')
132 | @click.option('--fps', default=30, help='FPS of output video')
133 | def main(input, checkpoint, output, batch, scale, fps):
134 | avg = lambda x, n, x0: (x * n/(n+1) + x0 / (n+1), n+1)
135 | load_models(checkpoint)
136 | t0 = time()
137 | n0 = 0
138 | fpx = 0
139 | for dl, fd, fc in convert_video(input, output, int(scale), int(batch), output_fps=int(fps)):
140 | fpx, n0 = avg(fpx, n0, dl / (time() - t0))
141 | prg = int(100*fd/fc)
142 | eta = (fc - fd) / fpx
143 | print('\rDone: {:03d}% FPS: {:05.2f} ETA: {:.2f}s'.format(prg, fpx, eta) + ' '*5, end='')
144 | t0 = time()
145 |
146 |
147 | if __name__ == '__main__':
148 | main()
149 |
150 |
151 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | import src.model as model
2 | import src.loss as loss
3 | import torch.nn as nn
4 | import torch
5 | from math import log10
6 | from models.ResBlock import GlobalGenerator
7 |
8 |
9 | class StructureGen(nn.Module):
10 |
11 | def __init__(self, feature_level=3):
12 | super(StructureGen, self).__init__()
13 |
14 | self.feature_level = feature_level
15 | channel = 2 ** (6 + self.feature_level)
16 | # self.structure_extractor = model.StructureExtractor()
17 | self.extract_features = model.ExtractFeatures()
18 | self.dcn = model.DeformableConv(channel, dg=16)
19 | self.generator = GlobalGenerator(channel, 4, n_downsampling=self.feature_level)
20 |
21 | # Loss calculate
22 | self.L1_lossFn = nn.L1Loss()
23 | self.sL1_lossFn = nn.SmoothL1Loss()
24 | self.cL1_lossFn = loss.CharbonnierLoss()
25 | self.MSE_LossFn = nn.MSELoss()
26 |
27 | def forward(self, input):
28 | img0_e, img1_e, IFrame_e = input
29 |
30 | ft_img0 = list(self.extract_features(img0_e))[self.feature_level]
31 | ft_img1 = list(self.extract_features(img1_e))[self.feature_level]
32 |
33 | pre_gen_ft_imgt, out_x, out_y = self.dcn(ft_img0, ft_img1)
34 |
35 | ref_imgt_e = self.generator(pre_gen_ft_imgt)
36 |
37 | # divide edge
38 | IFrame, edge_IFrame = IFrame_e[:, :3], IFrame_e[:, 3:]
39 | ref_imgt, edge_ref_imgt = ref_imgt_e[:, :3], ref_imgt_e[:, 3:]
40 |
41 | # extract for loss
42 | ft_IFrame = list(self.extract_features(IFrame_e))[self.feature_level]
43 | ft_ref_imgt = list(self.extract_features(ref_imgt_e))[self.feature_level]
44 | # st_IFrame = self.structure_extractor(IFrame)
45 | # st_ref_imgt = self.structure_extractor(ref_imgt)
46 |
47 | # Loss calculate
48 |
49 | feature_mix_loss = 500 * self.cL1_lossFn(pre_gen_ft_imgt, ft_IFrame)
50 |
51 | tri_loss = 20 * (self.cL1_lossFn(out_x, ft_IFrame) + self.L1_lossFn(out_y, ft_IFrame))
52 |
53 | # feature_gen_loss = 10 * self.MSE_LossFn(ft_ref_imgt, ft_IFrame)
54 |
55 | # structure_loss = 20 * (self.L1_lossFn(st_ref_imgt[0], st_IFrame[0]) + self.L1_lossFn(st_ref_imgt[1], st_IFrame[1]) +
56 | # self.L1_lossFn(st_ref_imgt[2], st_IFrame[2]) + self.L1_lossFn(st_ref_imgt[3], st_IFrame[3]))
57 |
58 | edge_loss = 5 * self.MSE_LossFn(edge_ref_imgt, edge_IFrame)
59 |
60 | gen_loss = 128 * self.cL1_lossFn(ref_imgt, IFrame)
61 |
62 | loss = gen_loss + feature_mix_loss + edge_loss + tri_loss# + structure_loss + feature_gen_loss
63 |
64 | MSE_val = self.MSE_LossFn(ref_imgt, IFrame)
65 |
66 | print('Loss:', loss.item(), 'feature_mix_loss:', feature_mix_loss.item(),
67 | 'gen_loss:', gen_loss.item(),
68 | # 'feature_gen_loss:', feature_gen_loss.item(),
69 | # 'structure_loss:', structure_loss.item(),
70 | 'edge_loss:', edge_loss.item(),
71 | 'tri_loss:', tri_loss.item())
72 |
73 | return loss, MSE_val, ref_imgt
74 |
75 |
76 |
77 | class DetailEnhance(nn.Module):
78 |
79 |
80 | def __init__(self):
81 |
82 | super(DetailEnhance, self).__init__()
83 |
84 | self.feature_level = 3
85 |
86 | self.extract_features = model.ValidationFeatures()
87 | self.extract_aligned_features = model.ExtractAlignedFeatures(n_res=5) # 4 5
88 | self.pcd_align = model.PCD_Align(groups=8) # 4 8
89 | self.tsa_fusion = model.TSA_Fusion(nframes=3, center=1)
90 |
91 | self.reconstruct = model.Reconstruct(n_res=20) # 5 40
92 |
93 | # Loss calculate
94 | self.L1_lossFn = nn.L1Loss()
95 | self.sL1_lossFn = nn.SmoothL1Loss()
96 | self.cL1_lossFn = loss.CharbonnierLoss()
97 | self.MSE_LossFn = nn.MSELoss()
98 |
99 | def forward(self, input):
100 | """
101 | Network forward tensor flow
102 |
103 | :param input: a tuple of input that will be unfolded
104 | :return: medium interpolation image
105 | """
106 | img0, img1, IFrame, ref_imgt = input
107 |
108 | ref_align_ft = self.extract_aligned_features(ref_imgt)
109 | align_ft_0 = self.extract_aligned_features(img0)
110 | align_ft_1 = self.extract_aligned_features(img1)
111 |
112 | align_ft = [self.pcd_align(align_ft_0, ref_align_ft),
113 | self.pcd_align(ref_align_ft, ref_align_ft),
114 | self.pcd_align(align_ft_1, ref_align_ft)]
115 | align_ft = torch.stack(align_ft, dim=1)
116 |
117 | tsa_ft = self.tsa_fusion(align_ft)
118 |
119 | imgt = self.reconstruct(tsa_ft, ref_imgt)
120 |
121 |
122 | # extract for loss
123 | ft_IFrame = list(self.extract_features(IFrame))[self.feature_level]
124 | ft_imgt = list(self.extract_features(imgt))[self.feature_level]
125 |
126 |
127 | """
128 | ----------------------------------------------------------------------
129 | ======================================================================
130 | """
131 |
132 | # Loss calculate
133 | # feature_recn_loss = 10 * self.MSE_LossFn(ft_imgt, ft_IFrame)
134 |
135 | recn_loss = 128 * self.cL1_lossFn(imgt, IFrame)
136 | ie = 128 * self.L1_lossFn(imgt, IFrame)
137 |
138 | loss = recn_loss #+ feature_recn_loss
139 |
140 | MSE_val = self.MSE_LossFn(imgt, IFrame)
141 | psnr = (10 * log10(4 / MSE_val.item()))
142 |
143 | print('Loss:', loss.item(), 'psnr:', psnr,
144 | 'recn_loss:', recn_loss.item(),
145 | # 'feature_recn_loss:', feature_recn_loss.item()
146 | )
147 |
148 | return loss, MSE_val, ie, imgt
149 |
150 | class DetailEnhance_last(nn.Module):
151 |
152 |
153 | def __init__(self):
154 |
155 | super(DetailEnhance_last, self).__init__()
156 |
157 | self.feature_level = 3
158 |
159 | self.extract_features = model.ValidationFeatures()
160 | self.extract_aligned_features = model.ExtractAlignedFeatures(n_res=5) # 4 5
161 | self.pcd_align = model.PCD_Align(groups=8) # 4 8
162 | self.tsa_fusion = model.TSA_Fusion(nframes=3, center=1)
163 |
164 | self.reconstruct = model.Reconstruct(n_res=20) # 5 40
165 |
166 | # Loss calculate
167 | self.L1_lossFn = nn.L1Loss()
168 | self.sL1_lossFn = nn.SmoothL1Loss()
169 | self.cL1_lossFn = loss.CharbonnierLoss()
170 | self.MSE_LossFn = nn.MSELoss()
171 |
172 | def forward(self, input):
173 | """
174 | Network forward tensor flow
175 |
176 | :param input: a tuple of input that will be unfolded
177 | :return: medium interpolation image
178 | """
179 | img0, img1, IFrame, ref_imgt = input
180 |
181 | ref_align_ft = self.extract_aligned_features(ref_imgt)
182 | align_ft_0 = self.extract_aligned_features(img0)
183 | align_ft_1 = self.extract_aligned_features(img1)
184 |
185 | align_ft = [self.pcd_align(align_ft_0, ref_align_ft),
186 | self.pcd_align(ref_align_ft, ref_align_ft),
187 | self.pcd_align(align_ft_1, ref_align_ft)]
188 | align_ft = torch.stack(align_ft, dim=1)
189 |
190 | tsa_ft = self.tsa_fusion(align_ft)
191 |
192 | imgt = self.reconstruct(tsa_ft, ref_imgt)
193 |
194 |
195 | # extract for loss
196 | # ft_IFrame = list(self.extract_features(IFrame))[self.feature_level]
197 | # ft_imgt = list(self.extract_features(imgt))[self.feature_level]
198 |
199 |
200 | """
201 | ----------------------------------------------------------------------
202 | ======================================================================
203 | """
204 |
205 | # Loss calculate
206 | # feature_recn_loss = 10 * self.MSE_LossFn(ft_imgt, ft_IFrame)
207 |
208 | recn_loss = 128 * self.L1_lossFn(imgt, IFrame)
209 |
210 | loss = recn_loss# + feature_recn_loss
211 |
212 | MSE_val = self.MSE_LossFn(imgt, IFrame)
213 | psnr = (10 * log10(4 / MSE_val.item()))
214 |
215 | print('Loss:', loss.item(), 'psnr:', psnr,
216 | 'recn_loss:', recn_loss.item())
217 | # 'feature_recn_loss:', feature_recn_loss.item())
218 |
219 | return loss, MSE_val, imgt
220 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """Charbonnier Loss (L1)"""
7 |
8 | def __init__(self, eps=1e-6):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = eps
11 |
12 | def forward(self, x, y):
13 | diff = x - y
14 | loss = torch.mean(torch.sqrt(diff * diff + self.eps))
15 | return loss
16 |
--------------------------------------------------------------------------------
/src/pure_network.py:
--------------------------------------------------------------------------------
1 | import src.model as model
2 | import torch.nn as nn
3 | import torch
4 | from models.ResBlock import GlobalGenerator
5 |
6 |
7 | class StructureGen(nn.Module):
8 |
9 | def __init__(self, feature_level=3):
10 | super(StructureGen, self).__init__()
11 |
12 | self.feature_level = feature_level
13 | channel = 2 ** (6 + self.feature_level)
14 | self.extract_features = model.ExtractFeatures()
15 | self.dcn = model.DeformableConv(channel, dg=16)
16 | self.generator = GlobalGenerator(channel, 4, n_downsampling=self.feature_level)
17 |
18 | def forward(self, input):
19 | img0_e, img1_e= input
20 |
21 | ft_img0 = list(self.extract_features(img0_e))[self.feature_level]
22 | ft_img1 = list(self.extract_features(img1_e))[self.feature_level]
23 |
24 | pre_gen_ft_imgt, out_x, out_y = self.dcn(ft_img0, ft_img1)
25 |
26 | ref_imgt_e = self.generator(pre_gen_ft_imgt)
27 |
28 | # divide edge
29 | ref_imgt, edge_ref_imgt = ref_imgt_e[:, :3], ref_imgt_e[:, 3:]
30 |
31 | return ref_imgt, edge_ref_imgt
32 |
33 |
34 |
35 | class DetailEnhance(nn.Module):
36 |
37 |
38 | def __init__(self):
39 |
40 | super(DetailEnhance, self).__init__()
41 |
42 | self.feature_level = 3
43 |
44 | self.extract_features = model.ValidationFeatures()
45 | self.extract_aligned_features = model.ExtractAlignedFeatures(n_res=5) # 4 5
46 | self.pcd_align = model.PCD_Align(groups=8) # 4 8
47 | self.tsa_fusion = model.TSA_Fusion(nframes=3, center=1)
48 |
49 | self.reconstruct = model.Reconstruct(n_res=20) # 5 40
50 |
51 | def forward(self, input):
52 | """
53 | Network forward tensor flow
54 |
55 | :param input: a tuple of input that will be unfolded
56 | :return: medium interpolation image
57 | """
58 | img0, img1, ref_imgt = input
59 |
60 | ref_align_ft = self.extract_aligned_features(ref_imgt)
61 | align_ft_0 = self.extract_aligned_features(img0)
62 | align_ft_1 = self.extract_aligned_features(img1)
63 |
64 | align_ft = [self.pcd_align(align_ft_0, ref_align_ft),
65 | self.pcd_align(ref_align_ft, ref_align_ft),
66 | self.pcd_align(align_ft_1, ref_align_ft)]
67 | align_ft = torch.stack(align_ft, dim=1)
68 |
69 | tsa_ft = self.tsa_fusion(align_ft)
70 |
71 | imgt = self.reconstruct(tsa_ft, ref_imgt)
72 |
73 | return imgt
74 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | from torch.nn.functional import interpolate
2 |
3 |
4 | def feature_transform(img, n_upsample):
5 | """
6 | transform img like feature to be visualized
7 | :param img: [B, C, H, W]
8 | :return: visualized img range from 0 to 1
9 | """
10 | img = img[0, 1:2].repeat(3, 1, 1)
11 | img = interpolate(((img - img.min()) / (img.max() - img.min())).unsqueeze(0), scale_factor=n_upsample,
12 | align_corners=False, mode='bilinear')
13 |
14 | return img
15 |
16 | def edge_transform(img):
17 | """
18 | transform img like feature to be visualized
19 | :param img: [B, C, H, W]
20 | :return: visualized img range from 0 to 1
21 | """
22 | img = img[0].repeat(3, 1, 1)
23 | img = ((img - img.min()) / (img.max() - img.min())).unsqueeze(0)
24 |
25 | return img
26 |
--------------------------------------------------------------------------------
/video_process.py:
--------------------------------------------------------------------------------
1 | # SeDraw
2 | import argparse
3 | import os
4 | import torch
5 | import cv2
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 | import src.pure_network as layers
9 | from tqdm import tqdm
10 | import numpy as np
11 | import math
12 | import models.bdcn.bdcn as bdcn
13 |
14 | # For parsing commandline arguments
15 | def str2bool(v):
16 | if isinstance(v, bool):
17 | return v
18 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
19 | return True
20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
21 | return False
22 | else:
23 | raise argparse.ArgumentTypeError('Boolean value expected.')
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model')
27 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3')
28 | parser.add_argument('--bdcn_model', type=str, default='./models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth')
29 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.')
30 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint')
31 | parser.add_argument('--video_name', type=str, required=True, help='the path the video.')
32 | parser.add_argument('--batchsize', type=int, default=4)
33 | parser.add_argument('--fix_range', action='store_true', help="it won't change the fps without this flag.")
34 | args = parser.parse_args()
35 |
36 |
37 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0):
38 |
39 |
40 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
41 | with open(path, 'rb') as f:
42 | img = Image.open(f)
43 | # Resize image if specified.
44 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img
45 | # Crop image if crop area specified.
46 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img
47 | # Flip image horizontally if specified.
48 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img
49 | return flipped_img.convert('RGB')
50 |
51 |
52 |
53 | bdcn = bdcn.BDCN()
54 | bdcn.cuda()
55 | structure_gen = layers.StructureGen(feature_level=args.feature_level)
56 | structure_gen.cuda()
57 | detail_enhance = layers.DetailEnhance()
58 | detail_enhance.cuda()
59 |
60 |
61 | # Channel wise mean calculated on adobe240-fps training dataset
62 | mean = [0.5, 0.5, 0.5]
63 | std = [0.5, 0.5, 0.5]
64 | normalize = transforms.Normalize(mean=mean,
65 | std=std)
66 | transform = transforms.Compose([transforms.ToTensor(), normalize])
67 |
68 | negmean = [-1 for x in mean]
69 | restd = [2, 2, 2]
70 | revNormalize = transforms.Normalize(mean=negmean, std=restd)
71 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
72 |
73 |
74 | def ToImage(frame0, frame1):
75 |
76 | with torch.no_grad():
77 |
78 | img0 = frame0.cuda()
79 | img1 = frame1.cuda()
80 |
81 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
82 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
83 | ref_imgt, _ = structure_gen((img0_e, img1_e))
84 | imgt = detail_enhance((img0, img1, ref_imgt))
85 | # imgt = detail_enhance((img0, img1, imgt))
86 | imgt = torch.clamp(imgt, max=1., min=-1.)
87 |
88 | return imgt
89 |
90 |
91 | def main():
92 | # initial
93 |
94 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
95 | dict1 = torch.load(args.checkpoint)
96 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
97 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)
98 |
99 | bdcn.eval()
100 | structure_gen.eval()
101 | detail_enhance.eval()
102 |
103 | if not os.path.isfile(args.video_name):
104 | print('video not exist!')
105 | video = cv2.VideoCapture(args.video_name)
106 | if args.fix_range:
107 | fps = video.get(cv2.CAP_PROP_FPS) * 2
108 | else:
109 | # fps = video.get(cv2.CAP_PROP_FPS)
110 | fps = 25
111 | size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
112 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
113 | # fourcc = int(video.get(cv2.CAP_PROP_FOURCC))
114 | video_writer = cv2.VideoWriter(args.video_name[:-4] + '_Sedraw.mp4',
115 | fourcc,
116 | fps,
117 | size)
118 |
119 | flag = True
120 | frame_group = []
121 | while video.isOpened():
122 | for i in range(args.batchsize):
123 | ret, frame = video.read()
124 | if ret:
125 | frame = torch.FloatTensor(frame[:, :, ::-1].transpose(2, 0, 1).copy()) / 255
126 | frame = normalize(frame).unsqueeze(0)
127 | frame_group += [frame]
128 | else:
129 | break
130 | if len(frame_group) <= 1:
131 | break
132 | first = torch.cat(frame_group[:-1], dim=0)
133 | second = torch.cat(frame_group[1:], dim=0)
134 |
135 | middle_frame = ToImage(first, second)
136 |
137 | if flag:
138 | for i in range(first.shape[0]):
139 | first_np = first[i].cpu().numpy()
140 | first_png = np.uint8(((first_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
141 | middle_frame_np = middle_frame[i].cpu().numpy()
142 | middle_frame_png = np.uint8(((middle_frame_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
143 | video_writer.write(first_png)
144 | video_writer.write(middle_frame_png)
145 | second_np = second[-1].cpu().numpy()
146 | second_png = np.uint8(((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
147 | video_writer.write(second_png)
148 | frame_group = [second[-1].unsqueeze(0)]
149 | flag = False
150 | else:
151 | for i in range(second.shape[0]):
152 | middle_frame_np = middle_frame[i].cpu().numpy()
153 | middle_frame_png = np.uint8(((middle_frame_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
154 | second_np = second[i].cpu().numpy()
155 | second_png = np.uint8(((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
156 | video_writer.write(middle_frame_png)
157 | video_writer.write(second_png)
158 | frame_group = [second[-1].unsqueeze(0)]
159 |
160 | video_writer.release()
161 |
162 |
163 | main()
164 |
165 |
166 |
--------------------------------------------------------------------------------