├── .gitignore ├── AverageMeter.py ├── LICENSE ├── README.md ├── data ├── Vimeo-90K.py ├── adobe240fps │ ├── test_list.txt │ └── train_list.txt ├── create_dataset.py ├── figures │ ├── SYA_1.png │ ├── SYA_2.png │ ├── check_all.png │ ├── shana.png │ ├── tianqi_all.png │ ├── video.png │ └── youtube.png └── video_transformer.py ├── eval_Vimeo90K.py ├── models ├── DCNv2 │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── dcn_v2.py │ ├── make.sh │ ├── setup.py │ └── src │ │ ├── cpu │ │ ├── dcn_v2_cpu.cpp │ │ └── vision.h │ │ ├── cuda │ │ ├── dcn_v2_cuda.cu │ │ ├── dcn_v2_im2col_cuda.cu │ │ ├── dcn_v2_im2col_cuda.h │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ └── vision.h │ │ ├── dcn_v2.h │ │ └── vision.cpp ├── ResBlock.py ├── Unet.py ├── __init__.py ├── bdcn │ ├── __init__.py │ ├── bdcn.py │ └── vgg16_c.py └── warp.py ├── paper ├── FeatureFlow.pdf └── Supp.pdf ├── pure_run.py ├── requirements.txt ├── run.py ├── sequence_run.py ├── src ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dataloader.cpython-37.pyc │ ├── layers.cpython-37.pyc │ ├── loss.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── pure_network.cpython-37.pyc ├── dataloader.py ├── eval.py ├── layers.py ├── loss.py ├── model.py └── pure_network.py ├── train.py ├── utils └── visualize.py └── video_process.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/** 2 | log/** 3 | checkpoints/** 4 | models/bdcn/final-model/** 5 | -------------------------------------------------------------------------------- /AverageMeter.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value""" 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Citrine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FeatureFlow 2 | 3 | [Paper](https://github.com/CM-BF/FeatureFlow/blob/master/paper/FeatureFlow.pdf) | [Supp](https://github.com/CM-BF/FeatureFlow/blob/master/paper/Supp.pdf) 4 | 5 | A state-of-the-art Video Frame Interpolation Method using deep semantic flows blending. 6 | 7 | FeatureFlow: Robust Video Interpolation via Structure-to-texture Generation (IEEE Conference on Computer Vision and Pattern Recognition 2020) 8 | 9 | ## To Do List 10 | - [x] Preprint 11 | - [x] Training code 12 | 13 | ## Table of Contents 14 | 15 | 1. [Requirements](#requirements) 16 | 1. [Demos](#video-demos) 17 | 1. [Installation](#installation) 18 | 1. [Pre-trained Model](#pre-trained-model) 19 | 1. [Download Results](#download-results) 20 | 1. [Evaluation](#evaluation) 21 | 1. [Test your video](#test-your-video) 22 | 1. [Training](#training) 23 | 1. [Citation](#citation) 24 | 25 | ## Requirements 26 | 27 | * Ubuntu 28 | * PyTorch (>=1.1) 29 | * Cuda (>=10.0) & Cudnn (>=7.0) 30 | * mmdet 1.0rc (from https://github.com/open-mmlab/mmdetection.git) 31 | * visdom (not necessary) 32 | * NVIDIA GPU 33 | 34 | Ps: `requirements.txt` is provided, but do not use it directly. It is just for reference because it contains another project's dependencies. 35 | 36 | ## Video demos 37 | 38 | Click the picture to Download one of them or click [Here(Google)](https://drive.google.com/open?id=1QUYoplBNjaWXJZPO90NiwQwqQz7yF7TX) or [Here(Baidu)](https://pan.baidu.com/s/1J9seoqgC2p9zZ7pegMlH1A)(key: oav2) to download **360p demos**. 39 | 40 | **360p demos**(including comparisons): 41 | 42 | [](https://github.com/CM-BF/storage/tree/master/videos/youtube.mp4 "video1") 43 | [](https://github.com/CM-BF/storage/tree/master/videos/check_all.mp4 "video2") 44 | [](https://github.com/CM-BF/storage/tree/master/videos/tianqi_all.mp4 "video3") 45 | [](https://github.com/CM-BF/storage/tree/master/videos/video.mp4 "video4") 46 | [](https://github.com/CM-BF/storage/tree/master/videos/shana.mp4 "video5") 47 | 48 | **720p demos**: 49 | 50 | [](https://github.com/CM-BF/storage/tree/master/videos/SYA_1.mp4 "video6") 51 | [](https://github.com/CM-BF/storage/tree/master/videos/SYA_2.mp4 "video7") 52 | 53 | ## Installation 54 | * clone this repo 55 | * git clone https://github.com/open-mmlab/mmdetection.git 56 | * install mmdetection: please follow the guidence in its github 57 | ```bash 58 | $ cd mmdetection 59 | $ pip install -r requirements/build.txt 60 | $ pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI" 61 | $ pip install -v -e . # or "python setup.py develop" 62 | $ pip list | grep mmdet 63 | ``` 64 | * Download [test set](http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip) 65 | ```bash 66 | $ unzip vimeo_interp_test.zip 67 | $ cd vimeo_interp_test 68 | $ mkdir sequences 69 | $ cp target/* sequences/ -r 70 | $ cp input/* sequences/ -r 71 | ``` 72 | * Download BDCN's pre-trained model:bdcn_pretrained_on_bsds500.pth to ./model/bdcn/final-model/ 73 | 74 | Ps: For your convenience, you can only download the bdcn_pretrained_on_bsds500.pth: [Google Drive](https://drive.google.com/file/d/1Zot0P8pSBawnehTE32T-8J0rjYq6CSAT/view?usp=sharing) or all of the pre-trained bdcn models its authors provided: [Google Drive](https://drive.google.com/file/d/1CmDMypSlLM6EAvOt5yjwUQ7O5w-xCm1n/view). For a Baidu Cloud link, you can resort to BDCN's GitHub repository. 75 | 76 | ``` 77 | $ pip install scikit-image visdom tqdm prefetch-generator 78 | ``` 79 | 80 | ## Pre-trained Model 81 | 82 | [Google Drive](https://drive.google.com/open?id=1S8C0chFV6Bip6W9lJdZkog0T3xiNxbEx) 83 | 84 | [Baidu Cloud](https://pan.baidu.com/s/1LxVw-89f3GX5r0mZ6wmsJw): ae4x 85 | 86 | Place FeFlow.ckpt to ./checkpoints/. 87 | 88 | ## Download Results 89 | 90 | [Google Drive](https://drive.google.com/open?id=1OtrExUiyIBJe0D6_ZwDfztqJBqji4lmt) 91 | 92 | [Baidu Cloud](https://pan.baidu.com/s/1BaJJ82nSKagly6XZ8KNtAw): pc0k 93 | 94 | ## Evaluation 95 | ```bash 96 | $ CUDA_VISIBLE_DEVICES=0 python eval_Vimeo90K.py --checkpoint ./checkpoints/FeFlow.ckpt --dataset_root ~/datasets/videos/vimeo_interp_test --visdom_env test --vimeo90k --imgpath ./results/ 97 | ``` 98 | 99 | ## Test your video 100 | ```bash 101 | $ CUDA_VISIBLE_DEVICES=0 python sequence_run.py --checkpoint checkpoints/FeFlow.ckpt --video_path ./yourvideo.mp4 --t_interp 4 --slow_motion 102 | ``` 103 | `--t_interp` sets frame multiples, only power of 2(2,4,8...) are supported. Use flag `--slow_motion` to slow down the video which maintains the original fps. 104 | 105 | The output video will be saved as output.mp4 in your working diractory. 106 | 107 | ## Training 108 | 109 | Training Code **train.py** is available now. I can't run it for comfirmation now because I've left the Lab, but I'm sure it will work with right argument settings. 110 | 111 | ```bash 112 | $ CUDA_VISIBLE_DEVICES=0,1 python train.py 113 | ``` 114 | 115 | * Please read the **arguments' help** carefully to fully control the **two-step training**. 116 | * Pay attention to the `--GEN_DE` which is the flag to set the model to Stage-I or Stage-II. 117 | * 2 GPUs is necessary for training or the small batch\_size will cause training process crash. 118 | * Deformable CNN is not stable enough so that you may face training crash sometimes(I didn't fix the random seed), but it can be detected soon after the beginning of running by visualizing results using Visdom. 119 | * Visdom visualization codes[line 75, 201-216 and 338-353] are included which is good for viewing training process and checking crash. 120 | 121 | ## Citation 122 | ``` 123 | @InProceedings{Gui_2020_CVPR, 124 | author = {Gui, Shurui and Wang, Chaoyue and Chen, Qihua and Tao, Dacheng}, 125 | title = {FeatureFlow: Robust Video Interpolation via Structure-to-Texture Generation}, 126 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 127 | month = {June}, 128 | year = {2020} 129 | } 130 | ``` 131 | 132 | ## Contact 133 | [Shurui Gui](mailto:citrinegui@gmail.com); [Chaoyue Wang](mailto:chaoyue.wang@sydney.edu.au) 134 | 135 | ## License 136 | See [MIT License](https://github.com/CM-BF/FeatureFlow/blob/master/LICENSE) 137 | -------------------------------------------------------------------------------- /data/Vimeo-90K.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--dataset', type=str, required=True, help='path to Vimeo-90K dataset') 6 | parser.add_argument('--out', type=str, required=True, help='path to output dataset place') 7 | args = parser.parse_args() 8 | 9 | def main(): 10 | train_out_path = args.out + 'train' 11 | test_out_path = args.out + 'test' 12 | validation_out_path = args.out + 'validation' 13 | os.mkdir(train_out_path) 14 | os.mkdir(test_out_path) 15 | os.mkdir(validation_out_path) 16 | 17 | with open(args.dataset + '/tri_trainlist.txt', 'r') as f: 18 | train_paths = f.read().split('\n') 19 | test_paths = f.read().split('\n') 20 | print() 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /data/adobe240fps/test_list.txt: -------------------------------------------------------------------------------- 1 | 720p_240fps_1.mov 2 | GOPR9635.mp4 3 | GOPR9637a.mp4 4 | IMG_0004a.mov 5 | IMG_0015.mov 6 | IMG_0023.mov 7 | IMG_0179.m4v 8 | IMG_0183.MOV -------------------------------------------------------------------------------- /data/adobe240fps/train_list.txt: -------------------------------------------------------------------------------- 1 | 720p_240fps_2.mov 2 | 720p_240fps_3.mov 3 | 720p_240fps_5.mov 4 | 720p_240fps_6.mov 5 | GOPR9633.mp4 6 | GOPR9634.mp4 7 | GOPR9636.mp4 8 | GOPR9637b.mp4 9 | GOPR9638.mp4 10 | GOPR9639.mp4 11 | GOPR9640.mp4 12 | GOPR9641.mp4 13 | GOPR9642.mp4 14 | GOPR9643.mp4 15 | GOPR9644.mp4 16 | GOPR9645.mp4 17 | GOPR9646.mp4 18 | GOPR9647.mp4 19 | GOPR9648.mp4 20 | GOPR9649.mp4 21 | GOPR9650.mp4 22 | GOPR9651.mp4 23 | GOPR9652.mp4 24 | GOPR9653.mp4 25 | GOPR9654a.mp4 26 | GOPR9654b.mp4 27 | GOPR9655a.mp4 28 | GOPR9655b.mp4 29 | GOPR9656.mp4 30 | GOPR9657.mp4 31 | GOPR9658.MP4 32 | GOPR9659.mp4 33 | GOPR9660.MP4 34 | IMG_0001.mov 35 | IMG_0002.mov 36 | IMG_0003.mov 37 | IMG_0004b.mov 38 | IMG_0005.mov 39 | IMG_0006.mov 40 | IMG_0007.mov 41 | IMG_0008.mov 42 | IMG_0009.mov 43 | IMG_0010.mov 44 | IMG_0011.mov 45 | IMG_0012.mov 46 | IMG_0013.mov 47 | IMG_0014.mov 48 | IMG_0016.mov 49 | IMG_0017.mov 50 | IMG_0018.mov 51 | IMG_0019.mov 52 | IMG_0020.mov 53 | IMG_0021.mov 54 | IMG_0022.mov 55 | IMG_0024.mov 56 | IMG_0025.mov 57 | IMG_0026.mov 58 | IMG_0028.mov 59 | IMG_0029.mov 60 | IMG_0030.mov 61 | IMG_0031.mov 62 | IMG_0032.mov 63 | IMG_0033.mov 64 | IMG_0034.mov 65 | IMG_0034a.mov 66 | IMG_0035.mov 67 | IMG_0036.mov 68 | IMG_0037.mov 69 | IMG_0037a.mov 70 | IMG_0038.mov 71 | IMG_0039.mov 72 | IMG_0040.mov 73 | IMG_0041.mov 74 | IMG_0042.mov 75 | IMG_0043.mov 76 | IMG_0044.mov 77 | IMG_0045.mov 78 | IMG_0046.mov 79 | IMG_0047.mov 80 | IMG_0052.mov 81 | IMG_0054a.mov 82 | IMG_0054b.mov 83 | IMG_0055.mov 84 | IMG_0056.mov 85 | IMG_0058.mov 86 | IMG_0150.m4v 87 | IMG_0151.m4v 88 | IMG_0152.m4v 89 | IMG_0153.m4v 90 | IMG_0154.m4v 91 | IMG_0155.m4v 92 | IMG_0156.m4v 93 | IMG_0157.m4v 94 | IMG_0160.m4v 95 | IMG_0161.m4v 96 | IMG_0162.m4v 97 | IMG_0163.m4v 98 | IMG_0164.m4v 99 | IMG_0167.m4v 100 | IMG_0169.m4v 101 | IMG_0170.m4v 102 | IMG_0171.m4v 103 | IMG_0172.m4v 104 | IMG_0173.m4v 105 | IMG_0174.m4v 106 | IMG_0175.m4v 107 | IMG_0176.m4v 108 | IMG_0177.m4v 109 | IMG_0178.m4v 110 | IMG_0180.m4v 111 | IMG_0200.MOV 112 | IMG_0212.MOV -------------------------------------------------------------------------------- /data/create_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | from shutil import rmtree, move 5 | import random 6 | 7 | # For parsing commandline arguments 8 | parser = argparse.ArgumentParser() 9 | # parser.add_argument("--ffmpeg_dir", type=str, required=True, help='path to ffmpeg.exe') 10 | parser.add_argument("--dataset", type=str, default="custom", help='specify if using "adobe240fps" or custom video dataset') 11 | parser.add_argument("--videos_folder", type=str, required=True, help='path to the folder containing videos') 12 | parser.add_argument("--dataset_folder", type=str, required=True, help='path to the output dataset folder') 13 | parser.add_argument("--img_width", type=int, default=640, help="output image width") 14 | parser.add_argument("--img_height", type=int, default=360, help="output image height") 15 | parser.add_argument("--train_test_split", type=tuple, default=(90, 10), help="train test split for custom dataset") 16 | args = parser.parse_args() 17 | 18 | 19 | def extract_frames(videos, inDir, outDir): 20 | """ 21 | Converts all the videos passed in `videos` list to images. 22 | 23 | Parameters 24 | ---------- 25 | videos : list 26 | name of all video files. 27 | inDir : string 28 | path to input directory containing videos in `videos` list. 29 | outDir : string 30 | path to directory to output the extracted images. 31 | 32 | Returns 33 | ------- 34 | None 35 | """ 36 | 37 | 38 | for video in videos: 39 | os.mkdir(os.path.join(outDir, os.path.splitext(video)[0])) 40 | retn = os.system('{} -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.png'.format('ffmpeg', os.path.join(inDir, video), args.img_width, args.img_height, os.path.join(outDir, os.path.splitext(video)[0]))) 41 | if retn: 42 | print("Error converting file:{}. Exiting.".format(video)) 43 | 44 | 45 | def create_clips(root, destination): 46 | """ 47 | Distributes the images extracted by `extract_frames()` in 48 | clips containing 12 frames each. 49 | 50 | Parameters 51 | ---------- 52 | root : string 53 | path containing extracted image folders. 54 | destination : string 55 | path to output clips. 56 | 57 | Returns 58 | ------- 59 | None 60 | """ 61 | 62 | 63 | folderCounter = -1 64 | 65 | files = os.listdir(root) 66 | 67 | # Iterate over each folder containing extracted video frames. 68 | for file in files: 69 | images = sorted(os.listdir(os.path.join(root, file))) 70 | 71 | for imageCounter, image in enumerate(images): 72 | # Bunch images in groups of 12 frames 73 | if (imageCounter % 12 == 0): 74 | if (imageCounter + 11 >= len(images)): 75 | break 76 | folderCounter += 1 77 | os.mkdir("{}/{}".format(destination, folderCounter)) 78 | move("{}/{}/{}".format(root, file, image), "{}/{}/{}".format(destination, folderCounter, image)) 79 | rmtree(os.path.join(root, file)) 80 | 81 | def main(): 82 | # Create dataset folder if it doesn't exist already. 83 | if not os.path.isdir(args.dataset_folder): 84 | os.mkdir(args.dataset_folder) 85 | 86 | extractPath = os.path.join(args.dataset_folder, "extracted") 87 | trainPath = os.path.join(args.dataset_folder, "train") 88 | testPath = os.path.join(args.dataset_folder, "test") 89 | validationPath = os.path.join(args.dataset_folder, "validation") 90 | os.mkdir(extractPath) 91 | os.mkdir(trainPath) 92 | os.mkdir(testPath) 93 | os.mkdir(validationPath) 94 | 95 | if(args.dataset == "adobe240fps"): 96 | f = open("data/adobe240fps/test_list.txt", "r") 97 | videos = f.read().split('\n') 98 | extract_frames(videos, args.videos_folder, extractPath) 99 | create_clips(extractPath, testPath) 100 | 101 | f = open("data/adobe240fps/train_list.txt", "r") 102 | videos = f.read().split('\n') 103 | extract_frames(videos, args.videos_folder, extractPath) 104 | create_clips(extractPath, trainPath) 105 | 106 | # Select 100 clips at random from test set for validation set. 107 | testClips = os.listdir(testPath) 108 | indices = random.sample(range(len(testClips)), 100) 109 | for index in indices: 110 | move("{}/{}".format(testPath, index), "{}/{}".format(validationPath, index)) 111 | 112 | else: # custom dataset 113 | 114 | # Extract video names 115 | videos = os.listdir(args.videos_folder) 116 | 117 | # Create random train-test split. 118 | testIndices = random.sample(range(len(videos)), int((args.train_test_split[1] * len(videos)) / 100)) 119 | trainIndices = [x for x in range((len(videos))) if x not in testIndices] 120 | 121 | # Create list of video names 122 | testVideoNames = [videos[index] for index in testIndices] 123 | trainVideoNames = [videos[index] for index in trainIndices] 124 | 125 | # Create train-test dataset 126 | extract_frames(testVideoNames, args.videos_folder, extractPath) 127 | create_clips(extractPath, testPath) 128 | extract_frames(trainVideoNames, args.videos_folder, extractPath) 129 | create_clips(extractPath, trainPath) 130 | 131 | # Select clips at random from test set for validation set. 132 | testClips = os.listdir(testPath) 133 | indices = random.sample(range(len(testClips)), min(100, int(len(testClips) / 5))) 134 | for index in indices: 135 | move("{}/{}".format(testPath, index), "{}/{}".format(validationPath, index)) 136 | 137 | rmtree(extractPath) 138 | 139 | main() 140 | -------------------------------------------------------------------------------- /data/figures/SYA_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/SYA_1.png -------------------------------------------------------------------------------- /data/figures/SYA_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/SYA_2.png -------------------------------------------------------------------------------- /data/figures/check_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/check_all.png -------------------------------------------------------------------------------- /data/figures/shana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/shana.png -------------------------------------------------------------------------------- /data/figures/tianqi_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/tianqi_all.png -------------------------------------------------------------------------------- /data/figures/video.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/video.png -------------------------------------------------------------------------------- /data/figures/youtube.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/data/figures/youtube.png -------------------------------------------------------------------------------- /data/video_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import argparse 5 | import shutil 6 | 7 | import multiprocessing 8 | from tqdm import tqdm 9 | 10 | # command line parser 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--videos_folder', type=str, required=True, help='the path to video dataset folder.') 14 | parser.add_argument('--output_folder', type=str, default='../pre_dataset/', help='the path to output dataset folder.') 15 | parser.add_argument('--lower_rate', type=int, default=5, help='lower the video fps by n times.') 16 | args = parser.parse_args() 17 | 18 | 19 | class DataCreator(object): 20 | 21 | def __init__(self): 22 | 23 | self.videos_folder = args.videos_folder 24 | self.output_folder = args.output_folder 25 | self.lower_rate = args.lower_rate 26 | self.tmp = '../.tmp/' 27 | try: 28 | os.mkdir(self.tmp) 29 | except: 30 | pass 31 | 32 | def _listener(self, pbar, q): 33 | for item in iter(q.get, None): 34 | pbar.update(1) 35 | 36 | def _lower_fps(self, p_args): 37 | video_name, q = p_args 38 | # pbar.set_description("Processing %s" % video_name) 39 | 40 | # read a video and create video_writer for lower fps video output 41 | video = cv2.VideoCapture(os.path.join(self.videos_folder, video_name)) 42 | fps = video.get(cv2.CAP_PROP_FPS) 43 | size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))) 44 | # fourcc = cv2.VideoWriter_fourcc(*'XVID') 45 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 46 | video_writer = [cv2.VideoWriter(self.tmp + video_name[:-4] + '_%s' % str(i) + '.mp4', 47 | fourcc, 48 | fps / self.lower_rate, 49 | size) 50 | for i in range(self.lower_rate)] 51 | 52 | count = 0 53 | while video.isOpened(): 54 | ret, frame = video.read() 55 | if ret: 56 | video_writer[count % self.lower_rate].write(frame) 57 | if cv2.waitKey(1) & 0xFF == ord('q'): 58 | raise KeyboardInterrupt 59 | else: 60 | break 61 | count += 1 62 | 63 | for i in range(self.lower_rate): 64 | video_writer[i].release() 65 | 66 | q.put(1) 67 | 68 | 69 | def lower_fps(self): 70 | 71 | videos_name = os.listdir(self.videos_folder) 72 | pbar = tqdm(total=len(videos_name)) 73 | m = multiprocessing.Manager() 74 | q = m.Queue() 75 | 76 | listener = multiprocessing.Process(target=self._listener, args=(pbar, q)) 77 | listener.start() 78 | 79 | p_args = [(video_name, q) for video_name in videos_name] 80 | pool = multiprocessing.Pool() 81 | pool.map(self._lower_fps, p_args) 82 | 83 | pool.close() 84 | pool.join() 85 | q.put(None) 86 | listener.join() 87 | 88 | def output(self): 89 | os.system('mkdir %s' % self.output_folder) 90 | os.system('cp %s %s' % (self.tmp + '*', self.output_folder)) 91 | os.system('rm -rf %s' % self.tmp) 92 | 93 | 94 | if __name__ == '__main__': 95 | data_creator = DataCreator() 96 | data_creator.lower_fps() 97 | data_creator.output() 98 | 99 | -------------------------------------------------------------------------------- /eval_Vimeo90K.py: -------------------------------------------------------------------------------- 1 | # SeDraw 2 | import argparse 3 | from skimage.measure import compare_psnr, compare_ssim 4 | import numpy as np 5 | import os 6 | import torch 7 | import cv2 8 | import torchvision.transforms as transforms 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import time 12 | import src.dataloader as dataloader 13 | import src.layers as layers 14 | from math import log10 15 | import datetime 16 | from tqdm import tqdm 17 | from prefetch_generator import BackgroundGenerator 18 | import visdom 19 | from utils.visualize import feature_transform 20 | import numpy as np 21 | import models.bdcn.bdcn as bdcn 22 | 23 | # For parsing commandline arguments 24 | def str2bool(v): 25 | if isinstance(v, bool): 26 | return v 27 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 28 | return True 29 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 30 | return False 31 | else: 32 | raise argparse.ArgumentTypeError('Boolean value expected.') 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--dataset_root", type=str, required=True, 36 | help='path to dataset folder containing train-test-validation folders') 37 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') 38 | parser.add_argument("--test_batch_size", type=int, default=1, help='batch size for training. Default: 6.') 39 | parser.add_argument('--visdom_env', type=str, default='SeDraw_0', help='Environment for visdom show') 40 | parser.add_argument('--vimeo90k', action='store_true', help='use this flag if using Vimeo-90K dataset') 41 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3') 42 | parser.add_argument('--bdcn_model', default='./models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth') 43 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.') 44 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint') 45 | parser.add_argument('--imgpath', type=str, required=True) 46 | args = parser.parse_args() 47 | 48 | 49 | 50 | # --For visualizing loss and interpolated frames-- 51 | 52 | 53 | # Visdom for real-time visualizing 54 | vis = visdom.Visdom(env=args.visdom_env, port=8098) 55 | 56 | # device 57 | device_count = torch.cuda.device_count() 58 | 59 | # --Initialize network-- 60 | bdcn = bdcn.BDCN() 61 | bdcn.cuda() 62 | structure_gen = layers.StructureGen(feature_level=args.feature_level) 63 | structure_gen.cuda() 64 | detail_enhance = layers.DetailEnhance() 65 | detail_enhance.cuda() 66 | 67 | 68 | # --Load Datasets-- 69 | 70 | 71 | # Channel wise mean calculated on adobe240-fps training dataset 72 | mean = [0.5, 0.5, 0.5] 73 | std = [0.5, 0.5, 0.5] 74 | normalize = transforms.Normalize(mean=mean, 75 | std=std) 76 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 77 | 78 | if args.vimeo90k: 79 | testset = dataloader.SeDraw_vimeo90k(root=args.dataset_root, transform=transform, 80 | randomCropSize=(448, 256), train=False, test=True) 81 | else: 82 | testset = dataloader.SeDraw(root=args.dataset_root + '/validation', transform=transform, 83 | randomCropSize=(448, 256), train=False) 84 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size * device_count, shuffle=False) 85 | 86 | print(testset) 87 | 88 | # --Create transform to display image from tensor-- 89 | 90 | 91 | negmean = [-1 for x in mean] 92 | restd = [2, 2, 2] 93 | revNormalize = transforms.Normalize(mean=negmean, std=restd) 94 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) 95 | 96 | 97 | # --Utils-- 98 | 99 | def get_lr(optimizer): 100 | for param_group in optimizer.param_groups: 101 | return param_group['lr'] 102 | 103 | 104 | 105 | # --Validation function-- 106 | # 107 | 108 | 109 | def validate(): 110 | # For details see training. 111 | psnr = 0 112 | ie = 0 113 | tloss = 0 114 | 115 | with torch.no_grad(): 116 | for testIndex, testData in tqdm(enumerate(testloader, 0)): 117 | frame0, frameT, frame1 = testData 118 | 119 | img0 = frame0.cuda() 120 | img1 = frame1.cuda() 121 | IFrame = frameT.cuda() 122 | 123 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) 124 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) 125 | IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1) 126 | _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) 127 | loss, MSE_val, IE, imgt = detail_enhance((img0, img1, IFrame, ref_imgt)) 128 | imgt = torch.clamp(imgt, max=1., min=-1.) 129 | IFrame_np = IFrame.squeeze(0).cpu().numpy() 130 | imgt_np = imgt.squeeze(0).cpu().numpy() 131 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 132 | IFrame_png = np.uint8(((IFrame_np + 1.0) /2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 133 | imgpath = args.imgpath + '/' + str(testIndex) 134 | if not os.path.isdir(imgpath): 135 | os.system('mkdir -p %s' % imgpath) 136 | cv2.imwrite(imgpath + '/imgt.png', imgt_png) 137 | cv2.imwrite(imgpath + '/IFrame.png', IFrame_png) 138 | 139 | PSNR = compare_psnr(IFrame_np, imgt_np, data_range=2) 140 | print('PSNR:', PSNR) 141 | 142 | loss = torch.mean(loss) 143 | MSE_val = torch.mean(MSE_val) 144 | 145 | if testIndex % 100 == 99: 146 | vImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0), 147 | revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0), 148 | revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)], 149 | dim=0) 150 | 151 | 152 | vImg = torch.clamp(vImg, max=1., min=0) 153 | vis.images(vImg, win='vImage', env=args.visdom_env, nrow=2, opts={'title': 'visual_image'}) 154 | 155 | # psnr 156 | tloss += loss.item() 157 | 158 | psnr += PSNR 159 | ie += IE 160 | 161 | return (psnr / len(testloader)), (tloss / len(testloader)), MSE_val, (ie / len(testloader)) 162 | 163 | 164 | # --Initialization-- 165 | 166 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) 167 | 168 | dict1 = torch.load(args.checkpoint) 169 | structure_gen.load_state_dict(dict1['state_dictGEN']) 170 | detail_enhance.load_state_dict(dict1['state_dictDE']) 171 | 172 | start = time.time() 173 | 174 | bdcn.eval() 175 | structure_gen.eval() 176 | detail_enhance.eval() 177 | psnr, vLoss, MSE_val, ie = validate() 178 | end = time.time() 179 | 180 | print(" Loss: %0.6f TestExecTime: %0.1f ValPSNR: %0.4f ValIE: %0.4f" % ( 181 | vLoss, end - start, psnr, ie)) 182 | 183 | -------------------------------------------------------------------------------- /models/DCNv2/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | *.so 4 | *.o 5 | *pyc 6 | _ext 7 | build 8 | DCNv2.egg-info 9 | dist -------------------------------------------------------------------------------- /models/DCNv2/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Charles Shang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /models/DCNv2/README.md: -------------------------------------------------------------------------------- 1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0 2 | 3 | ### Build 4 | ```bash 5 | ./make.sh # build 6 | python test.py # run examples and gradient check 7 | ``` 8 | 9 | ### An Example 10 | - deformable conv 11 | ```python 12 | from dcn_v2 import DCN 13 | input = torch.randn(2, 64, 128, 128).cuda() 14 | # wrap all things (offset and mask) in DCN 15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() 16 | output = dcn(input) 17 | print(output.shape) 18 | ``` 19 | - deformable roi pooling 20 | ```python 21 | from dcn_v2 import DCNPooling 22 | input = torch.randn(2, 32, 64, 64).cuda() 23 | batch_inds = torch.randint(2, (20, 1)).cuda().float() 24 | x = torch.randint(256, (20, 1)).cuda().float() 25 | y = torch.randint(256, (20, 1)).cuda().float() 26 | w = torch.randint(64, (20, 1)).cuda().float() 27 | h = torch.randint(64, (20, 1)).cuda().float() 28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) 29 | 30 | # mdformable pooling (V2) 31 | # wrap all things (offset and mask) in DCNPooling 32 | dpooling = DCNPooling(spatial_scale=1.0 / 4, 33 | pooled_size=7, 34 | output_dim=32, 35 | no_trans=False, 36 | group_size=1, 37 | trans_std=0.1).cuda() 38 | 39 | dout = dpooling(input, rois) 40 | ``` 41 | ### Note 42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, 43 | ```bash 44 | git checkout pytorch_0.4 45 | ``` 46 | 47 | ### Known Issues: 48 | 49 | - [x] Gradient check w.r.t offset (solved) 50 | - [ ] Backward is not reentrant (minor) 51 | 52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). 53 | 54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. 55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some 56 | non-differential points? 57 | 58 | Update: all gradient check passes with double precision. 59 | 60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for 61 | float `<1e-15` for double), 62 | so it may not be a serious problem (?) 63 | 64 | Please post an issue or PR if you have any comments. 65 | -------------------------------------------------------------------------------- /models/DCNv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/models/DCNv2/__init__.py -------------------------------------------------------------------------------- /models/DCNv2/dcn_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import logging 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Function 8 | from torch.nn.modules.utils import _pair 9 | from torch.autograd.function import once_differentiable 10 | 11 | import _ext as _backend 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class _DCNv2(Function): 16 | @staticmethod 17 | def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation, 18 | deformable_groups): 19 | ctx.stride = _pair(stride) 20 | ctx.padding = _pair(padding) 21 | ctx.dilation = _pair(dilation) 22 | ctx.kernel_size = _pair(weight.shape[2:4]) 23 | ctx.deformable_groups = deformable_groups 24 | output = _backend.dcn_v2_forward(input, weight, bias, offset, mask, ctx.kernel_size[0], 25 | ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], 26 | ctx.padding[0], ctx.padding[1], ctx.dilation[0], 27 | ctx.dilation[1], ctx.deformable_groups) 28 | ctx.save_for_backward(input, offset, mask, weight, bias) 29 | return output 30 | 31 | @staticmethod 32 | @once_differentiable 33 | def backward(ctx, grad_output): 34 | input, offset, mask, weight, bias = ctx.saved_tensors 35 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ 36 | _backend.dcn_v2_backward(input, weight, 37 | bias, 38 | offset, mask, 39 | grad_output, 40 | ctx.kernel_size[0], ctx.kernel_size[1], 41 | ctx.stride[0], ctx.stride[1], 42 | ctx.padding[0], ctx.padding[1], 43 | ctx.dilation[0], ctx.dilation[1], 44 | ctx.deformable_groups) 45 | 46 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ 47 | None, None, None, None, 48 | 49 | 50 | dcn_v2_conv = _DCNv2.apply 51 | 52 | 53 | class DCNv2(nn.Module): 54 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, 55 | deformable_groups=1): 56 | super(DCNv2, self).__init__() 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | self.kernel_size = _pair(kernel_size) 60 | self.stride = _pair(stride) 61 | self.padding = _pair(padding) 62 | self.dilation = _pair(dilation) 63 | self.deformable_groups = deformable_groups 64 | 65 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 66 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | n = self.in_channels 71 | for k in self.kernel_size: 72 | n *= k 73 | stdv = 1. / math.sqrt(n) 74 | self.weight.data.uniform_(-stdv, stdv) 75 | self.bias.data.zero_() 76 | 77 | def forward(self, input, offset, mask): 78 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ 79 | offset.shape[1] 80 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ 81 | mask.shape[1] 82 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, 83 | self.dilation, self.deformable_groups) 84 | 85 | 86 | class DCN(DCNv2): 87 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, 88 | deformable_groups=1): 89 | super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, 90 | deformable_groups) 91 | 92 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 93 | self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size, 94 | stride=self.stride, padding=self.padding, bias=True) 95 | self.init_offset() 96 | 97 | def init_offset(self): 98 | self.conv_offset_mask.weight.data.zero_() 99 | self.conv_offset_mask.bias.data.zero_() 100 | 101 | def forward(self, input): 102 | out = self.conv_offset_mask(input) 103 | o1, o2, mask = torch.chunk(out, 3, dim=1) 104 | offset = torch.cat((o1, o2), dim=1) 105 | mask = torch.sigmoid(mask) 106 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, 107 | self.dilation, self.deformable_groups) 108 | 109 | 110 | class DCN_sep(DCNv2): 111 | '''Use other features to generate offsets and masks''' 112 | 113 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 114 | deformable_groups=1): 115 | super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding, 116 | dilation, deformable_groups) 117 | 118 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 119 | self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size, 120 | stride=self.stride, padding=self.padding, bias=True) 121 | self.init_offset() 122 | 123 | def init_offset(self): 124 | self.conv_offset_mask.weight.data.zero_() 125 | self.conv_offset_mask.bias.data.zero_() 126 | 127 | def forward(self, input, fea): 128 | '''input: input features for deformable conv 129 | fea: other features used for generating offsets and mask''' 130 | out = self.conv_offset_mask(fea) 131 | o1, o2, mask = torch.chunk(out, 3, dim=1) 132 | offset = torch.cat((o1, o2), dim=1) 133 | 134 | offset_mean = torch.mean(torch.abs(offset)) 135 | if offset_mean > 100: 136 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) 137 | 138 | mask = torch.sigmoid(mask) 139 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, 140 | self.dilation, self.deformable_groups) 141 | 142 | 143 | class _DCNv2Pooling(Function): 144 | @staticmethod 145 | def forward(ctx, input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans, 146 | group_size=1, part_size=None, sample_per_part=4, trans_std=.0): 147 | ctx.spatial_scale = spatial_scale 148 | ctx.no_trans = int(no_trans) 149 | ctx.output_dim = output_dim 150 | ctx.group_size = group_size 151 | ctx.pooled_size = pooled_size 152 | ctx.part_size = pooled_size if part_size is None else part_size 153 | ctx.sample_per_part = sample_per_part 154 | ctx.trans_std = trans_std 155 | 156 | output, output_count = \ 157 | _backend.dcn_v2_psroi_pooling_forward(input, rois, offset, 158 | ctx.no_trans, ctx.spatial_scale, 159 | ctx.output_dim, ctx.group_size, 160 | ctx.pooled_size, ctx.part_size, 161 | ctx.sample_per_part, ctx.trans_std) 162 | ctx.save_for_backward(input, rois, offset, output_count) 163 | return output 164 | 165 | @staticmethod 166 | @once_differentiable 167 | def backward(ctx, grad_output): 168 | input, rois, offset, output_count = ctx.saved_tensors 169 | grad_input, grad_offset = \ 170 | _backend.dcn_v2_psroi_pooling_backward(grad_output, 171 | input, 172 | rois, 173 | offset, 174 | output_count, 175 | ctx.no_trans, 176 | ctx.spatial_scale, 177 | ctx.output_dim, 178 | ctx.group_size, 179 | ctx.pooled_size, 180 | ctx.part_size, 181 | ctx.sample_per_part, 182 | ctx.trans_std) 183 | 184 | return grad_input, None, grad_offset, \ 185 | None, None, None, None, None, None, None, None 186 | 187 | 188 | dcn_v2_pooling = _DCNv2Pooling.apply 189 | 190 | 191 | class DCNv2Pooling(nn.Module): 192 | def __init__(self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, 193 | part_size=None, sample_per_part=4, trans_std=.0): 194 | super(DCNv2Pooling, self).__init__() 195 | self.spatial_scale = spatial_scale 196 | self.pooled_size = pooled_size 197 | self.output_dim = output_dim 198 | self.no_trans = no_trans 199 | self.group_size = group_size 200 | self.part_size = pooled_size if part_size is None else part_size 201 | self.sample_per_part = sample_per_part 202 | self.trans_std = trans_std 203 | 204 | def forward(self, input, rois, offset): 205 | assert input.shape[1] == self.output_dim 206 | if self.no_trans: 207 | offset = input.new() 208 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size, 209 | self.output_dim, self.no_trans, self.group_size, self.part_size, 210 | self.sample_per_part, self.trans_std) 211 | 212 | 213 | class DCNPooling(DCNv2Pooling): 214 | def __init__(self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, 215 | part_size=None, sample_per_part=4, trans_std=.0, deform_fc_dim=1024): 216 | super(DCNPooling, self).__init__(spatial_scale, pooled_size, output_dim, no_trans, 217 | group_size, part_size, sample_per_part, trans_std) 218 | 219 | self.deform_fc_dim = deform_fc_dim 220 | 221 | if not no_trans: 222 | self.offset_mask_fc = nn.Sequential( 223 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, 224 | self.deform_fc_dim), nn.ReLU(inplace=True), 225 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True), 226 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3)) 227 | self.offset_mask_fc[4].weight.data.zero_() 228 | self.offset_mask_fc[4].bias.data.zero_() 229 | 230 | def forward(self, input, rois): 231 | offset = input.new() 232 | 233 | if not self.no_trans: 234 | 235 | # do roi_align first 236 | n = rois.shape[0] 237 | roi = dcn_v2_pooling( 238 | input, 239 | rois, 240 | offset, 241 | self.spatial_scale, 242 | self.pooled_size, 243 | self.output_dim, 244 | True, # no trans 245 | self.group_size, 246 | self.part_size, 247 | self.sample_per_part, 248 | self.trans_std) 249 | 250 | # build mask and offset 251 | offset_mask = self.offset_mask_fc(roi.view(n, -1)) 252 | offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size) 253 | o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) 254 | offset = torch.cat((o1, o2), dim=1) 255 | mask = torch.sigmoid(mask) 256 | 257 | # do pooling with offset and mask 258 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size, 259 | self.output_dim, self.no_trans, self.group_size, self.part_size, 260 | self.sample_per_part, self.trans_std) * mask 261 | # only roi_align 262 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size, 263 | self.output_dim, self.no_trans, self.group_size, self.part_size, 264 | self.sample_per_part, self.trans_std) 265 | -------------------------------------------------------------------------------- /models/DCNv2/make.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | # You may need to modify the following paths before compiling. 4 | 5 | # CUDA_HOME=/usr/local/cuda-10.0 \ 6 | # CUDNN_INCLUDE_DIR=/usr/local/cuda-10.0/include \ 7 | # CUDNN_LIB_DIR=/usr/local/cuda-10.0/lib64 \ 8 | python setup.py build develop 9 | -------------------------------------------------------------------------------- /models/DCNv2/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import glob 5 | 6 | import torch 7 | 8 | from torch.utils.cpp_extension import CUDA_HOME 9 | from torch.utils.cpp_extension import CppExtension 10 | from torch.utils.cpp_extension import CUDAExtension 11 | 12 | from setuptools import find_packages 13 | from setuptools import setup 14 | 15 | requirements = ["torch", "torchvision"] 16 | 17 | 18 | def get_extensions(): 19 | this_dir = os.path.dirname(os.path.abspath(__file__)) 20 | extensions_dir = os.path.join(this_dir, "src") 21 | 22 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 23 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 24 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 25 | 26 | sources = main_file + source_cpu 27 | extension = CppExtension 28 | extra_compile_args = {"cxx": []} 29 | define_macros = [] 30 | 31 | if torch.cuda.is_available() and CUDA_HOME is not None: 32 | extension = CUDAExtension 33 | sources += source_cuda 34 | define_macros += [("WITH_CUDA", None)] 35 | extra_compile_args["nvcc"] = [ 36 | "-DCUDA_HAS_FP16=1", 37 | "-D__CUDA_NO_HALF_OPERATORS__", 38 | "-D__CUDA_NO_HALF_CONVERSIONS__", 39 | "-D__CUDA_NO_HALF2_OPERATORS__", 40 | ] 41 | else: 42 | raise NotImplementedError('Cuda is not availabel') 43 | 44 | sources = [os.path.join(extensions_dir, s) for s in sources] 45 | include_dirs = [extensions_dir] 46 | ext_modules = [ 47 | extension( 48 | "_ext", 49 | sources, 50 | include_dirs=include_dirs, 51 | define_macros=define_macros, 52 | extra_compile_args=extra_compile_args, 53 | ) 54 | ] 55 | return ext_modules 56 | 57 | 58 | setup( 59 | name="DCNv2", 60 | version="0.1", 61 | author="charlesshang", 62 | url="https://github.com/charlesshang/DCNv2", 63 | description="deformable convolutional networks", 64 | packages=find_packages(exclude=( 65 | "configs", 66 | "tests", 67 | )), 68 | # install_requires=requirements, 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) -------------------------------------------------------------------------------- /models/DCNv2/src/cpu/dcn_v2_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | at::Tensor 8 | dcn_v2_cpu_forward(const at::Tensor &input, 9 | const at::Tensor &weight, 10 | const at::Tensor &bias, 11 | const at::Tensor &offset, 12 | const at::Tensor &mask, 13 | const int kernel_h, 14 | const int kernel_w, 15 | const int stride_h, 16 | const int stride_w, 17 | const int pad_h, 18 | const int pad_w, 19 | const int dilation_h, 20 | const int dilation_w, 21 | const int deformable_group) 22 | { 23 | AT_ERROR("Not implement on cpu"); 24 | } 25 | 26 | std::vector 27 | dcn_v2_cpu_backward(const at::Tensor &input, 28 | const at::Tensor &weight, 29 | const at::Tensor &bias, 30 | const at::Tensor &offset, 31 | const at::Tensor &mask, 32 | const at::Tensor &grad_output, 33 | int kernel_h, int kernel_w, 34 | int stride_h, int stride_w, 35 | int pad_h, int pad_w, 36 | int dilation_h, int dilation_w, 37 | int deformable_group) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | std::tuple 43 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 44 | const at::Tensor &bbox, 45 | const at::Tensor &trans, 46 | const int no_trans, 47 | const float spatial_scale, 48 | const int output_dim, 49 | const int group_size, 50 | const int pooled_size, 51 | const int part_size, 52 | const int sample_per_part, 53 | const float trans_std) 54 | { 55 | AT_ERROR("Not implement on cpu"); 56 | } 57 | 58 | std::tuple 59 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 60 | const at::Tensor &input, 61 | const at::Tensor &bbox, 62 | const at::Tensor &trans, 63 | const at::Tensor &top_count, 64 | const int no_trans, 65 | const float spatial_scale, 66 | const int output_dim, 67 | const int group_size, 68 | const int pooled_size, 69 | const int part_size, 70 | const int sample_per_part, 71 | const float trans_std) 72 | { 73 | AT_ERROR("Not implement on cpu"); 74 | } -------------------------------------------------------------------------------- /models/DCNv2/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cpu_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cpu_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /models/DCNv2/src/cuda/dcn_v2_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/dcn_v2_im2col_cuda.h" 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | extern THCState *state; 12 | 13 | // author: Charles Shang 14 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 15 | 16 | // [batch gemm] 17 | // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu 18 | 19 | __global__ void createBatchGemmBuffer(const float **input_b, float **output_b, 20 | float **columns_b, const float **ones_b, 21 | const float **weight_b, const float **bias_b, 22 | float *input, float *output, 23 | float *columns, float *ones, 24 | float *weight, float *bias, 25 | const int input_stride, const int output_stride, 26 | const int columns_stride, const int ones_stride, 27 | const int num_batches) 28 | { 29 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 30 | if (idx < num_batches) 31 | { 32 | input_b[idx] = input + idx * input_stride; 33 | output_b[idx] = output + idx * output_stride; 34 | columns_b[idx] = columns + idx * columns_stride; 35 | ones_b[idx] = ones + idx * ones_stride; 36 | // share weights and bias within a Mini-Batch 37 | weight_b[idx] = weight; 38 | bias_b[idx] = bias; 39 | } 40 | } 41 | 42 | at::Tensor 43 | dcn_v2_cuda_forward(const at::Tensor &input, 44 | const at::Tensor &weight, 45 | const at::Tensor &bias, 46 | const at::Tensor &offset, 47 | const at::Tensor &mask, 48 | const int kernel_h, 49 | const int kernel_w, 50 | const int stride_h, 51 | const int stride_w, 52 | const int pad_h, 53 | const int pad_w, 54 | const int dilation_h, 55 | const int dilation_w, 56 | const int deformable_group) 57 | { 58 | using scalar_t = float; 59 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); 60 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 61 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); 62 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); 63 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); 64 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); 65 | 66 | const int batch = input.size(0); 67 | const int channels = input.size(1); 68 | const int height = input.size(2); 69 | const int width = input.size(3); 70 | 71 | const int channels_out = weight.size(0); 72 | const int channels_kernel = weight.size(1); 73 | const int kernel_h_ = weight.size(2); 74 | const int kernel_w_ = weight.size(3); 75 | 76 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); 77 | // printf("Channels: %d %d\n", channels, channels_kernel); 78 | // printf("Channels: %d %d\n", channels_out, channels_kernel); 79 | 80 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 81 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 82 | 83 | AT_ASSERTM(channels == channels_kernel, 84 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 85 | 86 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 87 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 88 | 89 | auto ones = at::ones({batch, height_out, width_out}, input.options()); 90 | auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 91 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 92 | 93 | // prepare for batch-wise computing, which is significantly faster than instance-wise computing 94 | // when batch size is large. 95 | // launch batch threads 96 | int matrices_size = batch * sizeof(float *); 97 | auto input_b = static_cast(THCudaMalloc(state, matrices_size)); 98 | auto output_b = static_cast(THCudaMalloc(state, matrices_size)); 99 | auto columns_b = static_cast(THCudaMalloc(state, matrices_size)); 100 | auto ones_b = static_cast(THCudaMalloc(state, matrices_size)); 101 | auto weight_b = static_cast(THCudaMalloc(state, matrices_size)); 102 | auto bias_b = static_cast(THCudaMalloc(state, matrices_size)); 103 | 104 | const int block = 128; 105 | const int grid = (batch + block - 1) / block; 106 | 107 | createBatchGemmBuffer<<>>( 108 | input_b, output_b, 109 | columns_b, ones_b, 110 | weight_b, bias_b, 111 | input.data(), 112 | output.data(), 113 | columns.data(), 114 | ones.data(), 115 | weight.data(), 116 | bias.data(), 117 | channels * width * height, 118 | channels_out * width_out * height_out, 119 | channels * kernel_h * kernel_w * height_out * width_out, 120 | height_out * width_out, 121 | batch); 122 | 123 | long m_ = channels_out; 124 | long n_ = height_out * width_out; 125 | long k_ = 1; 126 | THCudaBlas_SgemmBatched(state, 127 | 't', 128 | 'n', 129 | n_, 130 | m_, 131 | k_, 132 | 1.0f, 133 | ones_b, k_, 134 | bias_b, k_, 135 | 0.0f, 136 | output_b, n_, 137 | batch); 138 | 139 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), 140 | input.data(), 141 | offset.data(), 142 | mask.data(), 143 | batch, channels, height, width, 144 | height_out, width_out, kernel_h, kernel_w, 145 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, 146 | deformable_group, 147 | columns.data()); 148 | 149 | long m = channels_out; 150 | long n = height_out * width_out; 151 | long k = channels * kernel_h * kernel_w; 152 | THCudaBlas_SgemmBatched(state, 153 | 'n', 154 | 'n', 155 | n, 156 | m, 157 | k, 158 | 1.0f, 159 | (const float **)columns_b, n, 160 | weight_b, k, 161 | 1.0f, 162 | output_b, n, 163 | batch); 164 | 165 | THCudaFree(state, input_b); 166 | THCudaFree(state, output_b); 167 | THCudaFree(state, columns_b); 168 | THCudaFree(state, ones_b); 169 | THCudaFree(state, weight_b); 170 | THCudaFree(state, bias_b); 171 | return output; 172 | } 173 | 174 | __global__ void createBatchGemmBufferBackward( 175 | float **grad_output_b, 176 | float **columns_b, 177 | float **ones_b, 178 | float **weight_b, 179 | float **grad_weight_b, 180 | float **grad_bias_b, 181 | float *grad_output, 182 | float *columns, 183 | float *ones, 184 | float *weight, 185 | float *grad_weight, 186 | float *grad_bias, 187 | const int grad_output_stride, 188 | const int columns_stride, 189 | const int ones_stride, 190 | const int num_batches) 191 | { 192 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 193 | if (idx < num_batches) 194 | { 195 | grad_output_b[idx] = grad_output + idx * grad_output_stride; 196 | columns_b[idx] = columns + idx * columns_stride; 197 | ones_b[idx] = ones + idx * ones_stride; 198 | 199 | // share weights and bias within a Mini-Batch 200 | weight_b[idx] = weight; 201 | grad_weight_b[idx] = grad_weight; 202 | grad_bias_b[idx] = grad_bias; 203 | } 204 | } 205 | 206 | std::vector dcn_v2_cuda_backward(const at::Tensor &input, 207 | const at::Tensor &weight, 208 | const at::Tensor &bias, 209 | const at::Tensor &offset, 210 | const at::Tensor &mask, 211 | const at::Tensor &grad_output, 212 | int kernel_h, int kernel_w, 213 | int stride_h, int stride_w, 214 | int pad_h, int pad_w, 215 | int dilation_h, int dilation_w, 216 | int deformable_group) 217 | { 218 | 219 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); 220 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); 221 | 222 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 223 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); 224 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); 225 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); 226 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); 227 | 228 | const int batch = input.size(0); 229 | const int channels = input.size(1); 230 | const int height = input.size(2); 231 | const int width = input.size(3); 232 | 233 | const int channels_out = weight.size(0); 234 | const int channels_kernel = weight.size(1); 235 | const int kernel_h_ = weight.size(2); 236 | const int kernel_w_ = weight.size(3); 237 | 238 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, 239 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); 240 | 241 | AT_ASSERTM(channels == channels_kernel, 242 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); 243 | 244 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; 245 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; 246 | 247 | auto ones = at::ones({height_out, width_out}, input.options()); 248 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); 249 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); 250 | 251 | auto grad_input = at::zeros_like(input); 252 | auto grad_weight = at::zeros_like(weight); 253 | auto grad_bias = at::zeros_like(bias); 254 | auto grad_offset = at::zeros_like(offset); 255 | auto grad_mask = at::zeros_like(mask); 256 | 257 | using scalar_t = float; 258 | 259 | for (int b = 0; b < batch; b++) 260 | { 261 | auto input_n = input.select(0, b); 262 | auto offset_n = offset.select(0, b); 263 | auto mask_n = mask.select(0, b); 264 | auto grad_output_n = grad_output.select(0, b); 265 | auto grad_input_n = grad_input.select(0, b); 266 | auto grad_offset_n = grad_offset.select(0, b); 267 | auto grad_mask_n = grad_mask.select(0, b); 268 | 269 | long m = channels * kernel_h * kernel_w; 270 | long n = height_out * width_out; 271 | long k = channels_out; 272 | 273 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, 274 | grad_output_n.data(), n, 275 | weight.data(), m, 0.0f, 276 | columns.data(), n); 277 | 278 | // gradient w.r.t. input coordinate data 279 | modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), 280 | columns.data(), 281 | input_n.data(), 282 | offset_n.data(), 283 | mask_n.data(), 284 | 1, channels, height, width, 285 | height_out, width_out, kernel_h, kernel_w, 286 | pad_h, pad_w, stride_h, stride_w, 287 | dilation_h, dilation_w, deformable_group, 288 | grad_offset_n.data(), 289 | grad_mask_n.data()); 290 | // gradient w.r.t. input data 291 | modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), 292 | columns.data(), 293 | offset_n.data(), 294 | mask_n.data(), 295 | 1, channels, height, width, 296 | height_out, width_out, kernel_h, kernel_w, 297 | pad_h, pad_w, stride_h, stride_w, 298 | dilation_h, dilation_w, deformable_group, 299 | grad_input_n.data()); 300 | 301 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group 302 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), 303 | input_n.data(), 304 | offset_n.data(), 305 | mask_n.data(), 306 | 1, channels, height, width, 307 | height_out, width_out, kernel_h, kernel_w, 308 | pad_h, pad_w, stride_h, stride_w, 309 | dilation_h, dilation_w, deformable_group, 310 | columns.data()); 311 | 312 | long m_ = channels_out; 313 | long n_ = channels * kernel_h * kernel_w; 314 | long k_ = height_out * width_out; 315 | 316 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, 317 | columns.data(), k_, 318 | grad_output_n.data(), k_, 1.0f, 319 | grad_weight.data(), n_); 320 | 321 | // gradient w.r.t. bias 322 | // long m_ = channels_out; 323 | // long k__ = height_out * width_out; 324 | THCudaBlas_Sgemv(state, 325 | 't', 326 | k_, m_, 1.0f, 327 | grad_output_n.data(), k_, 328 | ones.data(), 1, 1.0f, 329 | grad_bias.data(), 1); 330 | } 331 | 332 | return { 333 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias 334 | }; 335 | } -------------------------------------------------------------------------------- /models/DCNv2/src/cuda/dcn_v2_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, this 28 | * list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 51 | * 52 | * Copyright (c) 2018 Microsoft 53 | * Licensed under The MIT License [see LICENSE for details] 54 | * \file modulated_deformable_im2col.h 55 | * \brief Function definitions of converting an image to 56 | * column matrix based on kernel, padding, dilation, and offset. 57 | * These functions are mainly used in deformable convolution operators. 58 | * \ref: https://arxiv.org/abs/1811.11168 59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 60 | */ 61 | 62 | /***************** Adapted by Charles Shang *********************/ 63 | 64 | #ifndef DCN_V2_IM2COL_CUDA 65 | #define DCN_V2_IM2COL_CUDA 66 | 67 | #ifdef __cplusplus 68 | extern "C" 69 | { 70 | #endif 71 | 72 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 73 | const float *data_im, const float *data_offset, const float *data_mask, 74 | const int batch_size, const int channels, const int height_im, const int width_im, 75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 77 | const int dilation_h, const int dilation_w, 78 | const int deformable_group, float *data_col); 79 | 80 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 81 | const float *data_col, const float *data_offset, const float *data_mask, 82 | const int batch_size, const int channels, const int height_im, const int width_im, 83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 85 | const int dilation_h, const int dilation_w, 86 | const int deformable_group, float *grad_im); 87 | 88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 90 | const int batch_size, const int channels, const int height_im, const int width_im, 91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 93 | const int dilation_h, const int dilation_w, 94 | const int deformable_group, 95 | float *grad_offset, float *grad_mask); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif -------------------------------------------------------------------------------- /models/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #define CUDA_KERNEL_LOOP(i, n) \ 23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 24 | i < (n); \ 25 | i += blockDim.x * gridDim.x) 26 | 27 | const int CUDA_NUM_THREADS = 1024; 28 | inline int GET_BLOCKS(const int N) 29 | { 30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 31 | } 32 | 33 | template 34 | __device__ T bilinear_interp( 35 | const T *data, 36 | const T x, 37 | const T y, 38 | const int width, 39 | const int height) 40 | { 41 | int x1 = floor(x); 42 | int x2 = ceil(x); 43 | int y1 = floor(y); 44 | int y2 = ceil(y); 45 | T dist_x = static_cast(x - x1); 46 | T dist_y = static_cast(y - y1); 47 | T value11 = data[y1 * width + x1]; 48 | T value12 = data[y2 * width + x1]; 49 | T value21 = data[y1 * width + x2]; 50 | T value22 = data[y2 * width + x2]; 51 | T value = (1 - dist_x) * (1 - dist_y) * value11 + 52 | (1 - dist_x) * dist_y * value12 + 53 | dist_x * (1 - dist_y) * value21 + 54 | dist_x * dist_y * value22; 55 | return value; 56 | } 57 | 58 | template 59 | __global__ void DeformablePSROIPoolForwardKernel( 60 | const int count, 61 | const T *bottom_data, 62 | const T spatial_scale, 63 | const int channels, 64 | const int height, const int width, 65 | const int pooled_height, const int pooled_width, 66 | const T *bottom_rois, const T *bottom_trans, 67 | const int no_trans, 68 | const T trans_std, 69 | const int sample_per_part, 70 | const int output_dim, 71 | const int group_size, 72 | const int part_size, 73 | const int num_classes, 74 | const int channels_each_class, 75 | T *top_data, 76 | T *top_count) 77 | { 78 | CUDA_KERNEL_LOOP(index, count) 79 | { 80 | // The output is in order (n, ctop, ph, pw) 81 | int pw = index % pooled_width; 82 | int ph = (index / pooled_width) % pooled_height; 83 | int ctop = (index / pooled_width / pooled_height) % output_dim; 84 | int n = index / pooled_width / pooled_height / output_dim; 85 | 86 | // [start, end) interval for spatial sampling 87 | const T *offset_bottom_rois = bottom_rois + n * 5; 88 | int roi_batch_ind = offset_bottom_rois[0]; 89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 93 | 94 | // Force too small ROIs to be 1x1 95 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 96 | T roi_height = max(roi_end_h - roi_start_h, 0.1); 97 | 98 | // Compute w and h at bottom 99 | T bin_size_h = roi_height / static_cast(pooled_height); 100 | T bin_size_w = roi_width / static_cast(pooled_width); 101 | 102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); 103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); 104 | 105 | int part_h = floor(static_cast(ph) / pooled_height * part_size); 106 | int part_w = floor(static_cast(pw) / pooled_width * part_size); 107 | int class_id = ctop / channels_each_class; 108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 110 | 111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w; 112 | wstart += trans_x * roi_width; 113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h; 114 | hstart += trans_y * roi_height; 115 | 116 | T sum = 0; 117 | int count = 0; 118 | int gw = floor(static_cast(pw) * group_size / pooled_width); 119 | int gh = floor(static_cast(ph) * group_size / pooled_height); 120 | gw = min(max(gw, 0), group_size - 1); 121 | gh = min(max(gh, 0), group_size - 1); 122 | 123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; 124 | for (int ih = 0; ih < sample_per_part; ih++) 125 | { 126 | for (int iw = 0; iw < sample_per_part; iw++) 127 | { 128 | T w = wstart + iw * sub_bin_size_w; 129 | T h = hstart + ih * sub_bin_size_h; 130 | // bilinear interpolation 131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 132 | { 133 | continue; 134 | } 135 | w = min(max(w, 0.), width - 1.); 136 | h = min(max(h, 0.), height - 1.); 137 | int c = (ctop * group_size + gh) * group_size + gw; 138 | T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); 139 | sum += val; 140 | count++; 141 | } 142 | } 143 | top_data[index] = count == 0 ? static_cast(0) : sum / count; 144 | top_count[index] = count; 145 | } 146 | } 147 | 148 | template 149 | __global__ void DeformablePSROIPoolBackwardAccKernel( 150 | const int count, 151 | const T *top_diff, 152 | const T *top_count, 153 | const int num_rois, 154 | const T spatial_scale, 155 | const int channels, 156 | const int height, const int width, 157 | const int pooled_height, const int pooled_width, 158 | const int output_dim, 159 | T *bottom_data_diff, T *bottom_trans_diff, 160 | const T *bottom_data, 161 | const T *bottom_rois, 162 | const T *bottom_trans, 163 | const int no_trans, 164 | const T trans_std, 165 | const int sample_per_part, 166 | const int group_size, 167 | const int part_size, 168 | const int num_classes, 169 | const int channels_each_class) 170 | { 171 | CUDA_KERNEL_LOOP(index, count) 172 | { 173 | // The output is in order (n, ctop, ph, pw) 174 | int pw = index % pooled_width; 175 | int ph = (index / pooled_width) % pooled_height; 176 | int ctop = (index / pooled_width / pooled_height) % output_dim; 177 | int n = index / pooled_width / pooled_height / output_dim; 178 | 179 | // [start, end) interval for spatial sampling 180 | const T *offset_bottom_rois = bottom_rois + n * 5; 181 | int roi_batch_ind = offset_bottom_rois[0]; 182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; 183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; 184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; 185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; 186 | 187 | // Force too small ROIs to be 1x1 188 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 189 | T roi_height = max(roi_end_h - roi_start_h, 0.1); 190 | 191 | // Compute w and h at bottom 192 | T bin_size_h = roi_height / static_cast(pooled_height); 193 | T bin_size_w = roi_width / static_cast(pooled_width); 194 | 195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); 196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); 197 | 198 | int part_h = floor(static_cast(ph) / pooled_height * part_size); 199 | int part_w = floor(static_cast(pw) / pooled_width * part_size); 200 | int class_id = ctop / channels_each_class; 201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; 202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; 203 | 204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w; 205 | wstart += trans_x * roi_width; 206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h; 207 | hstart += trans_y * roi_height; 208 | 209 | if (top_count[index] <= 0) 210 | { 211 | continue; 212 | } 213 | T diff_val = top_diff[index] / top_count[index]; 214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; 215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; 216 | int gw = floor(static_cast(pw) * group_size / pooled_width); 217 | int gh = floor(static_cast(ph) * group_size / pooled_height); 218 | gw = min(max(gw, 0), group_size - 1); 219 | gh = min(max(gh, 0), group_size - 1); 220 | 221 | for (int ih = 0; ih < sample_per_part; ih++) 222 | { 223 | for (int iw = 0; iw < sample_per_part; iw++) 224 | { 225 | T w = wstart + iw * sub_bin_size_w; 226 | T h = hstart + ih * sub_bin_size_h; 227 | // bilinear interpolation 228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) 229 | { 230 | continue; 231 | } 232 | w = min(max(w, 0.), width - 1.); 233 | h = min(max(h, 0.), height - 1.); 234 | int c = (ctop * group_size + gh) * group_size + gw; 235 | // backward on feature 236 | int x0 = floor(w); 237 | int x1 = ceil(w); 238 | int y0 = floor(h); 239 | int y1 = ceil(h); 240 | T dist_x = w - x0, dist_y = h - y0; 241 | T q00 = (1 - dist_x) * (1 - dist_y); 242 | T q01 = (1 - dist_x) * dist_y; 243 | T q10 = dist_x * (1 - dist_y); 244 | T q11 = dist_x * dist_y; 245 | int bottom_index_base = c * height * width; 246 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); 247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); 248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); 249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); 250 | 251 | if (no_trans) 252 | { 253 | continue; 254 | } 255 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; 256 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; 257 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; 258 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; 259 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; 260 | diff_x *= roi_width; 261 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; 262 | diff_y *= roi_height; 263 | 264 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); 265 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); 266 | } 267 | } 268 | } 269 | } 270 | 271 | std::tuple 272 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 273 | const at::Tensor &bbox, 274 | const at::Tensor &trans, 275 | const int no_trans, 276 | const float spatial_scale, 277 | const int output_dim, 278 | const int group_size, 279 | const int pooled_size, 280 | const int part_size, 281 | const int sample_per_part, 282 | const float trans_std) 283 | { 284 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 285 | AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); 286 | AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); 287 | 288 | const int batch = input.size(0); 289 | const int channels = input.size(1); 290 | const int height = input.size(2); 291 | const int width = input.size(3); 292 | const int channels_trans = no_trans ? 2 : trans.size(1); 293 | const int num_bbox = bbox.size(0); 294 | 295 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); 296 | auto pooled_height = pooled_size; 297 | auto pooled_width = pooled_size; 298 | 299 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); 300 | long out_size = num_bbox * output_dim * pooled_height * pooled_width; 301 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); 302 | 303 | const int num_classes = no_trans ? 1 : channels_trans / 2; 304 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 305 | 306 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 307 | 308 | if (out.numel() == 0) 309 | { 310 | THCudaCheck(cudaGetLastError()); 311 | return std::make_tuple(out, top_count); 312 | } 313 | 314 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); 315 | dim3 block(512); 316 | 317 | AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { 318 | DeformablePSROIPoolForwardKernel<<>>( 319 | out_size, 320 | input.contiguous().data(), 321 | spatial_scale, 322 | channels, 323 | height, width, 324 | pooled_height, 325 | pooled_width, 326 | bbox.contiguous().data(), 327 | trans.contiguous().data(), 328 | no_trans, 329 | trans_std, 330 | sample_per_part, 331 | output_dim, 332 | group_size, 333 | part_size, 334 | num_classes, 335 | channels_each_class, 336 | out.data(), 337 | top_count.data()); 338 | }); 339 | THCudaCheck(cudaGetLastError()); 340 | return std::make_tuple(out, top_count); 341 | } 342 | 343 | std::tuple 344 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 345 | const at::Tensor &input, 346 | const at::Tensor &bbox, 347 | const at::Tensor &trans, 348 | const at::Tensor &top_count, 349 | const int no_trans, 350 | const float spatial_scale, 351 | const int output_dim, 352 | const int group_size, 353 | const int pooled_size, 354 | const int part_size, 355 | const int sample_per_part, 356 | const float trans_std) 357 | { 358 | AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); 359 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 360 | AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); 361 | AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); 362 | AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor"); 363 | 364 | const int batch = input.size(0); 365 | const int channels = input.size(1); 366 | const int height = input.size(2); 367 | const int width = input.size(3); 368 | const int channels_trans = no_trans ? 2 : trans.size(1); 369 | const int num_bbox = bbox.size(0); 370 | 371 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); 372 | auto pooled_height = pooled_size; 373 | auto pooled_width = pooled_size; 374 | long out_size = num_bbox * output_dim * pooled_height * pooled_width; 375 | const int num_classes = no_trans ? 1 : channels_trans / 2; 376 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; 377 | 378 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); 379 | auto trans_grad = at::zeros_like(trans); 380 | 381 | if (input_grad.numel() == 0) 382 | { 383 | THCudaCheck(cudaGetLastError()); 384 | return std::make_tuple(input_grad, trans_grad); 385 | } 386 | 387 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); 388 | dim3 block(512); 389 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 390 | 391 | AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { 392 | DeformablePSROIPoolBackwardAccKernel<<>>( 393 | out_size, 394 | out_grad.contiguous().data(), 395 | top_count.contiguous().data(), 396 | num_bbox, 397 | spatial_scale, 398 | channels, 399 | height, 400 | width, 401 | pooled_height, 402 | pooled_width, 403 | output_dim, 404 | input_grad.contiguous().data(), 405 | trans_grad.contiguous().data(), 406 | input.contiguous().data(), 407 | bbox.contiguous().data(), 408 | trans.contiguous().data(), 409 | no_trans, 410 | trans_std, 411 | sample_per_part, 412 | group_size, 413 | part_size, 414 | num_classes, 415 | channels_each_class); 416 | }); 417 | THCudaCheck(cudaGetLastError()); 418 | return std::make_tuple(input_grad, trans_grad); 419 | } -------------------------------------------------------------------------------- /models/DCNv2/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | dcn_v2_cuda_forward(const at::Tensor &input, 6 | const at::Tensor &weight, 7 | const at::Tensor &bias, 8 | const at::Tensor &offset, 9 | const at::Tensor &mask, 10 | const int kernel_h, 11 | const int kernel_w, 12 | const int stride_h, 13 | const int stride_w, 14 | const int pad_h, 15 | const int pad_w, 16 | const int dilation_h, 17 | const int dilation_w, 18 | const int deformable_group); 19 | 20 | std::vector 21 | dcn_v2_cuda_backward(const at::Tensor &input, 22 | const at::Tensor &weight, 23 | const at::Tensor &bias, 24 | const at::Tensor &offset, 25 | const at::Tensor &mask, 26 | const at::Tensor &grad_output, 27 | int kernel_h, int kernel_w, 28 | int stride_h, int stride_w, 29 | int pad_h, int pad_w, 30 | int dilation_h, int dilation_w, 31 | int deformable_group); 32 | 33 | 34 | std::tuple 35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, 36 | const at::Tensor &bbox, 37 | const at::Tensor &trans, 38 | const int no_trans, 39 | const float spatial_scale, 40 | const int output_dim, 41 | const int group_size, 42 | const int pooled_size, 43 | const int part_size, 44 | const int sample_per_part, 45 | const float trans_std); 46 | 47 | std::tuple 48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, 49 | const at::Tensor &input, 50 | const at::Tensor &bbox, 51 | const at::Tensor &trans, 52 | const at::Tensor &top_count, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); -------------------------------------------------------------------------------- /models/DCNv2/src/dcn_v2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | at::Tensor 10 | dcn_v2_forward(const at::Tensor &input, 11 | const at::Tensor &weight, 12 | const at::Tensor &bias, 13 | const at::Tensor &offset, 14 | const at::Tensor &mask, 15 | const int kernel_h, 16 | const int kernel_w, 17 | const int stride_h, 18 | const int stride_w, 19 | const int pad_h, 20 | const int pad_w, 21 | const int dilation_h, 22 | const int dilation_w, 23 | const int deformable_group) 24 | { 25 | if (input.type().is_cuda()) 26 | { 27 | #ifdef WITH_CUDA 28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask, 29 | kernel_h, kernel_w, 30 | stride_h, stride_w, 31 | pad_h, pad_w, 32 | dilation_h, dilation_w, 33 | deformable_group); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | dcn_v2_backward(const at::Tensor &input, 43 | const at::Tensor &weight, 44 | const at::Tensor &bias, 45 | const at::Tensor &offset, 46 | const at::Tensor &mask, 47 | const at::Tensor &grad_output, 48 | int kernel_h, int kernel_w, 49 | int stride_h, int stride_w, 50 | int pad_h, int pad_w, 51 | int dilation_h, int dilation_w, 52 | int deformable_group) 53 | { 54 | if (input.type().is_cuda()) 55 | { 56 | #ifdef WITH_CUDA 57 | return dcn_v2_cuda_backward(input, 58 | weight, 59 | bias, 60 | offset, 61 | mask, 62 | grad_output, 63 | kernel_h, kernel_w, 64 | stride_h, stride_w, 65 | pad_h, pad_w, 66 | dilation_h, dilation_w, 67 | deformable_group); 68 | #else 69 | AT_ERROR("Not compiled with GPU support"); 70 | #endif 71 | } 72 | AT_ERROR("Not implemented on the CPU"); 73 | } 74 | 75 | std::tuple 76 | dcn_v2_psroi_pooling_forward(const at::Tensor &input, 77 | const at::Tensor &bbox, 78 | const at::Tensor &trans, 79 | const int no_trans, 80 | const float spatial_scale, 81 | const int output_dim, 82 | const int group_size, 83 | const int pooled_size, 84 | const int part_size, 85 | const int sample_per_part, 86 | const float trans_std) 87 | { 88 | if (input.type().is_cuda()) 89 | { 90 | #ifdef WITH_CUDA 91 | return dcn_v2_psroi_pooling_cuda_forward(input, 92 | bbox, 93 | trans, 94 | no_trans, 95 | spatial_scale, 96 | output_dim, 97 | group_size, 98 | pooled_size, 99 | part_size, 100 | sample_per_part, 101 | trans_std); 102 | #else 103 | AT_ERROR("Not compiled with GPU support"); 104 | #endif 105 | } 106 | AT_ERROR("Not implemented on the CPU"); 107 | } 108 | 109 | std::tuple 110 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, 111 | const at::Tensor &input, 112 | const at::Tensor &bbox, 113 | const at::Tensor &trans, 114 | const at::Tensor &top_count, 115 | const int no_trans, 116 | const float spatial_scale, 117 | const int output_dim, 118 | const int group_size, 119 | const int pooled_size, 120 | const int part_size, 121 | const int sample_per_part, 122 | const float trans_std) 123 | { 124 | if (input.type().is_cuda()) 125 | { 126 | #ifdef WITH_CUDA 127 | return dcn_v2_psroi_pooling_cuda_backward(out_grad, 128 | input, 129 | bbox, 130 | trans, 131 | top_count, 132 | no_trans, 133 | spatial_scale, 134 | output_dim, 135 | group_size, 136 | pooled_size, 137 | part_size, 138 | sample_per_part, 139 | trans_std); 140 | #else 141 | AT_ERROR("Not compiled with GPU support"); 142 | #endif 143 | } 144 | AT_ERROR("Not implemented on the CPU"); 145 | } -------------------------------------------------------------------------------- /models/DCNv2/src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "dcn_v2.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); 6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); 7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); 8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); 9 | } 10 | -------------------------------------------------------------------------------- /models/ResBlock.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GlobalGenerator(nn.Module): 5 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 6 | padding_type='reflect'): 7 | assert (n_blocks >= 0) 8 | super(GlobalGenerator, self).__init__() 9 | activation = nn.ReLU(True) 10 | 11 | model = [] 12 | 13 | ### resnet blocks 14 | mult = 2 ** n_downsampling 15 | for i in range(n_blocks): 16 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation)] 17 | 18 | ### upsample 19 | for i in range(n_downsampling): 20 | mult = 2 ** (n_downsampling - i) 21 | # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, 22 | # output_padding=1), 23 | # norm_layer(int(ngf * mult / 2)), activation] 24 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, 25 | output_padding=1), activation] 26 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 27 | self.model = nn.Sequential(*model) 28 | 29 | def forward(self, input): 30 | return self.model(input) 31 | 32 | 33 | class ResnetBlock(nn.Module): 34 | def __init__(self, inchannel, padding_type='reflect', activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False): 35 | super(ResnetBlock, self).__init__() 36 | self.conv_block = self.build_conv_block(inchannel, padding_type, activation, use_dropout) 37 | 38 | def build_conv_block(self, inchannel, padding_type, activation, use_dropout): 39 | conv_block = [] 40 | p = 0 41 | if padding_type == 'reflect': 42 | conv_block += [nn.ReflectionPad2d(1)] 43 | elif padding_type == 'replicate': 44 | conv_block += [nn.ReplicationPad2d(1)] 45 | elif padding_type == 'zero': 46 | p = 1 47 | else: 48 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 49 | 50 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), 51 | activation] 52 | if use_dropout: 53 | conv_block += [nn.Dropout(0.5)] 54 | 55 | p = 0 56 | if padding_type == 'reflect': 57 | conv_block += [nn.ReflectionPad2d(1)] 58 | elif padding_type == 'replicate': 59 | conv_block += [nn.ReplicationPad2d(1)] 60 | elif padding_type == 'zero': 61 | p = 1 62 | else: 63 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 64 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), activation] 65 | 66 | return nn.Sequential(*conv_block) 67 | 68 | def forward(self, x): 69 | out = x + self.conv_block(x) 70 | return out 71 | 72 | 73 | class ResnetBlock_bn(nn.Module): 74 | def __init__(self, inchannel, padding_type='reflect', norm_layer=nn.BatchNorm2d, activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False): 75 | super(ResnetBlock_bn, self).__init__() 76 | self.conv_block = self.build_conv_block(inchannel, padding_type, norm_layer, activation, use_dropout) 77 | 78 | def build_conv_block(self, inchannel, padding_type, norm_layer, activation, use_dropout): 79 | conv_block = [] 80 | p = 0 81 | if padding_type == 'reflect': 82 | conv_block += [nn.ReflectionPad2d(1)] 83 | elif padding_type == 'replicate': 84 | conv_block += [nn.ReplicationPad2d(1)] 85 | elif padding_type == 'zero': 86 | p = 1 87 | else: 88 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 89 | 90 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), 91 | norm_layer(inchannel), 92 | activation] 93 | if use_dropout: 94 | conv_block += [nn.Dropout(0.5)] 95 | 96 | p = 0 97 | if padding_type == 'reflect': 98 | conv_block += [nn.ReflectionPad2d(1)] 99 | elif padding_type == 'replicate': 100 | conv_block += [nn.ReplicationPad2d(1)] 101 | elif padding_type == 'zero': 102 | p = 1 103 | else: 104 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 105 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), 106 | norm_layer(inchannel)] 107 | 108 | return nn.Sequential(*conv_block) 109 | 110 | def forward(self, x): 111 | out = x + self.conv_block(x) 112 | return out 113 | 114 | 115 | class SemiResnetBlock(nn.Module): 116 | def __init__(self, inchannel, outchannel, padding_type='reflect', activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False, end=False): 117 | super(SemiResnetBlock, self).__init__() 118 | self.conv_block1 = self.build_conv_block1(inchannel, padding_type, activation, use_dropout) 119 | self.conv_block2 = self.build_conv_block2(inchannel, outchannel, padding_type, activation, 120 | use_dropout, end) 121 | 122 | def build_conv_block1(self, inchannel, padding_type, activation, use_dropout): 123 | conv_block = [] 124 | p = 0 125 | if padding_type == 'reflect': 126 | conv_block += [nn.ReflectionPad2d(1)] 127 | elif padding_type == 'replicate': 128 | conv_block += [nn.ReplicationPad2d(1)] 129 | elif padding_type == 'zero': 130 | p = 1 131 | else: 132 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 133 | 134 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), 135 | activation] 136 | if use_dropout: 137 | conv_block += [nn.Dropout(0.5)] 138 | 139 | return nn.Sequential(*conv_block) 140 | 141 | def build_conv_block2(self, inchannel, outchannel, padding_type, activation, use_dropout, end): 142 | conv_block = [] 143 | p = 0 144 | if padding_type == 'reflect': 145 | conv_block += [nn.ReflectionPad2d(1)] 146 | elif padding_type == 'replicate': 147 | conv_block += [nn.ReplicationPad2d(1)] 148 | elif padding_type == 'zero': 149 | p = 1 150 | else: 151 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 152 | if end: 153 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p)] 154 | else: 155 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p), 156 | activation] 157 | return nn.Sequential(*conv_block) 158 | 159 | def forward(self, x): 160 | x = self.conv_block1(x) + x 161 | out = self.conv_block2(x) 162 | return out 163 | 164 | class SemiResnetBlock_bn(nn.Module): 165 | def __init__(self, inchannel, outchannel, padding_type='reflect', norm_layer=nn.BatchNorm2d, activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), use_dropout=False, end=False): 166 | super(SemiResnetBlock_bn, self).__init__() 167 | self.conv_block1 = self.build_conv_block1(inchannel, padding_type, norm_layer, activation, use_dropout) 168 | self.conv_block2 = self.build_conv_block2(inchannel, outchannel, padding_type, norm_layer, activation, 169 | use_dropout, end) 170 | 171 | def build_conv_block1(self, inchannel, padding_type, norm_layer, activation, use_dropout): 172 | conv_block = [] 173 | p = 0 174 | if padding_type == 'reflect': 175 | conv_block += [nn.ReflectionPad2d(1)] 176 | elif padding_type == 'replicate': 177 | conv_block += [nn.ReplicationPad2d(1)] 178 | elif padding_type == 'zero': 179 | p = 1 180 | else: 181 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 182 | 183 | conv_block += [nn.Conv2d(inchannel, inchannel, kernel_size=3, padding=p), 184 | norm_layer(inchannel), 185 | activation] 186 | if use_dropout: 187 | conv_block += [nn.Dropout(0.5)] 188 | 189 | return nn.Sequential(*conv_block) 190 | 191 | def build_conv_block2(self, inchannel, outchannel, padding_type, norm_layer, activation, use_dropout, end): 192 | conv_block = [] 193 | p = 0 194 | if padding_type == 'reflect': 195 | conv_block += [nn.ReflectionPad2d(1)] 196 | elif padding_type == 'replicate': 197 | conv_block += [nn.ReplicationPad2d(1)] 198 | elif padding_type == 'zero': 199 | p = 1 200 | else: 201 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 202 | if end: 203 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p)] 204 | else: 205 | conv_block += [nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=p), 206 | norm_layer(outchannel), activation] 207 | return nn.Sequential(*conv_block) 208 | 209 | def forward(self, x): 210 | x = self.conv_block1(x) + x 211 | out = self.conv_block2(x) 212 | return out 213 | -------------------------------------------------------------------------------- /models/Unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class down(nn.Module): 7 | """ 8 | A class for creating neural network blocks containing layers: 9 | 10 | Average Pooling --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU 11 | 12 | This is used in the UNet Class to create a UNet like NN architecture. 13 | 14 | ... 15 | 16 | Methods 17 | ------- 18 | forward(x) 19 | Returns output tensor after passing input `x` to the neural network 20 | block. 21 | """ 22 | 23 | 24 | def __init__(self, inChannels, outChannels, filterSize): 25 | """ 26 | Parameters 27 | ---------- 28 | inChannels : int 29 | number of input channels for the first convolutional layer. 30 | outChannels : int 31 | number of output channels for the first convolutional layer. 32 | This is also used as input and output channels for the 33 | second convolutional layer. 34 | filterSize : int 35 | filter size for the convolution filter. input N would create 36 | a N x N filter. 37 | """ 38 | 39 | 40 | super(down, self).__init__() 41 | # Initialize convolutional layers. 42 | self.conv1 = nn.Conv2d(inChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2)) 43 | self.conv2 = nn.Conv2d(outChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2)) 44 | 45 | def forward(self, x): 46 | """ 47 | Returns output tensor after passing input `x` to the neural network 48 | block. 49 | 50 | Parameters 51 | ---------- 52 | x : tensor 53 | input to the NN block. 54 | 55 | Returns 56 | ------- 57 | tensor 58 | output of the NN block. 59 | """ 60 | 61 | 62 | # Average pooling with kernel size 2 (2 x 2). 63 | x = F.avg_pool2d(x, 2) 64 | # Convolution + Leaky ReLU 65 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) 66 | # Convolution + Leaky ReLU 67 | x = F.leaky_relu(self.conv2(x), negative_slope = 0.1) 68 | return x 69 | 70 | class up(nn.Module): 71 | """ 72 | A class for creating neural network blocks containing layers: 73 | 74 | Bilinear interpolation --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU 75 | 76 | This is used in the UNet Class to create a UNet like NN architecture. 77 | 78 | ... 79 | 80 | Methods 81 | ------- 82 | forward(x, skpCn) 83 | Returns output tensor after passing input `x` to the neural network 84 | block. 85 | """ 86 | 87 | 88 | def __init__(self, inChannels, outChannels): 89 | """ 90 | Parameters 91 | ---------- 92 | inChannels : int 93 | number of input channels for the first convolutional layer. 94 | outChannels : int 95 | number of output channels for the first convolutional layer. 96 | This is also used for setting input and output channels for 97 | the second convolutional layer. 98 | """ 99 | 100 | 101 | super(up, self).__init__() 102 | # Initialize convolutional layers. 103 | self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1) 104 | # (2 * outChannels) is used for accommodating skip connection. 105 | self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1) 106 | 107 | def forward(self, x, skpCn): 108 | """ 109 | Returns output tensor after passing input `x` to the neural network 110 | block. 111 | 112 | Parameters 113 | ---------- 114 | x : tensor 115 | input to the NN block. 116 | skpCn : tensor 117 | skip connection input to the NN block. 118 | 119 | Returns 120 | ------- 121 | tensor 122 | output of the NN block. 123 | """ 124 | 125 | # Bilinear interpolation with scaling 2. 126 | x = F.interpolate(x, scale_factor=2, mode='bilinear') 127 | # Convolution + Leaky ReLU 128 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) 129 | # Convolution + Leaky ReLU on (`x`, `skpCn`) 130 | x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1) 131 | return x 132 | 133 | 134 | 135 | class UNet(nn.Module): 136 | """ 137 | A class for creating UNet like architecture as specified by the 138 | Super SloMo paper. 139 | 140 | ... 141 | 142 | Methods 143 | ------- 144 | forward(x) 145 | Returns output tensor after passing input `x` to the neural network 146 | block. 147 | """ 148 | 149 | 150 | def __init__(self, inChannels, outChannels): 151 | """ 152 | Parameters 153 | ---------- 154 | inChannels : int 155 | number of input channels for the UNet. 156 | outChannels : int 157 | number of output channels for the UNet. 158 | """ 159 | 160 | 161 | super(UNet, self).__init__() 162 | # Initialize neural network blocks. 163 | self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3) 164 | self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3) 165 | self.down1 = down(32, 64, 5) 166 | self.down2 = down(64, 128, 3) 167 | self.down3 = down(128, 256, 3) 168 | self.down4 = down(256, 512, 3) 169 | self.down5 = down(512, 512, 3) 170 | self.up1 = up(512, 512) 171 | self.up2 = up(512, 256) 172 | self.up3 = up(256, 128) 173 | self.up4 = up(128, 64) 174 | self.up5 = up(64, 32) 175 | self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1) 176 | 177 | def forward(self, x): 178 | """ 179 | Returns output tensor after passing input `x` to the neural network. 180 | 181 | Parameters 182 | ---------- 183 | x : tensor 184 | input to the UNet. 185 | 186 | Returns 187 | ------- 188 | tensor 189 | output of the UNet. 190 | """ 191 | 192 | 193 | x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) 194 | s1 = F.leaky_relu(self.conv2(x), negative_slope = 0.1) 195 | s2 = self.down1(s1) 196 | s3 = self.down2(s2) 197 | s4 = self.down3(s3) 198 | s5 = self.down4(s4) 199 | x = self.down5(s5) 200 | x = self.up1(x, s5) 201 | x = self.up2(x, s4) 202 | x = self.up3(x, s3) 203 | x = self.up4(x, s2) 204 | x = self.up5(x, s1) 205 | x = F.leaky_relu(self.conv3(x), negative_slope = 0.1) 206 | return x -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/models/__init__.py -------------------------------------------------------------------------------- /models/bdcn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/models/bdcn/__init__.py -------------------------------------------------------------------------------- /models/bdcn/bdcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import models.bdcn.vgg16_c as vgg16_c 6 | 7 | def crop(data1, data2, crop_h, crop_w): 8 | _, _, h1, w1 = data1.size() 9 | _, _, h2, w2 = data2.size() 10 | assert(h2 <= h1 and w2 <= w1) 11 | data = data1[:, :, crop_h:crop_h+h2, crop_w:crop_w+w2] 12 | return data 13 | 14 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 15 | """Make a 2D bilinear kernel suitable for upsampling""" 16 | factor = (kernel_size + 1) // 2 17 | if kernel_size % 2 == 1: 18 | center = factor - 1 19 | else: 20 | center = factor - 0.5 21 | og = np.ogrid[:kernel_size, :kernel_size] 22 | filt = (1 - abs(og[0] - center) / factor) * \ 23 | (1 - abs(og[1] - center) / factor) 24 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 25 | dtype=np.float64) 26 | weight[range(in_channels), range(out_channels), :, :] = filt 27 | return torch.from_numpy(weight).float() 28 | 29 | class MSBlock(nn.Module): 30 | def __init__(self, c_in, rate=4): 31 | super(MSBlock, self).__init__() 32 | c_out = c_in 33 | self.rate = rate 34 | 35 | self.conv = nn.Conv2d(c_in, 32, 3, stride=1, padding=1) 36 | self.relu = nn.ReLU(inplace=True) 37 | dilation = self.rate*1 if self.rate >= 1 else 1 38 | self.conv1 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) 39 | self.relu1 = nn.ReLU(inplace=True) 40 | dilation = self.rate*2 if self.rate >= 1 else 1 41 | self.conv2 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) 42 | self.relu2 = nn.ReLU(inplace=True) 43 | dilation = self.rate*3 if self.rate >= 1 else 1 44 | self.conv3 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) 45 | self.relu3 = nn.ReLU(inplace=True) 46 | 47 | self._initialize_weights() 48 | 49 | def forward(self, x): 50 | o = self.relu(self.conv(x)) 51 | o1 = self.relu1(self.conv1(o)) 52 | o2 = self.relu2(self.conv2(o)) 53 | o3 = self.relu3(self.conv3(o)) 54 | out = o + o1 + o2 + o3 55 | return out 56 | 57 | def _initialize_weights(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | m.weight.data.normal_(0, 0.01) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | 64 | 65 | class BDCN(nn.Module): 66 | def __init__(self, pretrain=None, logger=None, rate=4): 67 | super(BDCN, self).__init__() 68 | self.pretrain = pretrain 69 | t = 1 70 | 71 | self.features = vgg16_c.VGG16_C(pretrain, logger) 72 | self.msblock1_1 = MSBlock(64, rate) 73 | self.msblock1_2 = MSBlock(64, rate) 74 | self.conv1_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 75 | self.conv1_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 76 | self.score_dsn1 = nn.Conv2d(21, 1, (1, 1), stride=1) 77 | self.score_dsn1_1 = nn.Conv2d(21, 1, 1, stride=1) 78 | self.msblock2_1 = MSBlock(128, rate) 79 | self.msblock2_2 = MSBlock(128, rate) 80 | self.conv2_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 81 | self.conv2_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 82 | self.score_dsn2 = nn.Conv2d(21, 1, (1, 1), stride=1) 83 | self.score_dsn2_1 = nn.Conv2d(21, 1, (1, 1), stride=1) 84 | self.msblock3_1 = MSBlock(256, rate) 85 | self.msblock3_2 = MSBlock(256, rate) 86 | self.msblock3_3 = MSBlock(256, rate) 87 | self.conv3_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 88 | self.conv3_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 89 | self.conv3_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 90 | self.score_dsn3 = nn.Conv2d(21, 1, (1, 1), stride=1) 91 | self.score_dsn3_1 = nn.Conv2d(21, 1, (1, 1), stride=1) 92 | self.msblock4_1 = MSBlock(512, rate) 93 | self.msblock4_2 = MSBlock(512, rate) 94 | self.msblock4_3 = MSBlock(512, rate) 95 | self.conv4_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 96 | self.conv4_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 97 | self.conv4_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 98 | self.score_dsn4 = nn.Conv2d(21, 1, (1, 1), stride=1) 99 | self.score_dsn4_1 = nn.Conv2d(21, 1, (1, 1), stride=1) 100 | self.msblock5_1 = MSBlock(512, rate) 101 | self.msblock5_2 = MSBlock(512, rate) 102 | self.msblock5_3 = MSBlock(512, rate) 103 | self.conv5_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 104 | self.conv5_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 105 | self.conv5_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) 106 | self.score_dsn5 = nn.Conv2d(21, 1, (1, 1), stride=1) 107 | self.score_dsn5_1 = nn.Conv2d(21, 1, (1, 1), stride=1) 108 | self.upsample_2 = nn.ConvTranspose2d(1, 1, 4, stride=2, bias=False) 109 | self.upsample_4 = nn.ConvTranspose2d(1, 1, 8, stride=4, bias=False) 110 | self.upsample_8 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False) 111 | self.upsample_8_5 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False) 112 | self.fuse = nn.Conv2d(10, 1, 1, stride=1) 113 | 114 | self._initialize_weights(logger) 115 | 116 | def forward(self, x): 117 | features = self.features(x) 118 | sum1 = self.conv1_1_down(self.msblock1_1(features[0])) + \ 119 | self.conv1_2_down(self.msblock1_2(features[1])) 120 | s1 = self.score_dsn1(sum1) 121 | s11 = self.score_dsn1_1(sum1) 122 | # print(s1.data.shape, s11.data.shape) 123 | sum2 = self.conv2_1_down(self.msblock2_1(features[2])) + \ 124 | self.conv2_2_down(self.msblock2_2(features[3])) 125 | s2 = self.score_dsn2(sum2) 126 | s21 = self.score_dsn2_1(sum2) 127 | s2 = self.upsample_2(s2) 128 | s21 = self.upsample_2(s21) 129 | # print(s2.data.shape, s21.data.shape) 130 | s2 = crop(s2, x, 1, 1) 131 | s21 = crop(s21, x, 1, 1) 132 | sum3 = self.conv3_1_down(self.msblock3_1(features[4])) + \ 133 | self.conv3_2_down(self.msblock3_2(features[5])) + \ 134 | self.conv3_3_down(self.msblock3_3(features[6])) 135 | s3 = self.score_dsn3(sum3) 136 | s3 =self.upsample_4(s3) 137 | # print(s3.data.shape) 138 | s3 = crop(s3, x, 2, 2) 139 | s31 = self.score_dsn3_1(sum3) 140 | s31 =self.upsample_4(s31) 141 | # print(s31.data.shape) 142 | s31 = crop(s31, x, 2, 2) 143 | sum4 = self.conv4_1_down(self.msblock4_1(features[7])) + \ 144 | self.conv4_2_down(self.msblock4_2(features[8])) + \ 145 | self.conv4_3_down(self.msblock4_3(features[9])) 146 | s4 = self.score_dsn4(sum4) 147 | s4 = self.upsample_8(s4) 148 | # print(s4.data.shape) 149 | s4 = crop(s4, x, 4, 4) 150 | s41 = self.score_dsn4_1(sum4) 151 | s41 = self.upsample_8(s41) 152 | # print(s41.data.shape) 153 | s41 = crop(s41, x, 4, 4) 154 | sum5 = self.conv5_1_down(self.msblock5_1(features[10])) + \ 155 | self.conv5_2_down(self.msblock5_2(features[11])) + \ 156 | self.conv5_3_down(self.msblock5_3(features[12])) 157 | s5 = self.score_dsn5(sum5) 158 | s5 = self.upsample_8_5(s5) 159 | # print(s5.data.shape) 160 | s5 = crop(s5, x, 0, 0) 161 | s51 = self.score_dsn5_1(sum5) 162 | s51 = self.upsample_8_5(s51) 163 | # print(s51.data.shape) 164 | s51 = crop(s51, x, 0, 0) 165 | o1, o2, o3, o4, o5 = s1.detach(), s2.detach(), s3.detach(), s4.detach(), s5.detach() 166 | o11, o21, o31, o41, o51 = s11.detach(), s21.detach(), s31.detach(), s41.detach(), s51.detach() 167 | p1_1 = s1 168 | p2_1 = s2 + o1 169 | p3_1 = s3 + o2 + o1 170 | p4_1 = s4 + o3 + o2 + o1 171 | p5_1 = s5 + o4 + o3 + o2 + o1 172 | p1_2 = s11 + o21 + o31 + o41 + o51 173 | p2_2 = s21 + o31 + o41 + o51 174 | p3_2 = s31 + o41 + o51 175 | p4_2 = s41 + o51 176 | p5_2 = s51 177 | 178 | fuse = self.fuse(torch.cat([p1_1, p2_1, p3_1, p4_1, p5_1, p1_2, p2_2, p3_2, p4_2, p5_2], 1)) 179 | 180 | return [p1_1, p2_1, p3_1, p4_1, p5_1, p1_2, p2_2, p3_2, p4_2, p5_2, fuse] 181 | 182 | def _initialize_weights(self, logger=None): 183 | for name, param in self.state_dict().items(): 184 | if self.pretrain and 'features' in name: 185 | continue 186 | # elif 'down' in name: 187 | # param.zero_() 188 | elif 'upsample' in name: 189 | if logger: 190 | logger.info('init upsamle layer %s ' % name) 191 | k = int(name.split('.')[0].split('_')[1]) 192 | param.copy_(get_upsampling_weight(1, 1, k*2)) 193 | elif 'fuse' in name: 194 | if logger: 195 | logger.info('init params %s ' % name) 196 | if 'bias' in name: 197 | param.zero_() 198 | else: 199 | nn.init.constant(param, 0.080) 200 | else: 201 | if logger: 202 | logger.info('init params %s ' % name) 203 | if 'bias' in name: 204 | param.zero_() 205 | else: 206 | param.normal_(0, 0.01) 207 | # print self.conv1_1_down.weight 208 | 209 | if __name__ == '__main__': 210 | model = BDCN('./caffemodel2pytorch/vgg16.pth') 211 | a=torch.rand((2,3,100,100)) 212 | a=torch.autograd.Variable(a) 213 | for x in model(a): 214 | print(x.data.shape) 215 | # for name, param in model.state_dict().items(): 216 | # print name, param 217 | -------------------------------------------------------------------------------- /models/bdcn/vgg16_c.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | import math 6 | 7 | class VGG16_C(nn.Module): 8 | """""" 9 | def __init__(self, pretrain=None, logger=None): 10 | super(VGG16_C, self).__init__() 11 | self.conv1_1 = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1) 12 | self.relu1_1 = nn.ReLU(inplace=True) 13 | self.conv1_2 = nn.Conv2d(64, 64, (3, 3), stride=1, padding=1) 14 | self.relu1_2 = nn.ReLU(inplace=True) 15 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 16 | self.conv2_1 = nn.Conv2d(64, 128, (3, 3), stride=1, padding=1) 17 | self.relu2_1 = nn.ReLU(inplace=True) 18 | self.conv2_2 = nn.Conv2d(128, 128, (3, 3), stride=1, padding=1) 19 | self.relu2_2 = nn.ReLU(inplace=True) 20 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 21 | self.conv3_1 = nn.Conv2d(128, 256, (3, 3), stride=1, padding=1) 22 | self.relu3_1 = nn.ReLU(inplace=True) 23 | self.conv3_2 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1) 24 | self.relu3_2 = nn.ReLU(inplace=True) 25 | self.conv3_3 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1) 26 | self.relu3_3 = nn.ReLU(inplace=True) 27 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 28 | self.conv4_1 = nn.Conv2d(256, 512, (3, 3), stride=1, padding=1) 29 | self.relu4_1 = nn.ReLU(inplace=True) 30 | self.conv4_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1) 31 | self.relu4_2 = nn.ReLU(inplace=True) 32 | self.conv4_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1) 33 | self.relu4_3 = nn.ReLU(inplace=True) 34 | self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True) 35 | self.conv5_1 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) 36 | self.relu5_1 = nn.ReLU(inplace=True) 37 | self.conv5_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) 38 | self.relu5_2 = nn.ReLU(inplace=True) 39 | self.conv5_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) 40 | self.relu5_3 = nn.ReLU(inplace=True) 41 | if pretrain: 42 | if '.npy' in pretrain: 43 | state_dict = np.load(pretrain).item() 44 | for k in state_dict: 45 | state_dict[k] = torch.from_numpy(state_dict[k]) 46 | else: 47 | state_dict = torch.load(pretrain) 48 | own_state_dict = self.state_dict() 49 | for name, param in own_state_dict.items(): 50 | if name in state_dict: 51 | if logger: 52 | logger.info('copy the weights of %s from pretrained model' % name) 53 | param.copy_(state_dict[name]) 54 | else: 55 | if logger: 56 | logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\ 57 | % name) 58 | if 'bias' in name: 59 | param.zero_() 60 | else: 61 | param.normal_(0, 0.01) 62 | else: 63 | self._initialize_weights(logger) 64 | 65 | def forward(self, x): 66 | conv1_1 = self.relu1_1(self.conv1_1(x)) 67 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) 68 | pool1 = self.pool1(conv1_2) 69 | conv2_1 = self.relu2_1(self.conv2_1(pool1)) 70 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) 71 | pool2 = self.pool2(conv2_2) 72 | conv3_1 = self.relu3_1(self.conv3_1(pool2)) 73 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) 74 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) 75 | pool3 = self.pool3(conv3_3) 76 | conv4_1 = self.relu4_1(self.conv4_1(pool3)) 77 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) 78 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) 79 | pool4 = self.pool4(conv4_3) 80 | # pool4 = conv4_3 81 | conv5_1 = self.relu5_1(self.conv5_1(pool4)) 82 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) 83 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) 84 | 85 | side = [conv1_1, conv1_2, conv2_1, conv2_2, 86 | conv3_1, conv3_2, conv3_3, conv4_1, 87 | conv4_2, conv4_3, conv5_1, conv5_2, conv5_3] 88 | return side 89 | 90 | def _initialize_weights(self, logger=None): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | if logger: 94 | logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\ 95 | % m) 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | if m.bias is not None: 99 | m.bias.data.zero_() 100 | elif isinstance(m, nn.BatchNorm2d): 101 | m.weight.data.fill_(1) 102 | m.bias.data.zero_() 103 | elif isinstance(m, nn.Linear): 104 | m.weight.data.normal_(0, 0.01) 105 | m.bias.data.zero_() 106 | 107 | if __name__ == '__main__': 108 | model = VGG16_C() 109 | # im = np.zeros((1,3,100,100)) 110 | # out = model(Variable(torch.from_numpy(im))) 111 | 112 | 113 | -------------------------------------------------------------------------------- /models/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class backWarp(nn.Module): 8 | """ 9 | A class for creating a backwarping object. 10 | 11 | This is used for backwarping to an image: 12 | 13 | Given optical flow from frame I0 to I1 --> F_0_1 and frame I1, 14 | it generates I0 <-- backwarp(F_0_1, I1). 15 | 16 | ... 17 | 18 | Methods 19 | ------- 20 | forward(x) 21 | Returns output tensor after passing input `img` and `flow` to the backwarping 22 | block. 23 | """ 24 | 25 | 26 | def __init__(self, H, W): 27 | """ 28 | Parameters 29 | ---------- 30 | W : int 31 | width of the image. 32 | H : int 33 | height of the image. 34 | device : device 35 | computation device (cpu/cuda). 36 | """ 37 | 38 | 39 | super(backWarp, self).__init__() 40 | # create a grid 41 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) 42 | self.W = W 43 | self.H = H 44 | self.gridX = torch.nn.Parameter(torch.tensor(gridX), requires_grad=False) 45 | self.gridY = torch.nn.Parameter(torch.tensor(gridY), requires_grad=False) 46 | 47 | def forward(self, img, flow): 48 | """ 49 | Returns output tensor after passing input `img` and `flow` to the backwarping 50 | block. 51 | I0 = backwarp(I1, F_0_1) 52 | 53 | Parameters 54 | ---------- 55 | img : tensor 56 | frame I1. 57 | flow : tensor 58 | optical flow from I0 and I1: F_0_1. 59 | 60 | Returns 61 | ------- 62 | tensor 63 | frame I0. 64 | """ 65 | 66 | 67 | # Extract horizontal and vertical flows. 68 | u = flow[:, 0, :, :] 69 | v = flow[:, 1, :, :] 70 | x = self.gridX.unsqueeze(0).expand_as(u).float() + u 71 | y = self.gridY.unsqueeze(0).expand_as(v).float() + v 72 | # range -1 to 1 73 | x = 2*(x/self.W - 0.5) 74 | y = 2*(y/self.H - 0.5) 75 | # stacking X and Y 76 | grid = torch.stack((x,y), dim=3) 77 | # Sample pixels using bilinear interpolation. 78 | imgOut = torch.nn.functional.grid_sample(img, grid, padding_mode='border') 79 | return imgOut 80 | 81 | 82 | # Creating an array of `t` values for the 7 intermediate frames between 83 | # reference frames I0 and I1. 84 | class Coeff(nn.Module): 85 | 86 | def __init__(self): 87 | super(Coeff, self).__init__() 88 | self.t = torch.nn.Parameter(torch.FloatTensor(np.linspace(0.125, 0.875, 7)), requires_grad=False) 89 | 90 | def getFlowCoeff (self, indices): 91 | """ 92 | Gets flow coefficients used for calculating intermediate optical 93 | flows from optical flows between I0 and I1: F_0_1 and F_1_0. 94 | 95 | F_t_0 = C00 x F_0_1 + C01 x F_1_0 96 | F_t_1 = C10 x F_0_1 + C11 x F_1_0 97 | 98 | where, 99 | C00 = -(1 - t) x t 100 | C01 = t x t 101 | C10 = (1 - t) x (1 - t) 102 | C11 = -t x (1 - t) 103 | 104 | Parameters 105 | ---------- 106 | indices : tensor 107 | indices corresponding to the intermediate frame positions 108 | of all samples in the batch. 109 | device : device 110 | computation device (cpu/cuda). 111 | 112 | Returns 113 | ------- 114 | tensor 115 | coefficients C00, C01, C10, C11. 116 | """ 117 | 118 | 119 | # Convert indices tensor to numpy array 120 | ind = indices.detach() 121 | C11 = C00 = - (1 - (self.t[ind])) * (self.t[ind]) 122 | C01 = (self.t[ind]) * (self.t[ind]) 123 | C10 = (1 - (self.t[ind])) * (1 - (self.t[ind])) 124 | return C00[None, None, None, :].permute(3, 0, 1, 2), C01[None, None, None, :].permute(3, 0, 1, 2), C10[None, None, None, :].permute(3, 0, 1, 2), C11[None, None, None, :].permute(3, 0, 1, 2) 125 | 126 | def getWarpCoeff (self, indices): 127 | """ 128 | Gets coefficients used for calculating final intermediate 129 | frame `It_gen` from backwarped images using flows F_t_0 and F_t_1. 130 | 131 | It_gen = (C0 x V_t_0 x g_I_0_F_t_0 + C1 x V_t_1 x g_I_1_F_t_1) / (C0 x V_t_0 + C1 x V_t_1) 132 | 133 | where, 134 | C0 = 1 - t 135 | C1 = t 136 | 137 | V_t_0, V_t_1 --> visibility maps 138 | g_I_0_F_t_0, g_I_1_F_t_1 --> backwarped intermediate frames 139 | 140 | Parameters 141 | ---------- 142 | indices : tensor 143 | indices corresponding to the intermediate frame positions 144 | of all samples in the batch. 145 | device : device 146 | computation device (cpu/cuda). 147 | 148 | Returns 149 | ------- 150 | tensor 151 | coefficients C0 and C1. 152 | """ 153 | 154 | 155 | # Convert indices tensor to numpy array 156 | ind = indices.detach() 157 | C0 = 1 - self.t[ind] 158 | C1 = self.t[ind] 159 | return C0[None, None, None, :].permute(3, 0, 1, 2), C1[None, None, None, :].permute(3, 0, 1, 2) 160 | 161 | def set_t(self, factor): 162 | ti = 1 / factor 163 | self.t = torch.nn.Parameter(torch.FloatTensor(np.linspace(ti, 1 - ti, factor - 1)), requires_grad=False) -------------------------------------------------------------------------------- /paper/FeatureFlow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/paper/FeatureFlow.pdf -------------------------------------------------------------------------------- /paper/Supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/paper/Supp.pdf -------------------------------------------------------------------------------- /pure_run.py: -------------------------------------------------------------------------------- 1 | # SeDraw 2 | import argparse 3 | import os 4 | import torch 5 | import cv2 6 | import torchvision.transforms as transforms 7 | from skimage.measure import compare_psnr 8 | from PIL import Image 9 | import src.pure_network as layers 10 | from tqdm import tqdm 11 | import numpy as np 12 | import math 13 | import models.bdcn.bdcn as bdcn 14 | 15 | # For parsing commandline arguments 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 20 | return True 21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError('Boolean value expected.') 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') 28 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3') 29 | parser.add_argument('--bdcn_model', type=str, default='/home/visiting/Projects/citrine/SeDraw/models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth') 30 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.') 31 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint') 32 | parser.add_argument('--imgpath', type=str, required=True) 33 | parser.add_argument('--first', type=str, required=True) 34 | parser.add_argument('--second', type=str, required=True) 35 | parser.add_argument('--gt', type=str, required=True) 36 | args = parser.parse_args() 37 | 38 | 39 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): 40 | 41 | 42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 43 | with open(path, 'rb') as f: 44 | img = Image.open(f) 45 | # Resize image if specified. 46 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img 47 | # Crop image if crop area specified. 48 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img 49 | # Flip image horizontally if specified. 50 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img 51 | return flipped_img.convert('RGB') 52 | 53 | 54 | 55 | bdcn = bdcn.BDCN() 56 | bdcn.cuda() 57 | structure_gen = layers.StructureGen(feature_level=args.feature_level) 58 | structure_gen.cuda() 59 | detail_enhance = layers.DetailEnhance() 60 | detail_enhance.cuda() 61 | 62 | 63 | # Channel wise mean calculated on adobe240-fps training dataset 64 | mean = [0.5, 0.5, 0.5] 65 | std = [0.5, 0.5, 0.5] 66 | normalize = transforms.Normalize(mean=mean, 67 | std=std) 68 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 69 | 70 | negmean = [-1 for x in mean] 71 | restd = [2, 2, 2] 72 | revNormalize = transforms.Normalize(mean=negmean, std=restd) 73 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) 74 | 75 | 76 | def ToImage(frame0, frame1): 77 | 78 | with torch.no_grad(): 79 | 80 | img0 = frame0.cuda() 81 | img1 = frame1.cuda() 82 | 83 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) 84 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) 85 | ref_imgt, _ = structure_gen((img0_e, img1_e)) 86 | imgt = detail_enhance((img0, img1, ref_imgt)) 87 | # imgt = detail_enhance((img0, img1, imgt)) 88 | imgt = torch.clamp(imgt, max=1., min=-1.) 89 | 90 | return imgt 91 | 92 | 93 | def main(): 94 | # initial 95 | 96 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) 97 | dict1 = torch.load(args.checkpoint) 98 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) 99 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) 100 | 101 | bdcn.eval() 102 | structure_gen.eval() 103 | detail_enhance.eval() 104 | 105 | IE = 0 106 | PSNR = 0 107 | count = 0 108 | for folder in tqdm(os.listdir(args.imgpath)): 109 | triple_path = os.path.join(args.imgpath, folder) 110 | if not (os.path.isdir(triple_path)): 111 | continue 112 | X0 = transform(_pil_loader('%s/%s' % (triple_path, args.first))).unsqueeze(0) 113 | X1 = transform(_pil_loader('%s/%s' % (triple_path, args.second))).unsqueeze(0) 114 | 115 | assert (X0.size(2) == X1.size(2)) 116 | assert (X0.size(3) == X1.size(3)) 117 | 118 | intWidth = X0.size(3) 119 | intHeight = X0.size(2) 120 | channel = X0.size(1) 121 | if not channel == 3: 122 | print('Not RGB image') 123 | continue 124 | count += 1 125 | 126 | # if intWidth != ((intWidth >> 4) << 4): 127 | # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary 128 | # intPaddingLeft = int((intWidth_pad - intWidth) / 2) 129 | # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft 130 | # else: 131 | # intWidth_pad = intWidth 132 | # intPaddingLeft = 0 133 | # intPaddingRight = 0 134 | # 135 | # if intHeight != ((intHeight >> 4) << 4): 136 | # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary 137 | # intPaddingTop = int((intHeight_pad - intHeight) / 2) 138 | # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop 139 | # else: 140 | # intHeight_pad = intHeight 141 | # intPaddingTop = 0 142 | # intPaddingBottom = 0 143 | # 144 | # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) 145 | 146 | # first, second = pader(X0), pader(X1) 147 | first, second = X0, X1 148 | imgt = ToImage(first, second) 149 | 150 | imgt_np = imgt.squeeze(0).cpu().numpy()#[:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 151 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 152 | if not os.path.isdir(triple_path): 153 | os.system('mkdir -p %s' % triple_path) 154 | cv2.imwrite(triple_path + '/SeDraw.png', imgt_png) 155 | 156 | rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png'))) 157 | gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt))) 158 | 159 | diff_rgb = rec_rgb - gt_rgb 160 | avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2)) 161 | 162 | mse = np.mean((diff_rgb) ** 2) 163 | 164 | PIXEL_MAX = 255.0 165 | psnr = compare_psnr(gt_rgb, rec_rgb, 255) 166 | print(folder, psnr) 167 | 168 | IE += avg_interp_error_abs 169 | PSNR += psnr 170 | 171 | # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr) 172 | 173 | IE = IE / count 174 | PSNR = PSNR / count 175 | print('Average IE/PSNR:', IE, PSNR) 176 | 177 | main() 178 | 179 | 180 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.2.1 2 | apipkg==1.5 3 | asn1crypto==1.3.0 4 | attrs==19.3.0 5 | backcall==0.1.0 6 | bleach==3.1.4 7 | blis==0.4.1 8 | boto==2.49.0 9 | boto3==1.12.12 10 | botocore==1.15.12 11 | bz2file==0.98 12 | captum==0.2.0 13 | catalogue==1.0.0 14 | certifi==2020.4.5.1 15 | cffi==1.14.0 16 | chardet==3.0.4 17 | Click==7.0 18 | correlation-cuda==0.0.0 19 | cryptography==2.8 20 | cupy==7.2.0 21 | cycler==0.10.0 22 | cymem==2.0.3 23 | Cython==0.29.14 24 | decorator==4.4.2 25 | defusedxml==0.6.0 26 | depthflowprojection-cuda==0.0.0 27 | docutils==0.15.2 28 | en-core-web-sm==2.2.5 29 | entrypoints==0.3 30 | execnet==1.7.1 31 | fastrlock==0.4 32 | filelock==3.0.12 33 | filterinterpolation-cuda==0.0.0 34 | flowprojection-cuda==0.0.0 35 | gensim==3.8.0 36 | idna==2.9 37 | imageio==2.6.1 38 | imageio-ffmpeg==0.4.1 39 | importlib-metadata==1.5.0 40 | interpolation-cuda==0.0.0 41 | interpolationch-cuda==0.0.0 42 | ipykernel==5.1.4 43 | ipython==7.13.0 44 | ipython-genutils==0.2.0 45 | ipywidgets==7.5.1 46 | jedi==0.16.0 47 | jieba==0.42.1 48 | Jinja2==2.11.1 49 | jmespath==0.9.4 50 | joblib==0.14.1 51 | jsonpatch==1.25 52 | jsonpointer==2.0 53 | jsonschema==3.2.0 54 | jupyter==1.0.0 55 | jupyter-client==6.0.0 56 | jupyter-console==6.1.0 57 | jupyter-core==4.6.3 58 | kiwisolver==1.1.0 59 | lxml==4.5.0 60 | MarkupSafe==1.1.1 61 | matplotlib==3.1.3 62 | mindepthflowprojection-cuda==0.0.0 63 | mistune==0.8.4 64 | mkl-fft==1.0.15 65 | mkl-random==1.1.0 66 | mkl-service==2.3.0 67 | mmcv==0.2.16 68 | -e git+https://github.com/open-mmlab/mmdetection.git@93bed07bb65b24ae040d797a7696ecb994acf9f5#egg=mmdet 69 | more-itertools==8.2.0 70 | moviepy==1.0.1 71 | murmurhash==1.0.2 72 | nbconvert==5.6.1 73 | nbformat==5.0.4 74 | networkx==2.4 75 | nltk==3.4.5 76 | notebook==6.0.3 77 | numpy==1.18.1 78 | olefile==0.46 79 | opencv-python==4.2.0.32 80 | packaging==20.3 81 | pandocfilters==1.4.2 82 | parso==0.6.2 83 | pexpect==4.8.0 84 | pickleshare==0.7.5 85 | Pillow==6.2.2 86 | plac==1.1.3 87 | pluggy==0.13.1 88 | prefetch-generator==1.0.1 89 | preshed==3.0.2 90 | proglog==0.1.9 91 | prometheus-client==0.7.1 92 | prompt-toolkit==3.0.4 93 | ptyprocess==0.6.0 94 | py==1.8.1 95 | pycocotools==2.0 96 | pycparser==2.20 97 | Pygments==2.6.1 98 | pyOpenSSL==19.1.0 99 | pyparsing==2.4.6 100 | pyrsistent==0.15.7 101 | PySocks==1.7.1 102 | pytest==5.4.1 103 | pytest-forked==1.1.3 104 | pytest-xdist==1.31.0 105 | python-dateutil==2.8.1 106 | PyWavelets==1.1.1 107 | PyYAML==5.3 108 | pyzmq==18.1.1 109 | qtconsole==4.7.1 110 | QtPy==1.9.0 111 | regex==2020.1.8 112 | requests==2.23.0 113 | s3transfer==0.3.3 114 | sacremoses==0.0.38 115 | scikit-image==0.16.2 116 | scipy==1.4.1 117 | Send2Trash==1.5.0 118 | sentencepiece==0.1.85 119 | separableconv-cuda==0.0.0 120 | separableconvflow-cuda==0.0.0 121 | simplejson==3.17.0 122 | six==1.14.0 123 | smart-open==1.9.0 124 | spacy==2.2.4 125 | srsly==1.0.2 126 | terminado==0.8.3 127 | terminaltables==3.1.0 128 | testpath==0.4.4 129 | thinc==7.4.0 130 | tokenizers==0.5.2 131 | torch==1.2.0 132 | torchfile==0.1.0 133 | torchtext==0.5.0 134 | torchvision==0.4.0a0+6b959ee 135 | tornado==6.0.4 136 | tqdm==4.42.1 137 | traitlets==4.3.3 138 | -e git+https://github.com/huggingface/transformers@1789c7daf1b8013006b0aef6cb1b8f80573031c5#egg=transformers 139 | urllib3==1.25.8 140 | visdom==0.1.8.9 141 | wasabi==0.6.0 142 | wcwidth==0.1.8 143 | webencodings==0.5.1 144 | websocket-client==0.57.0 145 | widgetsnbextension==3.5.1 146 | zipp==3.1.0 147 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # SeDraw 2 | import argparse 3 | import os 4 | import torch 5 | import cv2 6 | import torchvision.transforms as transforms 7 | from skimage.measure import compare_psnr 8 | from PIL import Image 9 | import src.pure_network as layers 10 | from tqdm import tqdm 11 | import numpy as np 12 | import math 13 | import models.bdcn.bdcn as bdcn 14 | 15 | # For parsing commandline arguments 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 20 | return True 21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError('Boolean value expected.') 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') 28 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3') 29 | parser.add_argument('--bdcn_model', type=str, default='/home/visiting/Projects/citrine/SeDraw/models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth') 30 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.') 31 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint') 32 | parser.add_argument('--imgpath', type=str, required=True) 33 | parser.add_argument('--first', type=str, required=True) 34 | parser.add_argument('--second', type=str, required=True) 35 | parser.add_argument('--gt', type=str, required=True) 36 | args = parser.parse_args() 37 | 38 | 39 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): 40 | 41 | 42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 43 | with open(path, 'rb') as f: 44 | img = Image.open(f) 45 | # Resize image if specified. 46 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img 47 | # Crop image if crop area specified. 48 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img 49 | # Flip image horizontally if specified. 50 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img 51 | return flipped_img.convert('RGB') 52 | 53 | 54 | 55 | bdcn = bdcn.BDCN() 56 | bdcn.cuda() 57 | structure_gen = layers.StructureGen(feature_level=args.feature_level) 58 | structure_gen.cuda() 59 | detail_enhance = layers.DetailEnhance() 60 | detail_enhance.cuda() 61 | 62 | 63 | # Channel wise mean calculated on adobe240-fps training dataset 64 | mean = [0.5, 0.5, 0.5] 65 | std = [0.5, 0.5, 0.5] 66 | normalize = transforms.Normalize(mean=mean, 67 | std=std) 68 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 69 | 70 | negmean = [-1 for x in mean] 71 | restd = [2, 2, 2] 72 | revNormalize = transforms.Normalize(mean=negmean, std=restd) 73 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) 74 | 75 | 76 | def ToImage(frame0, frame1): 77 | 78 | with torch.no_grad(): 79 | 80 | img0 = frame0.cuda() 81 | img1 = frame1.cuda() 82 | 83 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) 84 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) 85 | ref_imgt, edge_ref_imgt = structure_gen((img0_e, img1_e)) 86 | imgt = detail_enhance((img0, img1, ref_imgt)) 87 | # imgt = detail_enhance((img0, img1, imgt)) 88 | imgt = torch.clamp(imgt, max=1., min=-1.) 89 | 90 | return imgt, ref_imgt, edge_ref_imgt, img0_e[:, 3:, :, :].repeat(1, 3, 1, 1), img1_e[:, 3:, :, :].repeat(1, 3, 1, 1) 91 | 92 | 93 | def main(): 94 | # initial 95 | 96 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) 97 | dict1 = torch.load(args.checkpoint) 98 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) 99 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) 100 | 101 | bdcn.eval() 102 | structure_gen.eval() 103 | detail_enhance.eval() 104 | 105 | IE = 0 106 | PSNR = 0 107 | count = 0 108 | for folder in tqdm(os.listdir(args.imgpath)): 109 | triple_path = os.path.join(args.imgpath, folder) 110 | if not (os.path.isdir(triple_path)): 111 | continue 112 | X0 = transform(_pil_loader('%s/%s' % (triple_path, args.first))).unsqueeze(0) 113 | X1 = transform(_pil_loader('%s/%s' % (triple_path, args.second))).unsqueeze(0) 114 | 115 | assert (X0.size(2) == X1.size(2)) 116 | assert (X0.size(3) == X1.size(3)) 117 | 118 | intWidth = X0.size(3) 119 | intHeight = X0.size(2) 120 | channel = X0.size(1) 121 | if not channel == 3: 122 | print('Not RGB image') 123 | continue 124 | count += 1 125 | 126 | # if intWidth != ((intWidth >> 4) << 4): 127 | # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary 128 | # intPaddingLeft = int((intWidth_pad - intWidth) / 2) 129 | # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft 130 | # else: 131 | # intWidth_pad = intWidth 132 | # intPaddingLeft = 0 133 | # intPaddingRight = 0 134 | # 135 | # if intHeight != ((intHeight >> 4) << 4): 136 | # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary 137 | # intPaddingTop = int((intHeight_pad - intHeight) / 2) 138 | # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop 139 | # else: 140 | # intHeight_pad = intHeight 141 | # intPaddingTop = 0 142 | # intPaddingBottom = 0 143 | # 144 | # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) 145 | 146 | # first, second = pader(X0), pader(X1) 147 | first, second = X0, X1 148 | imgt, ref_imgt, edge_ref_imgt, edge0, edge1 = ToImage(first, second) 149 | 150 | imgt_np = imgt.squeeze(0).cpu().numpy()#[:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 151 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 152 | ref_imgt_np = ref_imgt.squeeze( 153 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 154 | ref_imgt_png = np.uint8(((ref_imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 155 | edge_ref_imgt_np = edge_ref_imgt.squeeze( 156 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 157 | edge_ref_imgt_png = np.uint8(((edge_ref_imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 158 | edge0_np = edge0.squeeze( 159 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 160 | edge0_png = np.uint8(((edge0_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 161 | edge1_np = edge1.squeeze( 162 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 163 | edge1_png = np.uint8(((edge1_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 164 | if not os.path.isdir(triple_path): 165 | os.system('mkdir -p %s' % triple_path) 166 | cv2.imwrite(triple_path + '/SeDraw.png', imgt_png) 167 | cv2.imwrite(triple_path + '/SeDraw_ref.png', ref_imgt_png) 168 | cv2.imwrite(triple_path + '/SeDraw_edge_ref.png', edge_ref_imgt_png) 169 | cv2.imwrite(triple_path + '/SeDraw_edge0.png', edge0_png) 170 | cv2.imwrite(triple_path + '/SeDraw_edge1.png', edge1_png) 171 | 172 | rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png'))) 173 | gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt))) 174 | 175 | diff_rgb = rec_rgb - gt_rgb 176 | avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2)) 177 | 178 | mse = np.mean((diff_rgb) ** 2) 179 | 180 | PIXEL_MAX = 255.0 181 | psnr = compare_psnr(gt_rgb, rec_rgb, 255) 182 | print(folder, psnr) 183 | 184 | IE += avg_interp_error_abs 185 | PSNR += psnr 186 | 187 | # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr) 188 | 189 | IE = IE / count 190 | PSNR = PSNR / count 191 | print('Average IE/PSNR:', IE, PSNR) 192 | 193 | main() 194 | 195 | 196 | -------------------------------------------------------------------------------- /sequence_run.py: -------------------------------------------------------------------------------- 1 | # SeDraw 2 | import re 3 | import argparse 4 | import os 5 | import torch 6 | import cv2 7 | import torchvision.transforms as transforms 8 | from skimage.measure import compare_psnr 9 | from PIL import Image 10 | import src.pure_network as layers 11 | from tqdm import tqdm 12 | import numpy as np 13 | import math 14 | import models.bdcn.bdcn as bdcn 15 | 16 | # For parsing commandline arguments 17 | def str2bool(v): 18 | if isinstance(v, bool): 19 | return v 20 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 21 | return True 22 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError('Boolean value expected.') 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') 29 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3') 30 | parser.add_argument('--bdcn_model', type=str, default='./models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth') 31 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.') 32 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint') 33 | parser.add_argument('--video_path', type=str, required=True) 34 | parser.add_argument('--t_interp', type=int, default=4, help='times of interpolating') 35 | parser.add_argument('--fps', type=int, default=-1, help='specify the fps.') 36 | parser.add_argument('--slow_motion', action='store_true', help='using this flag if you want to slow down the video and maintain fps.') 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): 42 | 43 | 44 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 45 | with open(path, 'rb') as f: 46 | img = Image.open(f) 47 | # Resize image if specified. 48 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img 49 | # Crop image if crop area specified. 50 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img 51 | # Flip image horizontally if specified. 52 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img 53 | return flipped_img.convert('RGB') 54 | 55 | 56 | 57 | bdcn = bdcn.BDCN() 58 | bdcn.cuda() 59 | structure_gen = layers.StructureGen(feature_level=args.feature_level) 60 | structure_gen.cuda() 61 | detail_enhance = layers.DetailEnhance() 62 | detail_enhance.cuda() 63 | 64 | 65 | # Channel wise mean calculated on adobe240-fps training dataset 66 | mean = [0.5, 0.5, 0.5] 67 | std = [0.5, 0.5, 0.5] 68 | normalize = transforms.Normalize(mean=mean, 69 | std=std) 70 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 71 | 72 | negmean = [-1 for x in mean] 73 | restd = [2, 2, 2] 74 | revNormalize = transforms.Normalize(mean=negmean, std=restd) 75 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) 76 | 77 | 78 | def ToImage(frame0, frame1): 79 | 80 | with torch.no_grad(): 81 | 82 | img0 = frame0.cuda() 83 | img1 = frame1.cuda() 84 | 85 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) 86 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) 87 | ref_imgt, _ = structure_gen((img0_e, img1_e)) 88 | imgt = detail_enhance((img0, img1, ref_imgt)) 89 | # imgt = detail_enhance((img0, img1, imgt)) 90 | imgt = torch.clamp(imgt, max=1., min=-1.) 91 | 92 | return imgt 93 | def IndexHelper(i, digit): 94 | index = str(i) 95 | for j in range(digit-len(str(i))): 96 | index = '0'+index 97 | return index 98 | 99 | def VideoToSequence(path, time): 100 | video = cv2.VideoCapture(path) 101 | dir_path = 'frames_tmp' 102 | os.system("rm -rf %s" % dir_path) 103 | os.mkdir(dir_path) 104 | fps = int(video.get(cv2.CAP_PROP_FPS)) 105 | length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 106 | print('making ' + str(length) + ' frame sequence in ' + dir_path) 107 | i = -1 108 | while (True): 109 | (grabbed, frame) = video.read() 110 | if not grabbed: 111 | break 112 | i = i + 1 113 | index = IndexHelper(i*time, len(str(time*length))) 114 | cv2.imwrite(dir_path + '/' + index + '.png', frame) 115 | # print(index) 116 | return [dir_path, length, fps] 117 | 118 | def main(): 119 | # initial 120 | iter = math.log(args.t_interp, int(2)) 121 | if iter%1: 122 | print('the times of interpolating must be power of 2!!') 123 | return 124 | iter = int(iter) 125 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) 126 | dict1 = torch.load(args.checkpoint) 127 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) 128 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) 129 | 130 | bdcn.eval() 131 | structure_gen.eval() 132 | detail_enhance.eval() 133 | 134 | IE = 0 135 | PSNR = 0 136 | count = 0 137 | [dir_path, frame_count, fps] = VideoToSequence(args.video_path, args.t_interp) 138 | 139 | for i in range(iter): 140 | print('processing iter' + str(i+1) + ', ' + str((i+1)*frame_count) + ' frames in total') 141 | filenames = os.listdir(dir_path) 142 | filenames.sort() 143 | for i in range(0, len(filenames) - 1): 144 | arguments_strFirst = os.path.join(dir_path, filenames[i]) 145 | arguments_strSecond = os.path.join(dir_path, filenames[i + 1]) 146 | index1 = int(re.sub("\D", "", filenames[i])) 147 | index2 = int(re.sub("\D", "", filenames[i + 1])) 148 | index = int((index1 + index2) / 2) 149 | arguments_strOut = os.path.join(dir_path, 150 | IndexHelper(index, len(str(args.t_interp * frame_count))) + ".png") 151 | 152 | # print(arguments_strFirst) 153 | # print(arguments_strSecond) 154 | # print(arguments_strOut) 155 | 156 | X0 = transform(_pil_loader(arguments_strFirst)).unsqueeze(0) 157 | X1 = transform(_pil_loader(arguments_strSecond)).unsqueeze(0) 158 | 159 | assert (X0.size(2) == X1.size(2)) 160 | assert (X0.size(3) == X1.size(3)) 161 | 162 | intWidth = X0.size(3) 163 | intHeight = X0.size(2) 164 | channel = X0.size(1) 165 | if not channel == 3: 166 | print('Not RGB image') 167 | continue 168 | count += 1 169 | 170 | # if intWidth != ((intWidth >> 4) << 4): 171 | # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary 172 | # intPaddingLeft = int((intWidth_pad - intWidth) / 2) 173 | # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft 174 | # else: 175 | # intWidth_pad = intWidth 176 | # intPaddingLeft = 0 177 | # intPaddingRight = 0 178 | # 179 | # if intHeight != ((intHeight >> 4) << 4): 180 | # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary 181 | # intPaddingTop = int((intHeight_pad - intHeight) / 2) 182 | # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop 183 | # else: 184 | # intHeight_pad = intHeight 185 | # intPaddingTop = 0 186 | # intPaddingBottom = 0 187 | # 188 | # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) 189 | 190 | # first, second = pader(X0), pader(X1) 191 | first, second = X0, X1 192 | imgt = ToImage(first, second) 193 | 194 | imgt_np = imgt.squeeze( 195 | 0).cpu().numpy() # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] 196 | imgt_png = np.uint8(((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 197 | cv2.imwrite(arguments_strOut, imgt_png) 198 | 199 | # rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png'))) 200 | # gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt))) 201 | 202 | # diff_rgb = rec_rgb - gt_rgb 203 | # avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2)) 204 | 205 | # mse = np.mean((diff_rgb) ** 2) 206 | 207 | # PIXEL_MAX = 255.0 208 | # psnr = compare_psnr(gt_rgb, rec_rgb, 255) 209 | # print(folder, psnr) 210 | 211 | # IE += avg_interp_error_abs 212 | # PSNR += psnr 213 | 214 | # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr) 215 | 216 | # IE = IE / count 217 | # PSNR = PSNR / count 218 | # print('Average IE/PSNR:', IE, PSNR) 219 | if args.fps != -1: 220 | output_fps = args.fps 221 | else: 222 | output_fps = fps if args.slow_motion else args.t_interp*fps 223 | os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "/*.png' -pix_fmt yuv420p output.mp4") 224 | os.system("rm -rf %s" % dir_path) 225 | 226 | 227 | main() 228 | 229 | 230 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/pure_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CM-BF/FeatureFlow/deebc7e7bc5ea1ebb6dca7a8aa4f289649710d1b/src/__pycache__/pure_network.cpython-37.pyc -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import random 6 | 7 | 8 | def _make_dataset(dir): 9 | """ 10 | Creates a 2D list of all the frames in N clips containing 11 | M frames each. 12 | 13 | 2D List Structure: 14 | [[frame00, frame01,...frameM] <-- clip0 15 | [frame00, frame01,...frameM] <-- clip1 16 | : 17 | [frame00, frame01,...frameM]] <-- clipN 18 | 19 | Parameters 20 | ---------- 21 | dir : string 22 | root directory containing clips. 23 | 24 | Returns 25 | ------- 26 | list 27 | 2D list described above. 28 | """ 29 | 30 | 31 | framesPath = [] 32 | # Find and loop over all the clips in root `dir`. 33 | for index, folder in enumerate(os.listdir(dir)): 34 | clipsFolderPath = os.path.join(dir, folder) 35 | # Skip items which are not folders. 36 | if not (os.path.isdir(clipsFolderPath)): 37 | continue 38 | framesPath.append([]) 39 | # Find and loop over all the frames inside the clip. 40 | for image in sorted(os.listdir(clipsFolderPath)): 41 | # Add path to list. 42 | framesPath[index].append(os.path.join(clipsFolderPath, image)) 43 | return framesPath 44 | 45 | def _make_video_dataset(dir): 46 | """ 47 | Creates a 1D list of all the frames. 48 | 49 | 1D List Structure: 50 | [frame0, frame1,...frameN] 51 | 52 | Parameters 53 | ---------- 54 | dir : string 55 | root directory containing frames. 56 | 57 | Returns 58 | ------- 59 | list 60 | 1D list described above. 61 | """ 62 | 63 | 64 | framesPath = [] 65 | # Find and loop over all the frames in root `dir`. 66 | for image in sorted(os.listdir(dir)): 67 | # Add path to list. 68 | framesPath.append(os.path.join(dir, image)) 69 | return framesPath 70 | 71 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): 72 | """ 73 | Opens image at `path` using pil and applies data augmentation. 74 | 75 | Parameters 76 | ---------- 77 | path : string 78 | path of the image. 79 | cropArea : tuple, optional 80 | coordinates for cropping image. Default: None 81 | resizeDim : tuple, optional 82 | dimensions for resizing image. Default: None 83 | frameFlip : int, optional 84 | Non zero to flip image horizontally. Default: 0 85 | 86 | Returns 87 | ------- 88 | list 89 | 2D list described above. 90 | """ 91 | 92 | 93 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 94 | with open(path, 'rb') as f: 95 | img = Image.open(f) 96 | # Resize image if specified. 97 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img 98 | # Crop image if crop area specified. 99 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img 100 | # Flip image horizontally if specified. 101 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img 102 | return flipped_img.convert('RGB') 103 | 104 | 105 | class FeFlow_vimeo90k(data.Dataset): 106 | 107 | def __init__(self, root, transform=None, dim=(448, 256), randomCropSize=(208, 208), train=True, test=False): 108 | 109 | 110 | # Populate the list with image paths for all the 111 | # frame in `root`. 112 | if train: 113 | train_f = open(os.path.join(root, 'tri_trainlist.txt'), 'r') 114 | test_f = open(os.path.join(root, 'tri_testlist.txt'), 'r') 115 | if train: 116 | self.train_path = train_f.read().split('\n') 117 | self.train_path = [item for item in self.train_path if item != ''] 118 | test_val_path = test_f.read().split('\n') 119 | self.test_path = [item for index, item in enumerate(test_val_path) if item != ''] 120 | self.validation_path = [item for index, item in enumerate(test_val_path) if index % 100 == 0 and item != ''] 121 | 122 | self.randomCropSize = randomCropSize 123 | self.cropX0 = dim[0] - randomCropSize[0] 124 | self.cropY0 = dim[1] - randomCropSize[1] 125 | self.root = root 126 | self.transform = transform 127 | self.train = train 128 | self.test = test 129 | 130 | def __getitem__(self, index): 131 | """ 132 | Returns the sample corresponding to `index` from dataset. 133 | 134 | The sample consists of two reference frames - I0 and I1 - 135 | and a random frame chosen from the 7 intermediate frames 136 | available between I0 and I1 along with it's relative index. 137 | 138 | Parameters 139 | ---------- 140 | index : int 141 | Index 142 | 143 | Returns 144 | ------- 145 | tuple 146 | (sample, returnIndex) where sample is 147 | [I0, intermediate_frame, I1] and returnIndex is 148 | the position of `random_intermediate_frame`. 149 | e.g.- `returnIndex` of frame next to I0 would be 0 and 150 | frame before I1 would be 6. 151 | """ 152 | 153 | sample = [] 154 | 155 | if (self.train): 156 | ### Data Augmentation ### 157 | 158 | cropX = random.randint(0, self.cropX0) 159 | cropY = random.randint(0, self.cropY0) 160 | cropArea = (cropX, cropY, cropX + self.randomCropSize[0], cropY + self.randomCropSize[1]) 161 | # Random reverse frame 162 | if (random.randint(0, 1)): 163 | frameRange = [0, 1, 2] 164 | else: 165 | frameRange = [2, 1, 0] 166 | # Random flip frame 167 | randomFrameFlip = random.randint(0, 1) 168 | else: 169 | # Fixed settings to return same samples every epoch. 170 | # For validation/test sets. 171 | cropArea = (0, 0, self.randomCropSize[0], self.randomCropSize[1]) 172 | frameRange = [0, 1, 2] 173 | randomFrameFlip = 0 174 | 175 | # Loop over for all frames corresponding to the `index`. 176 | for frameIndex in frameRange: 177 | # Open image using pil and augment the image. 178 | if self.train: 179 | image = _pil_loader( 180 | self.root + '/sequences/' + self.train_path[index] + '/im{}.png'.format(frameIndex + 1), 181 | cropArea=cropArea, frameFlip=randomFrameFlip) 182 | elif self.test: 183 | image = _pil_loader( 184 | self.root + '/sequences/' + self.test_path[index] + '/im{}.png'.format(frameIndex + 1), 185 | cropArea=cropArea, frameFlip=randomFrameFlip) 186 | else: 187 | image = _pil_loader( 188 | self.root + '/sequences/' + self.validation_path[index] + '/im{}.png'.format(frameIndex + 1), 189 | cropArea=cropArea, frameFlip=randomFrameFlip) 190 | # Apply transformation if specified. 191 | if self.transform is not None: 192 | image = self.transform(image) 193 | sample.append(image) 194 | 195 | return sample 196 | 197 | def __len__(self): 198 | """ 199 | Returns the size of dataset. Invoked as len(datasetObj). 200 | 201 | Returns 202 | ------- 203 | int 204 | number of samples. 205 | """ 206 | if self.train: 207 | return len(self.train_path) 208 | elif self.test: 209 | return len(self.test_path) 210 | else: 211 | return len(self.validation_path) 212 | 213 | def __repr__(self): 214 | """ 215 | Returns printable representation of the dataset object. 216 | 217 | Returns 218 | ------- 219 | string 220 | info. 221 | """ 222 | 223 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 224 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 225 | fmt_str += ' Root Location: {}\n'.format(self.root) 226 | tmp = ' Transforms (if any): ' 227 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 228 | return fmt_str 229 | 230 | 231 | class UCI101Test(data.Dataset): 232 | """ 233 | A dataloader for loading N samples arranged in this way: 234 | 235 | |-- clip0 236 | |-- frame00 237 | |-- frame01 238 | |-- frame02 239 | |-- clip1 240 | |-- frame00 241 | |-- frame01 242 | |-- frame02 243 | : 244 | : 245 | |-- clipN 246 | |-- frame00 247 | |-- frame01 248 | |-- frame02 249 | 250 | ... 251 | 252 | Attributes 253 | ---------- 254 | framesPath : list 255 | List of frames' path in the dataset. 256 | 257 | Methods 258 | ------- 259 | __getitem__(index) 260 | Returns the sample corresponding to `index` from dataset. 261 | __len__() 262 | Returns the size of dataset. Invoked as len(datasetObj). 263 | __repr__() 264 | Returns printable representation of the dataset object. 265 | """ 266 | 267 | def __init__(self, root, transform=None): 268 | """ 269 | Parameters 270 | ---------- 271 | root : string 272 | Root directory path. 273 | transform : callable, optional 274 | A function/transform that takes in 275 | a sample and returns a transformed version. 276 | E.g, ``transforms.RandomCrop`` for images. 277 | """ 278 | 279 | # Populate the list with image paths for all the 280 | # frame in `root`. 281 | framesPath = _make_dataset(root) 282 | # Raise error if no images found in root. 283 | if len(framesPath) == 0: 284 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n")) 285 | 286 | self.root = root 287 | self.framesPath = framesPath 288 | self.transform = transform 289 | 290 | def __getitem__(self, index): 291 | """ 292 | Returns the sample corresponding to `index` from dataset. 293 | 294 | The sample consists of two reference frames - I0 and I1 - 295 | and a intermediate frame between I0 and I1. 296 | 297 | Parameters 298 | ---------- 299 | index : int 300 | Index 301 | 302 | Returns 303 | ------- 304 | tuple 305 | (sample, returnIndex) where sample is 306 | [I0, intermediate_frame, I1] and returnIndex is 307 | the position of `intermediate_frame`. 308 | The returnIndex is always 3 and is being returned 309 | to maintain compatibility with the `FeFlow` 310 | dataloader where 3 corresponds to the middle frame. 311 | """ 312 | 313 | sample = [] 314 | # Loop over for all frames corresponding to the `index`. 315 | for framePath in self.framesPath[index]: 316 | # Open image using pil. 317 | image = _pil_loader(framePath) 318 | # Apply transformation if specified. 319 | if self.transform is not None: 320 | image = self.transform(image) 321 | sample.append(image) 322 | return sample, 3 323 | 324 | def __len__(self): 325 | """ 326 | Returns the size of dataset. Invoked as len(datasetObj). 327 | 328 | Returns 329 | ------- 330 | int 331 | number of samples. 332 | """ 333 | 334 | return len(self.framesPath) 335 | 336 | def __repr__(self): 337 | """ 338 | Returns printable representation of the dataset object. 339 | 340 | Returns 341 | ------- 342 | string 343 | info. 344 | """ 345 | 346 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 347 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 348 | fmt_str += ' Root Location: {}\n'.format(self.root) 349 | tmp = ' Transforms (if any): ' 350 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 351 | return fmt_str 352 | 353 | 354 | class FeFlow(data.Dataset): 355 | """ 356 | A dataloader for loading N samples arranged in this way: 357 | 358 | |-- clip0 359 | |-- frame00 360 | |-- frame01 361 | : 362 | |-- frame11 363 | |-- frame12 364 | |-- clip1 365 | |-- frame00 366 | |-- frame01 367 | : 368 | |-- frame11 369 | |-- frame12 370 | : 371 | : 372 | |-- clipN 373 | |-- frame00 374 | |-- frame01 375 | : 376 | |-- frame11 377 | |-- frame12 378 | 379 | ... 380 | 381 | Attributes 382 | ---------- 383 | framesPath : list 384 | List of frames' path in the dataset. 385 | 386 | Methods 387 | ------- 388 | __getitem__(index) 389 | Returns the sample corresponding to `index` from dataset. 390 | __len__() 391 | Returns the size of dataset. Invoked as len(datasetObj). 392 | __repr__() 393 | Returns printable representation of the dataset object. 394 | """ 395 | 396 | 397 | def __init__(self, root, transform=None, dim=(448, 256), randomCropSize=(208, 208), train=True): 398 | """ 399 | Parameters 400 | ---------- 401 | root : string 402 | Root directory path. 403 | transform : callable, optional 404 | A function/transform that takes in 405 | a sample and returns a transformed version. 406 | E.g, ``transforms.RandomCrop`` for images. 407 | dim : tuple, optional 408 | Dimensions of images in dataset. Default: (640, 360) 409 | randomCropSize : tuple, optional 410 | Dimensions of random crop to be applied. Default: (352, 352) 411 | train : boolean, optional 412 | Specifies if the dataset is for training or testing/validation. 413 | `True` returns samples with data augmentation like random 414 | flipping, random cropping, etc. while `False` returns the 415 | samples without randomization. Default: True 416 | """ 417 | 418 | 419 | # Populate the list with image paths for all the 420 | # frame in `root`. 421 | framesPath = _make_dataset(root) 422 | # Raise error if no images found in root. 423 | if len(framesPath) == 0: 424 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n")) 425 | 426 | self.randomCropSize = randomCropSize 427 | self.cropX0 = dim[0] - randomCropSize[0] 428 | self.cropY0 = dim[1] - randomCropSize[1] 429 | self.root = root 430 | self.transform = transform 431 | self.train = train 432 | 433 | self.framesPath = framesPath 434 | 435 | def __getitem__(self, index): 436 | """ 437 | Returns the sample corresponding to `index` from dataset. 438 | 439 | The sample consists of two reference frames - I0 and I1 - 440 | and a random frame chosen from the 7 intermediate frames 441 | available between I0 and I1 along with it's relative index. 442 | 443 | Parameters 444 | ---------- 445 | index : int 446 | Index 447 | 448 | Returns 449 | ------- 450 | tuple 451 | (sample, returnIndex) where sample is 452 | [I0, intermediate_frame, I1] and returnIndex is 453 | the position of `random_intermediate_frame`. 454 | e.g.- `returnIndex` of frame next to I0 would be 0 and 455 | frame before I1 would be 6. 456 | """ 457 | 458 | 459 | sample = [] 460 | 461 | if (self.train): 462 | ### Data Augmentation ### 463 | # To select random 9 frames from 12 frames in a clip 464 | firstFrame = 0 465 | # Apply random crop on the 9 input frames 466 | cropX = random.randint(0, self.cropX0) 467 | cropY = random.randint(0, self.cropY0) 468 | cropArea = (cropX, cropY, cropX + self.randomCropSize[0], cropY + self.randomCropSize[1]) 469 | # Random reverse frame 470 | #frameRange = range(firstFrame, firstFrame + 9) if (random.randint(0, 1)) else range(firstFrame + 8, firstFrame - 1, -1) 471 | # IFrameIndex = random.randint(firstFrame + 1, firstFrame + 7) 472 | IFrameIndex = 1 473 | if (random.randint(0, 1)): 474 | frameRange = [firstFrame, IFrameIndex, firstFrame + 2] 475 | returnIndex = 1 476 | else: 477 | frameRange = [firstFrame + 2, IFrameIndex, firstFrame] 478 | returnIndex = 1 479 | # Random flip frame 480 | randomFrameFlip = random.randint(0, 1) 481 | else: 482 | # Fixed settings to return same samples every epoch. 483 | # For validation/test sets. 484 | firstFrame = 0 485 | cropArea = (0, 0, self.randomCropSize[0], self.randomCropSize[1]) 486 | IFrameIndex = 1 487 | returnIndex = 1 488 | frameRange = [0, IFrameIndex, 2] 489 | randomFrameFlip = 0 490 | 491 | # Loop over for all frames corresponding to the `index`. 492 | for frameIndex in frameRange: 493 | # Open image using pil and augment the image. 494 | image = _pil_loader(self.framesPath[index][frameIndex], cropArea=cropArea, frameFlip=randomFrameFlip) 495 | # Apply transformation if specified. 496 | if self.transform is not None: 497 | image = self.transform(image) 498 | sample.append(image) 499 | 500 | return sample, returnIndex 501 | 502 | 503 | def __len__(self): 504 | """ 505 | Returns the size of dataset. Invoked as len(datasetObj). 506 | 507 | Returns 508 | ------- 509 | int 510 | number of samples. 511 | """ 512 | 513 | 514 | return len(self.framesPath) 515 | 516 | def __repr__(self): 517 | """ 518 | Returns printable representation of the dataset object. 519 | 520 | Returns 521 | ------- 522 | string 523 | info. 524 | """ 525 | 526 | 527 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 528 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 529 | fmt_str += ' Root Location: {}\n'.format(self.root) 530 | tmp = ' Transforms (if any): ' 531 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 532 | return fmt_str 533 | 534 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts a Video to SeDraw version 3 | """ 4 | from time import time 5 | import click 6 | import cv2 7 | import torch 8 | from PIL import Image 9 | import numpy as np 10 | import src.model as model 11 | import src.layers as layers 12 | from torchvision import transforms 13 | from torch.functional import F 14 | 15 | 16 | torch.set_grad_enabled(False) 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | trans_forward = transforms.ToTensor() 20 | trans_backward = transforms.ToPILImage() 21 | if device != "cpu": 22 | mean = [0.429, 0.431, 0.397] 23 | mea0 = [-m for m in mean] 24 | std = [1] * 3 25 | trans_forward = transforms.Compose([trans_forward, transforms.Normalize(mean=mean, std=std)]) 26 | trans_backward = transforms.Compose([transforms.Normalize(mean=mea0, std=std), trans_backward]) 27 | 28 | network = layers.Network() 29 | network = torch.nn.DataParallel(network) 30 | 31 | 32 | 33 | 34 | def load_models(checkpoint): 35 | states = torch.load(checkpoint, map_location='cpu') 36 | network.module.load_state_dict(states['state_dictNET']) 37 | 38 | 39 | def interpolate_batch(frames, factor): 40 | frame0 = torch.stack(frames[:-1]) 41 | frame1 = torch.stack(frames[1:]) 42 | 43 | img0 = frame0.to(device) 44 | img1 = frame1.to(device) 45 | 46 | frame_buffer = [] 47 | for i in range(factor - 1): 48 | # start = time() 49 | frame_index = torch.ones(frames.__len__() - 1).type(torch.long) * i 50 | output = network((img0, img1, frame_index, None)) 51 | ft_p = output[3] 52 | 53 | frame_buffer.append(ft_p) 54 | # print('time:', time() - start) 55 | 56 | return frame_buffer 57 | 58 | 59 | def load_batch(video_in, batch_size, batch, w, h): 60 | if len(batch) > 0: 61 | batch = [batch[-1]] 62 | 63 | for i in range(batch_size): 64 | ok, frame = video_in.read() 65 | if not ok: 66 | break 67 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 68 | frame = Image.fromarray(frame) 69 | frame = frame.resize((w, h), Image.ANTIALIAS) 70 | frame = frame.convert('RGB') 71 | frame = trans_forward(frame) 72 | batch.append(frame) 73 | 74 | return batch 75 | 76 | 77 | def denorm_frame(frame, w0, h0): 78 | frame = frame.cpu() 79 | frame = trans_backward(frame) 80 | frame = frame.resize((w0, h0), Image.BILINEAR) 81 | frame = frame.convert('RGB') 82 | return np.array(frame)[:, :, ::-1].copy() 83 | 84 | 85 | def convert_video(source, dest, factor, batch_size=10, output_format='mp4v', output_fps=30): 86 | vin = cv2.VideoCapture(source) 87 | count = vin.get(cv2.CAP_PROP_FRAME_COUNT) 88 | w0, h0 = int(vin.get(cv2.CAP_PROP_FRAME_WIDTH)), int(vin.get(cv2.CAP_PROP_FRAME_HEIGHT)) 89 | 90 | codec = cv2.VideoWriter_fourcc(*output_format) 91 | vout = cv2.VideoWriter(dest, codec, float(output_fps), (w0, h0)) 92 | 93 | w, h = (w0 // 32) * 32, (h0 // 32) * 32 94 | network.module.setup('custom', h, w) 95 | 96 | network.module.setup_t(factor) 97 | network.cuda() 98 | 99 | done = 0 100 | batch = [] 101 | while True: 102 | batch = load_batch(vin, batch_size, batch, w, h) 103 | if len(batch) == 1: 104 | break 105 | done += len(batch) - 1 106 | 107 | intermediate_frames = interpolate_batch(batch, factor) 108 | intermediate_frames = list(zip(*intermediate_frames)) 109 | 110 | for fid, iframe in enumerate(intermediate_frames): 111 | vout.write(denorm_frame(batch[fid], w0, h0)) 112 | for frm in iframe: 113 | vout.write(denorm_frame(frm, w0, h0)) 114 | 115 | try: 116 | yield len(batch), done, count 117 | except StopIteration: 118 | break 119 | 120 | vout.write(denorm_frame(batch[0], w0, h0)) 121 | 122 | vin.release() 123 | vout.release() 124 | 125 | 126 | @click.command('Evaluate Model by converting a low-FPS video to high-fps') 127 | @click.argument('input') 128 | @click.option('--checkpoint', help='Path to model checkpoint') 129 | @click.option('--output', help='Path to output file to save') 130 | @click.option('--batch', default=4, help='Number of frames to process in single forward pass') 131 | @click.option('--scale', default=4, help='Scale Factor of FPS') 132 | @click.option('--fps', default=30, help='FPS of output video') 133 | def main(input, checkpoint, output, batch, scale, fps): 134 | avg = lambda x, n, x0: (x * n/(n+1) + x0 / (n+1), n+1) 135 | load_models(checkpoint) 136 | t0 = time() 137 | n0 = 0 138 | fpx = 0 139 | for dl, fd, fc in convert_video(input, output, int(scale), int(batch), output_fps=int(fps)): 140 | fpx, n0 = avg(fpx, n0, dl / (time() - t0)) 141 | prg = int(100*fd/fc) 142 | eta = (fc - fd) / fpx 143 | print('\rDone: {:03d}% FPS: {:05.2f} ETA: {:.2f}s'.format(prg, fpx, eta) + ' '*5, end='') 144 | t0 = time() 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | 150 | 151 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import src.model as model 2 | import src.loss as loss 3 | import torch.nn as nn 4 | import torch 5 | from math import log10 6 | from models.ResBlock import GlobalGenerator 7 | 8 | 9 | class StructureGen(nn.Module): 10 | 11 | def __init__(self, feature_level=3): 12 | super(StructureGen, self).__init__() 13 | 14 | self.feature_level = feature_level 15 | channel = 2 ** (6 + self.feature_level) 16 | # self.structure_extractor = model.StructureExtractor() 17 | self.extract_features = model.ExtractFeatures() 18 | self.dcn = model.DeformableConv(channel, dg=16) 19 | self.generator = GlobalGenerator(channel, 4, n_downsampling=self.feature_level) 20 | 21 | # Loss calculate 22 | self.L1_lossFn = nn.L1Loss() 23 | self.sL1_lossFn = nn.SmoothL1Loss() 24 | self.cL1_lossFn = loss.CharbonnierLoss() 25 | self.MSE_LossFn = nn.MSELoss() 26 | 27 | def forward(self, input): 28 | img0_e, img1_e, IFrame_e = input 29 | 30 | ft_img0 = list(self.extract_features(img0_e))[self.feature_level] 31 | ft_img1 = list(self.extract_features(img1_e))[self.feature_level] 32 | 33 | pre_gen_ft_imgt, out_x, out_y = self.dcn(ft_img0, ft_img1) 34 | 35 | ref_imgt_e = self.generator(pre_gen_ft_imgt) 36 | 37 | # divide edge 38 | IFrame, edge_IFrame = IFrame_e[:, :3], IFrame_e[:, 3:] 39 | ref_imgt, edge_ref_imgt = ref_imgt_e[:, :3], ref_imgt_e[:, 3:] 40 | 41 | # extract for loss 42 | ft_IFrame = list(self.extract_features(IFrame_e))[self.feature_level] 43 | ft_ref_imgt = list(self.extract_features(ref_imgt_e))[self.feature_level] 44 | # st_IFrame = self.structure_extractor(IFrame) 45 | # st_ref_imgt = self.structure_extractor(ref_imgt) 46 | 47 | # Loss calculate 48 | 49 | feature_mix_loss = 500 * self.cL1_lossFn(pre_gen_ft_imgt, ft_IFrame) 50 | 51 | tri_loss = 20 * (self.cL1_lossFn(out_x, ft_IFrame) + self.L1_lossFn(out_y, ft_IFrame)) 52 | 53 | # feature_gen_loss = 10 * self.MSE_LossFn(ft_ref_imgt, ft_IFrame) 54 | 55 | # structure_loss = 20 * (self.L1_lossFn(st_ref_imgt[0], st_IFrame[0]) + self.L1_lossFn(st_ref_imgt[1], st_IFrame[1]) + 56 | # self.L1_lossFn(st_ref_imgt[2], st_IFrame[2]) + self.L1_lossFn(st_ref_imgt[3], st_IFrame[3])) 57 | 58 | edge_loss = 5 * self.MSE_LossFn(edge_ref_imgt, edge_IFrame) 59 | 60 | gen_loss = 128 * self.cL1_lossFn(ref_imgt, IFrame) 61 | 62 | loss = gen_loss + feature_mix_loss + edge_loss + tri_loss# + structure_loss + feature_gen_loss 63 | 64 | MSE_val = self.MSE_LossFn(ref_imgt, IFrame) 65 | 66 | print('Loss:', loss.item(), 'feature_mix_loss:', feature_mix_loss.item(), 67 | 'gen_loss:', gen_loss.item(), 68 | # 'feature_gen_loss:', feature_gen_loss.item(), 69 | # 'structure_loss:', structure_loss.item(), 70 | 'edge_loss:', edge_loss.item(), 71 | 'tri_loss:', tri_loss.item()) 72 | 73 | return loss, MSE_val, ref_imgt 74 | 75 | 76 | 77 | class DetailEnhance(nn.Module): 78 | 79 | 80 | def __init__(self): 81 | 82 | super(DetailEnhance, self).__init__() 83 | 84 | self.feature_level = 3 85 | 86 | self.extract_features = model.ValidationFeatures() 87 | self.extract_aligned_features = model.ExtractAlignedFeatures(n_res=5) # 4 5 88 | self.pcd_align = model.PCD_Align(groups=8) # 4 8 89 | self.tsa_fusion = model.TSA_Fusion(nframes=3, center=1) 90 | 91 | self.reconstruct = model.Reconstruct(n_res=20) # 5 40 92 | 93 | # Loss calculate 94 | self.L1_lossFn = nn.L1Loss() 95 | self.sL1_lossFn = nn.SmoothL1Loss() 96 | self.cL1_lossFn = loss.CharbonnierLoss() 97 | self.MSE_LossFn = nn.MSELoss() 98 | 99 | def forward(self, input): 100 | """ 101 | Network forward tensor flow 102 | 103 | :param input: a tuple of input that will be unfolded 104 | :return: medium interpolation image 105 | """ 106 | img0, img1, IFrame, ref_imgt = input 107 | 108 | ref_align_ft = self.extract_aligned_features(ref_imgt) 109 | align_ft_0 = self.extract_aligned_features(img0) 110 | align_ft_1 = self.extract_aligned_features(img1) 111 | 112 | align_ft = [self.pcd_align(align_ft_0, ref_align_ft), 113 | self.pcd_align(ref_align_ft, ref_align_ft), 114 | self.pcd_align(align_ft_1, ref_align_ft)] 115 | align_ft = torch.stack(align_ft, dim=1) 116 | 117 | tsa_ft = self.tsa_fusion(align_ft) 118 | 119 | imgt = self.reconstruct(tsa_ft, ref_imgt) 120 | 121 | 122 | # extract for loss 123 | ft_IFrame = list(self.extract_features(IFrame))[self.feature_level] 124 | ft_imgt = list(self.extract_features(imgt))[self.feature_level] 125 | 126 | 127 | """ 128 | ---------------------------------------------------------------------- 129 | ====================================================================== 130 | """ 131 | 132 | # Loss calculate 133 | # feature_recn_loss = 10 * self.MSE_LossFn(ft_imgt, ft_IFrame) 134 | 135 | recn_loss = 128 * self.cL1_lossFn(imgt, IFrame) 136 | ie = 128 * self.L1_lossFn(imgt, IFrame) 137 | 138 | loss = recn_loss #+ feature_recn_loss 139 | 140 | MSE_val = self.MSE_LossFn(imgt, IFrame) 141 | psnr = (10 * log10(4 / MSE_val.item())) 142 | 143 | print('Loss:', loss.item(), 'psnr:', psnr, 144 | 'recn_loss:', recn_loss.item(), 145 | # 'feature_recn_loss:', feature_recn_loss.item() 146 | ) 147 | 148 | return loss, MSE_val, ie, imgt 149 | 150 | class DetailEnhance_last(nn.Module): 151 | 152 | 153 | def __init__(self): 154 | 155 | super(DetailEnhance_last, self).__init__() 156 | 157 | self.feature_level = 3 158 | 159 | self.extract_features = model.ValidationFeatures() 160 | self.extract_aligned_features = model.ExtractAlignedFeatures(n_res=5) # 4 5 161 | self.pcd_align = model.PCD_Align(groups=8) # 4 8 162 | self.tsa_fusion = model.TSA_Fusion(nframes=3, center=1) 163 | 164 | self.reconstruct = model.Reconstruct(n_res=20) # 5 40 165 | 166 | # Loss calculate 167 | self.L1_lossFn = nn.L1Loss() 168 | self.sL1_lossFn = nn.SmoothL1Loss() 169 | self.cL1_lossFn = loss.CharbonnierLoss() 170 | self.MSE_LossFn = nn.MSELoss() 171 | 172 | def forward(self, input): 173 | """ 174 | Network forward tensor flow 175 | 176 | :param input: a tuple of input that will be unfolded 177 | :return: medium interpolation image 178 | """ 179 | img0, img1, IFrame, ref_imgt = input 180 | 181 | ref_align_ft = self.extract_aligned_features(ref_imgt) 182 | align_ft_0 = self.extract_aligned_features(img0) 183 | align_ft_1 = self.extract_aligned_features(img1) 184 | 185 | align_ft = [self.pcd_align(align_ft_0, ref_align_ft), 186 | self.pcd_align(ref_align_ft, ref_align_ft), 187 | self.pcd_align(align_ft_1, ref_align_ft)] 188 | align_ft = torch.stack(align_ft, dim=1) 189 | 190 | tsa_ft = self.tsa_fusion(align_ft) 191 | 192 | imgt = self.reconstruct(tsa_ft, ref_imgt) 193 | 194 | 195 | # extract for loss 196 | # ft_IFrame = list(self.extract_features(IFrame))[self.feature_level] 197 | # ft_imgt = list(self.extract_features(imgt))[self.feature_level] 198 | 199 | 200 | """ 201 | ---------------------------------------------------------------------- 202 | ====================================================================== 203 | """ 204 | 205 | # Loss calculate 206 | # feature_recn_loss = 10 * self.MSE_LossFn(ft_imgt, ft_IFrame) 207 | 208 | recn_loss = 128 * self.L1_lossFn(imgt, IFrame) 209 | 210 | loss = recn_loss# + feature_recn_loss 211 | 212 | MSE_val = self.MSE_LossFn(imgt, IFrame) 213 | psnr = (10 * log10(4 / MSE_val.item())) 214 | 215 | print('Loss:', loss.item(), 'psnr:', psnr, 216 | 'recn_loss:', recn_loss.item()) 217 | # 'feature_recn_loss:', feature_recn_loss.item()) 218 | 219 | return loss, MSE_val, imgt 220 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.mean(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | -------------------------------------------------------------------------------- /src/pure_network.py: -------------------------------------------------------------------------------- 1 | import src.model as model 2 | import torch.nn as nn 3 | import torch 4 | from models.ResBlock import GlobalGenerator 5 | 6 | 7 | class StructureGen(nn.Module): 8 | 9 | def __init__(self, feature_level=3): 10 | super(StructureGen, self).__init__() 11 | 12 | self.feature_level = feature_level 13 | channel = 2 ** (6 + self.feature_level) 14 | self.extract_features = model.ExtractFeatures() 15 | self.dcn = model.DeformableConv(channel, dg=16) 16 | self.generator = GlobalGenerator(channel, 4, n_downsampling=self.feature_level) 17 | 18 | def forward(self, input): 19 | img0_e, img1_e= input 20 | 21 | ft_img0 = list(self.extract_features(img0_e))[self.feature_level] 22 | ft_img1 = list(self.extract_features(img1_e))[self.feature_level] 23 | 24 | pre_gen_ft_imgt, out_x, out_y = self.dcn(ft_img0, ft_img1) 25 | 26 | ref_imgt_e = self.generator(pre_gen_ft_imgt) 27 | 28 | # divide edge 29 | ref_imgt, edge_ref_imgt = ref_imgt_e[:, :3], ref_imgt_e[:, 3:] 30 | 31 | return ref_imgt, edge_ref_imgt 32 | 33 | 34 | 35 | class DetailEnhance(nn.Module): 36 | 37 | 38 | def __init__(self): 39 | 40 | super(DetailEnhance, self).__init__() 41 | 42 | self.feature_level = 3 43 | 44 | self.extract_features = model.ValidationFeatures() 45 | self.extract_aligned_features = model.ExtractAlignedFeatures(n_res=5) # 4 5 46 | self.pcd_align = model.PCD_Align(groups=8) # 4 8 47 | self.tsa_fusion = model.TSA_Fusion(nframes=3, center=1) 48 | 49 | self.reconstruct = model.Reconstruct(n_res=20) # 5 40 50 | 51 | def forward(self, input): 52 | """ 53 | Network forward tensor flow 54 | 55 | :param input: a tuple of input that will be unfolded 56 | :return: medium interpolation image 57 | """ 58 | img0, img1, ref_imgt = input 59 | 60 | ref_align_ft = self.extract_aligned_features(ref_imgt) 61 | align_ft_0 = self.extract_aligned_features(img0) 62 | align_ft_1 = self.extract_aligned_features(img1) 63 | 64 | align_ft = [self.pcd_align(align_ft_0, ref_align_ft), 65 | self.pcd_align(ref_align_ft, ref_align_ft), 66 | self.pcd_align(align_ft_1, ref_align_ft)] 67 | align_ft = torch.stack(align_ft, dim=1) 68 | 69 | tsa_ft = self.tsa_fusion(align_ft) 70 | 71 | imgt = self.reconstruct(tsa_ft, ref_imgt) 72 | 73 | return imgt 74 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import interpolate 2 | 3 | 4 | def feature_transform(img, n_upsample): 5 | """ 6 | transform img like feature to be visualized 7 | :param img: [B, C, H, W] 8 | :return: visualized img range from 0 to 1 9 | """ 10 | img = img[0, 1:2].repeat(3, 1, 1) 11 | img = interpolate(((img - img.min()) / (img.max() - img.min())).unsqueeze(0), scale_factor=n_upsample, 12 | align_corners=False, mode='bilinear') 13 | 14 | return img 15 | 16 | def edge_transform(img): 17 | """ 18 | transform img like feature to be visualized 19 | :param img: [B, C, H, W] 20 | :return: visualized img range from 0 to 1 21 | """ 22 | img = img[0].repeat(3, 1, 1) 23 | img = ((img - img.min()) / (img.max() - img.min())).unsqueeze(0) 24 | 25 | return img 26 | -------------------------------------------------------------------------------- /video_process.py: -------------------------------------------------------------------------------- 1 | # SeDraw 2 | import argparse 3 | import os 4 | import torch 5 | import cv2 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | import src.pure_network as layers 9 | from tqdm import tqdm 10 | import numpy as np 11 | import math 12 | import models.bdcn.bdcn as bdcn 13 | 14 | # For parsing commandline arguments 15 | def str2bool(v): 16 | if isinstance(v, bool): 17 | return v 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') 27 | parser.add_argument('--feature_level', type=int, default=3, help='Using feature_level=? in GEN, Default:3') 28 | parser.add_argument('--bdcn_model', type=str, default='./models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth') 29 | parser.add_argument('--DE_pretrained', action='store_true', help='using this flag if training the model from pretrained parameters.') 30 | parser.add_argument('--DE_ckpt', type=str, help='path to DE checkpoint') 31 | parser.add_argument('--video_name', type=str, required=True, help='the path the video.') 32 | parser.add_argument('--batchsize', type=int, default=4) 33 | parser.add_argument('--fix_range', action='store_true', help="it won't change the fps without this flag.") 34 | args = parser.parse_args() 35 | 36 | 37 | def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): 38 | 39 | 40 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 41 | with open(path, 'rb') as f: 42 | img = Image.open(f) 43 | # Resize image if specified. 44 | resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img 45 | # Crop image if crop area specified. 46 | cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img 47 | # Flip image horizontally if specified. 48 | flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img 49 | return flipped_img.convert('RGB') 50 | 51 | 52 | 53 | bdcn = bdcn.BDCN() 54 | bdcn.cuda() 55 | structure_gen = layers.StructureGen(feature_level=args.feature_level) 56 | structure_gen.cuda() 57 | detail_enhance = layers.DetailEnhance() 58 | detail_enhance.cuda() 59 | 60 | 61 | # Channel wise mean calculated on adobe240-fps training dataset 62 | mean = [0.5, 0.5, 0.5] 63 | std = [0.5, 0.5, 0.5] 64 | normalize = transforms.Normalize(mean=mean, 65 | std=std) 66 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 67 | 68 | negmean = [-1 for x in mean] 69 | restd = [2, 2, 2] 70 | revNormalize = transforms.Normalize(mean=negmean, std=restd) 71 | TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) 72 | 73 | 74 | def ToImage(frame0, frame1): 75 | 76 | with torch.no_grad(): 77 | 78 | img0 = frame0.cuda() 79 | img1 = frame1.cuda() 80 | 81 | img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) 82 | img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) 83 | ref_imgt, _ = structure_gen((img0_e, img1_e)) 84 | imgt = detail_enhance((img0, img1, ref_imgt)) 85 | # imgt = detail_enhance((img0, img1, imgt)) 86 | imgt = torch.clamp(imgt, max=1., min=-1.) 87 | 88 | return imgt 89 | 90 | 91 | def main(): 92 | # initial 93 | 94 | bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) 95 | dict1 = torch.load(args.checkpoint) 96 | structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) 97 | detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) 98 | 99 | bdcn.eval() 100 | structure_gen.eval() 101 | detail_enhance.eval() 102 | 103 | if not os.path.isfile(args.video_name): 104 | print('video not exist!') 105 | video = cv2.VideoCapture(args.video_name) 106 | if args.fix_range: 107 | fps = video.get(cv2.CAP_PROP_FPS) * 2 108 | else: 109 | # fps = video.get(cv2.CAP_PROP_FPS) 110 | fps = 25 111 | size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))) 112 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 113 | # fourcc = int(video.get(cv2.CAP_PROP_FOURCC)) 114 | video_writer = cv2.VideoWriter(args.video_name[:-4] + '_Sedraw.mp4', 115 | fourcc, 116 | fps, 117 | size) 118 | 119 | flag = True 120 | frame_group = [] 121 | while video.isOpened(): 122 | for i in range(args.batchsize): 123 | ret, frame = video.read() 124 | if ret: 125 | frame = torch.FloatTensor(frame[:, :, ::-1].transpose(2, 0, 1).copy()) / 255 126 | frame = normalize(frame).unsqueeze(0) 127 | frame_group += [frame] 128 | else: 129 | break 130 | if len(frame_group) <= 1: 131 | break 132 | first = torch.cat(frame_group[:-1], dim=0) 133 | second = torch.cat(frame_group[1:], dim=0) 134 | 135 | middle_frame = ToImage(first, second) 136 | 137 | if flag: 138 | for i in range(first.shape[0]): 139 | first_np = first[i].cpu().numpy() 140 | first_png = np.uint8(((first_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 141 | middle_frame_np = middle_frame[i].cpu().numpy() 142 | middle_frame_png = np.uint8(((middle_frame_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 143 | video_writer.write(first_png) 144 | video_writer.write(middle_frame_png) 145 | second_np = second[-1].cpu().numpy() 146 | second_png = np.uint8(((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 147 | video_writer.write(second_png) 148 | frame_group = [second[-1].unsqueeze(0)] 149 | flag = False 150 | else: 151 | for i in range(second.shape[0]): 152 | middle_frame_np = middle_frame[i].cpu().numpy() 153 | middle_frame_png = np.uint8(((middle_frame_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 154 | second_np = second[i].cpu().numpy() 155 | second_png = np.uint8(((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) 156 | video_writer.write(middle_frame_png) 157 | video_writer.write(second_png) 158 | frame_group = [second[-1].unsqueeze(0)] 159 | 160 | video_writer.release() 161 | 162 | 163 | main() 164 | 165 | 166 | --------------------------------------------------------------------------------