├── .gitignore
├── LICENSE
├── README.md
├── codes
├── data
│ ├── Vimeo7_dataset.py
│ ├── __init__.py
│ ├── data_sampler.py
│ └── util.py
├── data_scripts
│ ├── create_lmdb_mp.py
│ ├── generate_LR_Vimeo90K.m
│ ├── generate_mod_LR_bic.py
│ └── sep_vimeo_list.py
├── models
│ ├── VideoSR_base_model.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── lr_scheduler.py
│ ├── modules
│ │ ├── DCNv2
│ │ │ ├── .gitignore
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dcn_v2.py
│ │ │ ├── make.sh
│ │ │ ├── setup.py
│ │ │ ├── src
│ │ │ │ ├── cpu
│ │ │ │ │ ├── dcn_v2_cpu.cpp
│ │ │ │ │ └── vision.h
│ │ │ │ ├── cuda
│ │ │ │ │ ├── dcn_v2_cuda.cu
│ │ │ │ │ ├── dcn_v2_im2col_cuda.cu
│ │ │ │ │ ├── dcn_v2_im2col_cuda.h
│ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ │ │ │ └── vision.h
│ │ │ │ ├── dcn_v2.h
│ │ │ │ └── vision.cpp
│ │ │ └── test.py
│ │ ├── Sakuya_arch.py
│ │ ├── __init__.py
│ │ ├── convlstm.py
│ │ ├── loss.py
│ │ └── module_util.py
│ └── networks.py
├── options
│ ├── __init__.py
│ ├── options.py
│ └── train
│ │ └── train_zsm.yml
├── test.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── make_video.py
│ └── util.py
├── video_to_zsm.py
└── zsm_my_video.sh
├── datasets
├── README.md
└── meta_info
│ ├── Vimeo7_train_keys.pkl
│ ├── fast_testset.txt
│ ├── medium_testset.txt
│ └── slow_testset.txt
├── dump
├── .gitignore
├── 4539-teaser.gif
├── demo720.gif
├── demo_thumbnail.PNG
└── framework.png
├── experiments
└── pretrained_models
│ ├── readme.md
│ └── xiang2020zooming.pth
├── requirements.txt
└── test_example
└── 0625
├── im1.png
├── im2.png
├── im3.png
├── im4.png
├── im5.png
├── im6.png
└── im7.png
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | test_example/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # dotenv
86 | .env
87 |
88 | # virtualenv
89 | .venv
90 | venv/
91 | ENV/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # other
107 | .vscode
108 | data/
109 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Zooming-Slow-Mo (CVPR-2020)
2 |
3 | By [Xiaoyu Xiang\*](https://engineering.purdue.edu/people/xiaoyu.xiang.1), [Yapeng Tian\*](http://yapengtian.org/), [Yulun Zhang](http://yulunzhang.com/), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Jan P. Allebach+](https://engineering.purdue.edu/~allebach/), [Chenliang Xu+](https://www.cs.rochester.edu/~cxu22/) (\* equal contributions, + equal advising)
4 |
5 | This is the official Pytorch implementation of _Zooming Slow-Mo: Fast and Accurate One-Stage Space-Time Video Super-Resolution_.
6 |
7 | #### [Paper](https://arxiv.org/abs/2002.11616) | [Journal Version](https://arxiv.org/abs/2104.07473) | [Demo (YouTube)](https://youtu.be/8mgD8JxBOus) | [1-min teaser (YouTube)](https://www.youtube.com/watch?v=C1o85AXUNl8) | [1-min teaser (Bilibili)](https://www.bilibili.com/video/BV1GK4y1t7nb/)
8 |
9 |
10 |
11 |
12 | Input |
13 | Output |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | |
23 |
24 |
25 |
26 | ## Updates
27 |
28 | - 2020.3.13 Add meta-info of datasets used in this paper
29 | - 2020.3.11 Add new function: video converter
30 | - 2020.3.10: Upload the complete code and pretrained models
31 |
32 | ## Contents
33 |
34 | 0. [Introduction](#introduction)
35 | 1. [Prerequisites](#Prerequisites)
36 | 2. [Get Started](#Get-Started)
37 | - [Installation](#Installation)
38 | - [Training](#Training)
39 | - [Testing](#Testing)
40 | - [Colab Notebook](#Colab-Notebook)
41 | 3. [Citations](#citations)
42 | 4. [Contact](#Contact)
43 | 5. [License](#License)
44 | 6. [Acknowledgments](#Acknowledgments)
45 |
46 | ## Introduction
47 |
48 | The repository contains the entire project (including all the preprocessing) for one-stage space-time video super-resolution with Zooming Slow-Mo.
49 |
50 | Zooming Slow-Mo is a recently proposed joint video frame interpolation (VFI) and video super-resolution (VSR) method, which directly synthesizes an HR slow-motion video from an LFR, LR video. It is going to be published in [CVPR 2020](http://cvpr2020.thecvf.com/). The most up-to-date paper with supplementary materials can be found at [arXiv](https://arxiv.org/abs/2002.11616).
51 |
52 | In Zooming Slow-Mo, we firstly temporally interpolate features of the missing LR frame by the proposed feature temporal interpolation network. Then, we propose a deformable ConvLSTM to align and aggregate temporal information simultaneously. Finally, a deep reconstruction network is adopted to predict HR slow-motion video frames. If our proposed architectures also help your research, please consider citing our paper.
53 |
54 | Zooming Slow-Mo achieves state-of-the-art performance by PSNR and SSIM in Vid4, Vimeo test sets.
55 |
56 | 
57 |
58 | ## Prerequisites
59 |
60 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
61 | - [PyTorch >= 1.1](https://pytorch.org/)
62 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
63 | - [Deformable Convolution v2](https://arxiv.org/abs/1811.11168), we adopt [CharlesShang's implementation](https://github.com/CharlesShang/DCNv2) in the submodule.
64 | - Python packages: `pip install numpy opencv-python lmdb pyyaml pickle5 matplotlib seaborn`
65 |
66 | ## Get Started
67 |
68 | ### Installation
69 |
70 | Install the required packages: `pip install -r requirements.txt`
71 |
72 | First, make sure your machine has a GPU, which is required for the DCNv2 module.
73 |
74 | 1. Clone the Zooming Slow-Mo repository. We'll call the directory that you cloned Zooming Slow-Mo as ZOOMING_ROOT.
75 |
76 | ```Shell
77 | git clone --recursive https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020.git
78 | ```
79 |
80 | 2. Compile the DCNv2:
81 |
82 | ```Shell
83 | cd $ZOOMING_ROOT/codes/models/modules/DCNv2
84 | bash make.sh # build
85 | python test.py # run examples and gradient check
86 | ```
87 |
88 | Please make sure the test script finishes successfully without any errors before running the following experiments.
89 |
90 | ### Training
91 |
92 | #### Part 1: Data Preparation
93 |
94 | 1. Download the original training + test set of `Vimeo-septuplet` (82 GB).
95 |
96 | ```Shell
97 | wget http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip
98 | apt-get install unzip
99 | unzip vimeo_septuplet.zip
100 | ```
101 |
102 | 2. Split the `Vimeo-septuplet` into a training set and a test set, make sure you change the dataset's path to your download path in script, also you need to run for the training set and test set separately:
103 |
104 | ```Shell
105 | cd $ZOOMING_ROOT/codes/data_scripts/sep_vimeo_list.py
106 | ```
107 |
108 | This will create `train` and `test` folders in the directory of **`vimeo_septuplet/sequences`**. The folder structure is as follows:
109 |
110 | ```
111 | vimeo_septuplet
112 | ├── sequences
113 | ├── 00001
114 | ├── 0266
115 | ├── im1.png
116 | ├── ...
117 | ├── im7.png
118 | ├── 0268...
119 | ├── 00002...
120 | ├── readme.txt
121 | ├──sep_trainlist.txt
122 | ├── sep_testlist.txt
123 | ```
124 |
125 | 3. Generate low resolution (LR) images. You can either do this via MATLAB or Python (remember to configure the input and output path):
126 |
127 | ```Matlab
128 | # In Matlab Command Window
129 | run $ZOOMING_ROOT/codes/data_scripts/generate_LR_Vimeo90K.m
130 | ```
131 |
132 | ```Shell
133 | python $ZOOMING_ROOT/codes/data_scripts/generate_mod_LR_bic.py
134 | ```
135 |
136 | 4. Create the LMDB files for faster I/O speed. Note that you need to configure your input and output path in the following script:
137 |
138 | ```Shell
139 | python $ZOOMING_ROOT/codes/data_scripts/create_lmdb_mp.py
140 | ```
141 |
142 | The structure of generated lmdb folder is as follows:
143 |
144 | ```
145 | Vimeo7_train.lmdb
146 | ├── data.mdb
147 | ├── lock.mdb
148 | ├── meta_info.txt
149 | ```
150 |
151 | #### Part 2: Train
152 |
153 | **Note:** In this part, we assume you are in the directory **`$ZOOMING_ROOT/codes/`**
154 |
155 | 1. Configure your training settings that can be found at [options/train](./codes/options/train). Our training settings in the paper can be found at [train_zsm.yml](https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/blob/master/codes/options/train/train_zsm.yml). We'll take this setting as an example to illustrate the following steps.
156 |
157 | 2. Train the Zooming Slow-Mo model.
158 |
159 | ```Shell
160 | python train.py -opt options/train/train_zsm.yml
161 | ```
162 |
163 | After training, your model `xxxx_G.pth` and its training states, and a corresponding log file `train_LunaTokis_scratch_b16p32f5b40n7l1_600k_Vimeo_xxxx.log` are placed in the directory of `$ZOOMING_ROOT/experiments/LunaTokis_scratch_b16p32f5b40n7l1_600k_Vimeo/`.
164 |
165 | ### Testing
166 |
167 | We provide the test code for both standard test sets (Vid4, SPMC, etc.) and custom video frames.
168 |
169 | #### Pretrained Models
170 |
171 | Our pretrained model can be downloaded via [GitHub](https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/blob/master/experiments/pretrained_models/xiang2020zooming.pth) or [Google Drive](https://drive.google.com/open?id=1xeOoZclGeSI1urY6mVCcApfCqOPgxMBK).
172 |
173 | #### From Video
174 |
175 | If you have installed ffmpeg, you can convert any video to a high-resolution and high frame-rate video using [video_to_zsm.py](./codes/video_to_zsm.py). The corresponding commands are:
176 |
177 | ```Shell
178 | cd $ZOOMING_ROOT/codes
179 | python video_to_zsm.py --video PATH/TO/VIDEO.mp4 --model PATH/TO/PRETRAINED/MODEL.pth --output PATH/TO/OUTPUT.mp4
180 | ```
181 |
182 | We also write the above commands to a Shell script, so you can directly run:
183 |
184 | ```Shell
185 | bash zsm_my_video.sh
186 | ```
187 |
188 | #### From Extracted Frames
189 |
190 | As a quick start, we also provide some example images in the [test_example](./test_example) folder. You can test the model with the following commands:
191 |
192 | ```Shell
193 | cd $ZOOMING_ROOT/codes
194 | python test.py
195 | ```
196 |
197 | - You can put your own test folders in the [test_example](./test_example) too, or just change the input path, the number of frames, etc. in [test.py](codes/test.py).
198 |
199 | - Your custom test results will be saved to a folder here: `$ZOOMING_ROOT/results/your_data_name/`.
200 |
201 | #### Evaluate on Standard Test Sets
202 |
203 | The [test.py](codes/test.py) script also provides modes for evaluation on the following test sets: `Vid4`, `SPMC`, etc. We evaluate PSNR and SSIM on the Y-channels in YCrCb color space. The commands are the same with the ones above. All you need to do is the change the data_mode and corresponding path of the standard test set.
204 |
205 | ### Colab Notebook
206 |
207 | PyTorch Colab notebook (provided by [@HanClinto](https://github.com/HanClinto)): [HighResSlowMo.ipynb](https://gist.github.com/HanClinto/49219942f76d5f20990b6d048dbacbaf)
208 |
209 | ## Citations
210 |
211 | If you find the code helpful in your resarch or work, please cite the following papers.
212 |
213 | ```BibTex
214 | @misc{xiang2021zooming,
215 | title={Zooming SlowMo: An Efficient One-Stage Framework for Space-Time Video Super-Resolution},
216 | author={Xiang, Xiaoyu and Tian, Yapeng and Zhang, Yulun and Fu, Yun and Allebach, Jan P and Xu, Chenliang},
217 | archivePrefix={arXiv},
218 | eprint={2104.07473},
219 | year={2021},
220 | primaryClass={cs.CV}
221 | }
222 |
223 | @InProceedings{xiang2020zooming,
224 | author = {Xiang, Xiaoyu and Tian, Yapeng and Zhang, Yulun and Fu, Yun and Allebach, Jan P. and Xu, Chenliang},
225 | title = {Zooming Slow-Mo: Fast and Accurate One-Stage Space-Time Video Super-Resolution},
226 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
227 | pages={3370--3379},
228 | month = {June},
229 | year = {2020}
230 | }
231 |
232 | @InProceedings{tian2018tdan,
233 | author={Yapeng Tian, Yulun Zhang, Yun Fu, and Chenliang Xu},
234 | title={TDAN: Temporally Deformable Alignment Network for Video Super-Resolution},
235 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
236 | pages={3360--3369},
237 | month = {June},
238 | year = {2020}
239 | }
240 |
241 | @InProceedings{wang2019edvr,
242 | author = {Wang, Xintao and Chan, Kelvin C.K. and Yu, Ke and Dong, Chao and Loy, Chen Change},
243 | title = {EDVR: Video restoration with enhanced deformable convolutional networks},
244 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)},
245 | month = {June},
246 | year = {2019},
247 | }
248 | ```
249 |
250 | ## Contact
251 |
252 | [Xiaoyu Xiang](https://engineering.purdue.edu/people/xiaoyu.xiang.1) and [Yapeng Tian](http://yapengtian.org/).
253 |
254 | You can also leave your questions as issues in the repository. We will be glad to answer them.
255 |
256 | ## License
257 |
258 | This project is released under the [GNU General Public License v3.0](https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/blob/master/LICENSE).
259 |
260 | ## Acknowledgments
261 |
262 | Our code is inspired by [TDAN-VSR](https://github.com/YapengTian/TDAN-VSR) and [EDVR](https://github.com/xinntao/EDVR).
263 |
--------------------------------------------------------------------------------
/codes/data/Vimeo7_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Vimeo7 dataset
3 | support reading images from lmdb, image folder and memcached
4 | '''
5 | import os.path as osp
6 | import random
7 | import pickle
8 | import logging
9 | import numpy as np
10 | import cv2
11 | import lmdb
12 | import torch
13 | import torch.utils.data as data
14 | import data.util as util
15 | try:
16 | import mc # import memcached
17 | except ImportError:
18 | pass
19 |
20 | logger = logging.getLogger('base')
21 |
22 |
23 | class Vimeo7Dataset(data.Dataset):
24 | '''
25 | Reading the training Vimeo dataset
26 | key example: train/00001/0001/im1.png
27 | GT: Ground-Truth;
28 | LQ: Low-Quality, e.g., low-resolution frames
29 | support reading N HR frames, N = 3, 5, 7
30 | '''
31 |
32 | def __init__(self, opt):
33 | super(Vimeo7Dataset, self).__init__()
34 | self.opt = opt
35 | # temporal augmentation
36 | self.interval_list = opt['interval_list']
37 | self.random_reverse = opt['random_reverse']
38 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
39 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
40 | self.half_N_frames = opt['N_frames'] // 2
41 | self.LR_N_frames = 1 + self.half_N_frames
42 | assert self.LR_N_frames > 1, 'Error: Not enough LR frames to interpolate'
43 | # determine the LQ frame list
44 | '''
45 | N | frames
46 | 1 | error
47 | 3 | 0,2
48 | 5 | 0,2,4
49 | 7 | 0,2,4,6
50 | '''
51 | self.LR_index_list = []
52 | for i in range(self.LR_N_frames):
53 | self.LR_index_list.append(i*2)
54 |
55 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
56 | self.data_type = self.opt['data_type']
57 | # low resolution inputs
58 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True
59 | # directly load image keys
60 | if opt['cache_keys']:
61 | logger.info('Using cache keys: {}'.format(opt['cache_keys']))
62 | cache_keys = opt['cache_keys']
63 | else:
64 | cache_keys = 'Vimeo7_train_keys.pkl'
65 | logger.info('Using cache keys - {}.'.format(cache_keys))
66 | self.paths_GT = pickle.load(open('./data/{}'.format(cache_keys), 'rb'))
67 |
68 | assert self.paths_GT, 'Error: GT path is empty.'
69 |
70 | if self.data_type == 'lmdb':
71 | self.GT_env, self.LQ_env = None, None
72 | elif self.data_type == 'mc': # memcached
73 | self.mclient = None
74 | elif self.data_type == 'img':
75 | pass
76 | else:
77 | raise ValueError('Wrong data type: {}'.format(self.data_type))
78 |
79 | def _init_lmdb(self):
80 | # https://github.com/chainer/chainermn/issues/129
81 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
82 | meminit=False)
83 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
84 | meminit=False)
85 |
86 | def _ensure_memcached(self):
87 | if self.mclient is None:
88 | # specify the config files
89 | server_list_config_file = None
90 | client_config_file = None
91 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
92 | client_config_file)
93 |
94 | def _read_img_mc(self, path):
95 | ''' Return BGR, HWC, [0, 255], uint8'''
96 | value = mc.pyvector()
97 | self.mclient.Get(path, value)
98 | value_buf = mc.ConvertBuffer(value)
99 | img_array = np.frombuffer(value_buf, np.uint8)
100 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
101 | return img
102 |
103 | def _read_img_mc_BGR(self, path, name_a, name_b):
104 | ''' Read BGR channels separately and then combine for 1M limits in cluster'''
105 | img_B = self._read_img_mc(
106 | osp.join(path + '_B', name_a, name_b + '.png'))
107 | img_G = self._read_img_mc(
108 | osp.join(path + '_G', name_a, name_b + '.png'))
109 | img_R = self._read_img_mc(
110 | osp.join(path + '_R', name_a, name_b + '.png'))
111 | img = cv2.merge((img_B, img_G, img_R))
112 | return img
113 |
114 | def __getitem__(self, index):
115 | if self.data_type == 'mc':
116 | self._ensure_memcached()
117 | elif self.data_type == 'lmdb':
118 | if (self.GT_env is None) or (self.LQ_env is None):
119 | self._init_lmdb()
120 |
121 | scale = self.opt['scale']
122 | # print(scale)
123 | N_frames = self.opt['N_frames']
124 | GT_size = self.opt['GT_size']
125 | key = self.paths_GT['keys'][index]
126 | name_a, name_b = key.split('_')
127 |
128 | center_frame_idx = random.randint(2, 6) # 2<= index <=6
129 |
130 | # determine the neighbor frames
131 | interval = random.choice(self.interval_list)
132 | if self.opt['border_mode']:
133 | direction = 1 # 1: forward; 0: backward
134 | if self.random_reverse and random.random() < 0.5:
135 | direction = random.choice([0, 1])
136 | if center_frame_idx + interval * (N_frames - 1) > 7:
137 | direction = 0
138 | elif center_frame_idx - interval * (N_frames - 1) < 1:
139 | direction = 1
140 | # get the neighbor list
141 | if direction == 1:
142 | neighbor_list = list(
143 | range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
144 | else:
145 | neighbor_list = list(
146 | range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
147 | else:
148 | # ensure not exceeding the borders
149 | while (center_frame_idx + self.half_N_frames * interval >
150 | 7) or (center_frame_idx - self.half_N_frames * interval < 1):
151 | center_frame_idx = random.randint(2, 6)
152 | # get the neighbor list
153 | neighbor_list = list(
154 | range(center_frame_idx - self.half_N_frames * interval,
155 | center_frame_idx + self.half_N_frames * interval + 1, interval))
156 | if self.random_reverse and random.random() < 0.5:
157 | neighbor_list.reverse()
158 |
159 | self.LQ_frames_list = []
160 | for i in self.LR_index_list:
161 | self.LQ_frames_list.append(neighbor_list[i])
162 |
163 | assert len(
164 | neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
165 | len(neighbor_list))
166 |
167 | # get the GT image (as the center frame)
168 | img_GT_l = []
169 | for v in neighbor_list:
170 | if self.data_type == 'mc':
171 | img_GT = self._read_img_mc_BGR(
172 | self.GT_root, name_a, name_b, '{}.png'.format(v))
173 | img_GT = img_GT.astype(np.float32) / 255.
174 | elif self.data_type == 'lmdb':
175 | img_GT = util.read_img(
176 | self.GT_env, key + '_{}'.format(v), (3, 256, 448))
177 | else:
178 | img_GT = util.read_img(None, osp.join(
179 | self.GT_root, name_a, name_b, 'im{}.png'.format(v)))
180 | img_GT_l.append(img_GT)
181 |
182 | # get LQ images
183 | LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
184 | img_LQ_l = []
185 | for v in self.LQ_frames_list:
186 | if self.data_type == 'mc':
187 | img_LQ = self._read_img_mc(
188 | osp.join(self.LQ_root, name_a, name_b, '/{}.png'.format(v)))
189 | img_LQ = img_LQ.astype(np.float32) / 255.
190 | elif self.data_type == 'lmdb':
191 | img_LQ = util.read_img(
192 | self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
193 | else:
194 | img_LQ = util.read_img(None,
195 | osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
196 | img_LQ_l.append(img_LQ)
197 |
198 | if self.opt['phase'] == 'train':
199 | C, H, W = LQ_size_tuple # LQ size
200 | # randomly crop
201 | if self.LR_input:
202 | LQ_size = GT_size // scale
203 | rnd_h = random.randint(0, max(0, H - LQ_size))
204 | rnd_w = random.randint(0, max(0, W - LQ_size))
205 | img_LQ_l = [v[rnd_h:rnd_h + LQ_size,
206 | rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
207 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
208 | img_GT_l = [v[rnd_h_HR:rnd_h_HR + GT_size,
209 | rnd_w_HR:rnd_w_HR + GT_size, :] for v in img_GT_l]
210 | else:
211 | rnd_h = random.randint(0, max(0, H - GT_size))
212 | rnd_w = random.randint(0, max(0, W - GT_size))
213 | img_LQ_l = [v[rnd_h:rnd_h + GT_size,
214 | rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
215 | img_GT_l = [v[rnd_h:rnd_h + GT_size,
216 | rnd_w:rnd_w + GT_size, :] for v in img_GT_l]
217 |
218 | # augmentation - flip, rotate
219 | img_LQ_l = img_LQ_l + img_GT_l
220 | rlt = util.augment(
221 | img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
222 | img_LQ_l = rlt[0:-N_frames]
223 | img_GT_l = rlt[-N_frames:]
224 |
225 | # stack LQ images to NHWC, N is the frame number
226 | img_LQs = np.stack(img_LQ_l, axis=0)
227 | img_GTs = np.stack(img_GT_l, axis=0)
228 | # BGR to RGB, HWC to CHW, numpy to tensor
229 | img_GTs = img_GTs[:, :, :, [2, 1, 0]]
230 | img_LQs = img_LQs[:, :, :, [2, 1, 0]]
231 | img_GTs = torch.from_numpy(np.ascontiguousarray(
232 | np.transpose(img_GTs, (0, 3, 1, 2)))).float()
233 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
234 | (0, 3, 1, 2)))).float()
235 | return {'LQs': img_LQs, 'GT': img_GTs, 'key': key}
236 |
237 | def __len__(self):
238 | return len(self.paths_GT['keys'])
239 |
--------------------------------------------------------------------------------
/codes/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, opt, sampler):
8 | phase = dataset_opt['phase']
9 | if phase == 'train':
10 | if opt['dist']:
11 | world_size = torch.distributed.get_world_size()
12 | num_workers = dataset_opt['n_workers']
13 | assert dataset_opt['batch_size'] % world_size == 0
14 | batch_size = dataset_opt['batch_size'] // world_size
15 | shuffle = False
16 | else:
17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
18 | batch_size = dataset_opt['batch_size']
19 | shuffle = True
20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
21 | num_workers=num_workers, sampler=sampler, drop_last=True,
22 | pin_memory=False)
23 | else:
24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
25 | pin_memory=True)
26 |
27 |
28 | def create_dataset(dataset_opt):
29 | mode = dataset_opt['mode']
30 | if mode == 'Vimeo7':
31 | from data.Vimeo7_dataset import Vimeo7Dataset as D
32 | else:
33 | raise NotImplementedError(
34 | 'Dataset [{:s}] is not recognized.'.format(mode))
35 | dataset = D(dataset_opt)
36 |
37 | logger = logging.getLogger('base')
38 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
39 | dataset_opt['name']))
40 | return dataset
41 |
--------------------------------------------------------------------------------
/codes/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError(
34 | "Requires distributed package to be available")
35 | num_replicas = dist.get_world_size()
36 | if rank is None:
37 | if not dist.is_available():
38 | raise RuntimeError(
39 | "Requires distributed package to be available")
40 | rank = dist.get_rank()
41 | self.dataset = dataset
42 | self.num_replicas = num_replicas
43 | self.rank = rank
44 | self.epoch = 0
45 | self.num_samples = int(
46 | math.ceil(len(self.dataset) * ratio / self.num_replicas))
47 | self.total_size = self.num_samples * self.num_replicas
48 |
49 | def __iter__(self):
50 | # deterministically shuffle based on epoch
51 | g = torch.Generator()
52 | g.manual_seed(self.epoch)
53 | indices = torch.randperm(self.total_size, generator=g).tolist()
54 |
55 | dsize = len(self.dataset)
56 | indices = [v % dsize for v in indices]
57 |
58 | # subsample
59 | indices = indices[self.rank:self.total_size:self.num_replicas]
60 | assert len(indices) == self.num_samples
61 |
62 | return iter(indices)
63 |
64 | def __len__(self):
65 | return self.num_samples
66 |
67 | def set_epoch(self, epoch):
68 | self.epoch = epoch
69 |
--------------------------------------------------------------------------------
/codes/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | import numpy as np
5 | import cv2
6 | import math
7 | import torch
8 |
9 | ####################
10 | # Files & IO
11 | ####################
12 |
13 | ###################### get image path list ######################
14 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
16 |
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 |
22 | def _get_paths_from_images(path):
23 | '''get image path list from image folder'''
24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
25 | images = []
26 | for dirpath, _, fnames in sorted(os.walk(path)):
27 | for fname in sorted(fnames):
28 | if is_image_file(fname):
29 | img_path = os.path.join(dirpath, fname)
30 | images.append(img_path)
31 | assert images, '{:s} has no valid image file'.format(path)
32 | return images
33 |
34 |
35 | def _get_paths_from_lmdb(dataroot):
36 | '''get image path list from lmdb meta info'''
37 | meta_info = pickle.load(
38 | open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
39 | paths = meta_info['keys']
40 | sizes = meta_info['resolution']
41 | if len(sizes) == 1:
42 | sizes = sizes * len(paths)
43 | return paths, sizes
44 |
45 |
46 | def get_image_paths(data_type, dataroot):
47 | '''get image path list
48 | support lmdb or image files'''
49 | paths, sizes = None, None
50 | if dataroot is not None:
51 | if data_type == 'lmdb':
52 | paths, sizes = _get_paths_from_lmdb(dataroot)
53 | elif data_type == 'img':
54 | paths = sorted(_get_paths_from_images(dataroot))
55 | else:
56 | raise NotImplementedError(
57 | 'data_type [{:s}] is not recognized.'.format(data_type))
58 | return paths, sizes
59 |
60 |
61 | ###################### read images ######################
62 | def _read_img_lmdb(env, key, size):
63 | '''read image from lmdb with key (w/ and w/o fixed size)
64 | size: (C, H, W) tuple'''
65 | with env.begin(write=False) as txn:
66 | buf = txn.get(key.encode('ascii'))
67 | img_flat = np.frombuffer(buf, dtype=np.uint8)
68 | C, H, W = size
69 | img = img_flat.reshape(H, W, C)
70 | return img
71 |
72 |
73 | def read_img(env, path, size=None):
74 | '''read image by cv2 or from lmdb
75 | return: Numpy float32, HWC, BGR, [0,1]'''
76 | if env is None: # img
77 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
78 | else:
79 | img = _read_img_lmdb(env, path, size)
80 | img = img.astype(np.float32) / 255.
81 | if img.ndim == 2:
82 | img = np.expand_dims(img, axis=2)
83 | # some images have 4 channels
84 | if img.shape[2] > 3:
85 | img = img[:, :, :3]
86 | return img
87 |
88 |
89 | ####################
90 | # image processing
91 | # process on numpy image
92 | ####################
93 |
94 |
95 | def augment(img_list, hflip=True, rot=True):
96 | # horizontal flip OR rotate
97 | hflip = hflip and random.random() < 0.5
98 | vflip = rot and random.random() < 0.5
99 | rot90 = rot and random.random() < 0.5
100 |
101 | def _augment(img):
102 | if hflip:
103 | img = img[:, ::-1, :]
104 | if vflip:
105 | img = img[::-1, :, :]
106 | if rot90:
107 | img = img.transpose(1, 0, 2)
108 | return img
109 |
110 | return [_augment(img) for img in img_list]
111 |
112 |
113 | def channel_convert(in_c, tar_type, img_list):
114 | # conversion among BGR, gray and y
115 | if in_c == 3 and tar_type == 'gray': # BGR to gray
116 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
117 | return [np.expand_dims(img, axis=2) for img in gray_list]
118 | elif in_c == 3 and tar_type == 'y': # BGR to y
119 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
120 | return [np.expand_dims(img, axis=2) for img in y_list]
121 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
122 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
123 | else:
124 | return img_list
125 |
126 |
127 | def rgb2ycbcr(img, only_y=True):
128 | '''same as matlab rgb2ycbcr
129 | only_y: only return Y channel
130 | Input:
131 | uint8, [0, 255]
132 | float, [0, 1]
133 | '''
134 | in_img_type = img.dtype
135 | img.astype(np.float32)
136 | if in_img_type != np.uint8:
137 | img *= 255.
138 | # convert
139 | if only_y:
140 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
141 | else:
142 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
143 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
144 | if in_img_type == np.uint8:
145 | rlt = rlt.round()
146 | else:
147 | rlt /= 255.
148 | return rlt.astype(in_img_type)
149 |
150 |
151 | def bgr2ycbcr(img, only_y=True):
152 | '''bgr version of rgb2ycbcr
153 | only_y: only return Y channel
154 | Input:
155 | uint8, [0, 255]
156 | float, [0, 1]
157 | '''
158 | in_img_type = img.dtype
159 | img.astype(np.float32)
160 | if in_img_type != np.uint8:
161 | img *= 255.
162 | # convert
163 | if only_y:
164 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
165 | else:
166 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
167 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
168 | if in_img_type == np.uint8:
169 | rlt = rlt.round()
170 | else:
171 | rlt /= 255.
172 | return rlt.astype(in_img_type)
173 |
174 |
175 | def ycbcr2rgb(img):
176 | '''same as matlab ycbcr2rgb
177 | Input:
178 | uint8, [0, 255]
179 | float, [0, 1]
180 | '''
181 | in_img_type = img.dtype
182 | img.astype(np.float32)
183 | if in_img_type != np.uint8:
184 | img *= 255.
185 | # convert
186 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
187 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
188 | if in_img_type == np.uint8:
189 | rlt = rlt.round()
190 | else:
191 | rlt /= 255.
192 | return rlt.astype(in_img_type)
193 |
194 |
195 | def modcrop(img_in, scale):
196 | # img_in: Numpy, HWC or HW
197 | img = np.copy(img_in)
198 | if img.ndim == 2:
199 | H, W = img.shape
200 | H_r, W_r = H % scale, W % scale
201 | img = img[:H - H_r, :W - W_r]
202 | elif img.ndim == 3:
203 | H, W, C = img.shape
204 | H_r, W_r = H % scale, W % scale
205 | img = img[:H - H_r, :W - W_r, :]
206 | else:
207 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
208 | return img
209 |
210 |
211 | def cubic(x):
212 | absx = torch.abs(x)
213 | absx2 = absx**2
214 | absx3 = absx**3
215 | return (1.5 * absx3 - 2.5 * absx2 + 1) * (
216 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
217 | (absx > 1) * (absx <= 2)).type_as(absx))
218 |
219 |
220 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
221 | if (scale < 1) and (antialiasing):
222 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
223 | kernel_width = kernel_width / scale
224 |
225 | # Output-space coordinates
226 | x = torch.linspace(1, out_length, out_length)
227 |
228 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
229 | # in output space maps to 0.5 in input space, and 0.5+scale in output
230 | # space maps to 1.5 in input space.
231 | u = x / scale + 0.5 * (1 - 1 / scale)
232 |
233 | # What is the left-most pixel that can be involved in the computation?
234 | left = torch.floor(u - kernel_width / 2)
235 |
236 | # What is the maximum number of pixels that can be involved in the
237 | # computation? Note: it's OK to use an extra pixel here; if the
238 | # corresponding weights are all zero, it will be eliminated at the end
239 | # of this function.
240 | P = math.ceil(kernel_width) + 2
241 |
242 | # The indices of the input pixels involved in computing the k-th output
243 | # pixel are in row k of the indices matrix.
244 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
245 | 1, P).expand(out_length, P)
246 |
247 | # The weights used to compute the k-th output pixel are in row k of the
248 | # weights matrix.
249 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
250 | # apply cubic kernel
251 | if (scale < 1) and (antialiasing):
252 | weights = scale * cubic(distance_to_center * scale)
253 | else:
254 | weights = cubic(distance_to_center)
255 | # Normalize the weights matrix so that each row sums to 1.
256 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
257 | weights = weights / weights_sum.expand(out_length, P)
258 |
259 | # If a column in weights is all zero, get rid of it. only consider the first and last column.
260 | weights_zero_tmp = torch.sum((weights == 0), 0)
261 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
262 | indices = indices.narrow(1, 1, P - 2)
263 | weights = weights.narrow(1, 1, P - 2)
264 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
265 | indices = indices.narrow(1, 0, P - 2)
266 | weights = weights.narrow(1, 0, P - 2)
267 | weights = weights.contiguous()
268 | indices = indices.contiguous()
269 | sym_len_s = -indices.min() + 1
270 | sym_len_e = indices.max() - in_length
271 | indices = indices + sym_len_s - 1
272 | return weights, indices, int(sym_len_s), int(sym_len_e)
273 |
274 |
275 | def imresize_np(img, scale, antialiasing=True):
276 | # Now the scale should be the same for H and W
277 | # input: img: Numpy, HWC BGR [0,1]
278 | # output: HWC BGR [0,1] w/o round
279 | img = torch.from_numpy(img)
280 |
281 | in_H, in_W, in_C = img.size()
282 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
283 | kernel_width = 4
284 | kernel = 'cubic'
285 |
286 | # Return the desired dimension order for performing the resize. The
287 | # strategy is to perform the resize first along the dimension with the
288 | # smallest scale factor.
289 | # Now we do not support this.
290 |
291 | # get weights and indices
292 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
293 | in_H, out_H, scale, kernel, kernel_width, antialiasing)
294 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
295 | in_W, out_W, scale, kernel, kernel_width, antialiasing)
296 | # process H dimension
297 | # symmetric copying
298 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
299 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
300 |
301 | sym_patch = img[:sym_len_Hs, :, :]
302 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
303 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
304 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
305 |
306 | sym_patch = img[-sym_len_He:, :, :]
307 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
308 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
309 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
310 |
311 | out_1 = torch.FloatTensor(out_H, in_W, in_C)
312 | kernel_width = weights_H.size(1)
313 | for i in range(out_H):
314 | idx = int(indices_H[i][0])
315 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width,
316 | :, 0].transpose(0, 1).mv(weights_H[i])
317 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width,
318 | :, 1].transpose(0, 1).mv(weights_H[i])
319 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width,
320 | :, 2].transpose(0, 1).mv(weights_H[i])
321 |
322 | # process W dimension
323 | # symmetric copying
324 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
325 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
326 |
327 | sym_patch = out_1[:, :sym_len_Ws, :]
328 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
329 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
330 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
331 |
332 | sym_patch = out_1[:, -sym_len_We:, :]
333 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
334 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
335 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
336 |
337 | out_2 = torch.FloatTensor(out_H, out_W, in_C)
338 | kernel_width = weights_W.size(1)
339 | for i in range(out_W):
340 | idx = int(indices_W[i][0])
341 | out_2[:, i, 0] = out_1_aug[:, idx:idx +
342 | kernel_width, 0].mv(weights_W[i])
343 | out_2[:, i, 1] = out_1_aug[:, idx:idx +
344 | kernel_width, 1].mv(weights_W[i])
345 | out_2[:, i, 2] = out_1_aug[:, idx:idx +
346 | kernel_width, 2].mv(weights_W[i])
347 |
348 | return out_2.numpy()
349 |
--------------------------------------------------------------------------------
/codes/data_scripts/create_lmdb_mp.py:
--------------------------------------------------------------------------------
1 | '''create lmdb files for Vimeo90K-7 frames training dataset (multiprocessing)
2 | Will read all the images to the memory
3 | '''
4 |
5 | import os
6 | import sys
7 | import os.path as osp
8 | import glob
9 | import pickle
10 | from multiprocessing import Pool
11 | import numpy as np
12 | import lmdb
13 | import cv2
14 | try:
15 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
16 | import data.util as data_util
17 | import utils.util as util
18 | except ImportError:
19 | pass
20 |
21 |
22 | def reading_image_worker(path, key):
23 | '''worker for reading images'''
24 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
25 | return (key, img)
26 |
27 |
28 | def vimeo7():
29 | '''create lmdb for the Vimeo90K-7 frames dataset, each image with fixed size
30 | GT: [3, 256, 448]
31 | Only need the 4th frame currently, e.g., 00001_0001_4
32 | LR: [3, 64, 112]
33 | With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
34 | key:
35 | Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
36 | '''
37 | # configurations
38 | mode = 'GT' # GT | LR
39 | batch = 3000 # TODO: depending on your mem size
40 | if mode == 'GT':
41 | img_folder = '/data/datasets/SR/vimeo_septuplet/sequences/train'
42 | lmdb_save_path = '/data/datasets/SR/vimeo_septuplet/vimeo7_train_GT.lmdb'
43 | txt_file = '/data/datasets/SR/vimeo_septuplet/sep_trainlist.txt'
44 | H_dst, W_dst = 256, 448
45 | elif mode == 'LR':
46 | img_folder = '/data/datasets/SR/vimeo_septuplet/sequences_LR/LR/x4/train'
47 | lmdb_save_path = '/data/datasets/SR/vimeo_septuplet/vimeo7_train_LR7.lmdb'
48 | txt_file = '/data/datasets/SR/vimeo_septuplet/sep_trainlist.txt'
49 | H_dst, W_dst = 64, 112
50 | n_thread = 40
51 | ########################################################
52 | if not lmdb_save_path.endswith('.lmdb'):
53 | raise ValueError("lmdb_save_path must end with \'lmdb\'.")
54 | # whether the lmdb file exist
55 | if osp.exists(lmdb_save_path):
56 | print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
57 | sys.exit(1)
58 |
59 | # read all the image paths to a list
60 | print('Reading image path list ...')
61 | with open(txt_file) as f:
62 | train_l = f.readlines()
63 | train_l = [v.strip() for v in train_l]
64 | all_img_list = []
65 | keys = []
66 | for line in train_l:
67 | folder = line.split('/')[0]
68 | sub_folder = line.split('/')[1]
69 | file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + '/*')
70 | all_img_list.extend(file_l)
71 | for j in range(7):
72 | keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
73 | all_img_list = sorted(all_img_list)
74 | keys = sorted(keys)
75 | if mode == 'GT':
76 | all_img_list = [v for v in all_img_list if v.endswith('.png')]
77 | keys = [v for v in keys]
78 | print('Calculating the total size of images...')
79 | data_size = sum(os.stat(v).st_size for v in all_img_list)
80 |
81 | # read all images to memory (multiprocessing)
82 | print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
83 |
84 | # create lmdb environment
85 | env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
86 | txn = env.begin(write=True) # txn is a Transaction object
87 |
88 | # write data to lmdb
89 | pbar = util.ProgressBar(len(all_img_list))
90 |
91 | i = 0
92 | for path, key in zip(all_img_list, keys):
93 | pbar.update('Write {}'.format(key))
94 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
95 | key_byte = key.encode('ascii')
96 | H, W, C = img.shape # fixed shape
97 | assert H == H_dst and W == W_dst and C == 3, 'different shape.'
98 | txn.put(key_byte, img)
99 | i += 1
100 | if i % batch == 1:
101 | txn.commit()
102 | txn = env.begin(write=True)
103 |
104 | txn.commit()
105 | env.close()
106 | print('Finish reading and writing {} images.'.format(len(all_img_list)))
107 |
108 | print('Finish writing lmdb.')
109 |
110 | # create meta information
111 | meta_info = {}
112 | if mode == 'GT':
113 | meta_info['name'] = 'Vimeo7_train_GT'
114 | elif mode == 'LR':
115 | meta_info['name'] = 'Vimeo7_train_LR7'
116 | meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
117 | key_set = []
118 | for key in keys:
119 | a, b, _ = key.split('_')
120 | key_set.add('{}_{}'.format(a, b))
121 | meta_info['keys'] = key_set
122 | pickle.dump(meta_info, open(
123 | osp.join(lmdb_save_path, 'Vimeo7_train_keys.pkl'), "wb"))
124 | print('Finish creating lmdb meta info.')
125 |
126 |
127 | def test_lmdb(dataroot, dataset='vimeo7'):
128 | env = lmdb.open(dataroot, readonly=True, lock=False,
129 | readahead=False, meminit=False)
130 | meta_info = pickle.load(open(osp.join(dataroot, 'meta_info.pkl'), "rb"))
131 | print('Name: ', meta_info['name'])
132 | print('Resolution: ', meta_info['resolution'])
133 | print('# keys: ', len(meta_info['keys']))
134 | # read one image
135 | if dataset == 'vimeo7':
136 | key = '00001_0001_4'
137 | else:
138 | raise NameError('Please check the filename format.')
139 | print('Reading {} for test.'.format(key))
140 | with env.begin(write=False) as txn:
141 | buf = txn.get(key.encode('ascii'))
142 | img_flat = np.frombuffer(buf, dtype=np.uint8)
143 | C, H, W = [int(s) for s in meta_info['resolution'].split('_')]
144 | img = img_flat.reshape(H, W, C)
145 | cv2.imwrite('test.png', img)
146 |
147 |
148 | if __name__ == "__main__":
149 | vimeo7()
150 | test_lmdb('/data/datasets/SR/vimeo_septuplet/vimeo7_train_GT.lmdb', 'vimeo7')
151 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_LR_Vimeo90K.m:
--------------------------------------------------------------------------------
1 | function generate_LR_Vimeo90K()
2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
3 |
4 | up_scale = 4;
5 | mod_scale = 4;
6 | idx = 0;
7 | filepaths = dir('/data/datasets/SR/vimeo_septuplet/sequences/train/*/*/*.png');
8 | for i = 1 : length(filepaths)
9 | [~,imname,ext] = fileparts(filepaths(i).name);
10 | folder_path = filepaths(i).folder;
11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
12 | if ~exist(save_LR_folder, 'dir')
13 | mkdir(save_LR_folder);
14 | end
15 | if isempty(imname)
16 | disp('Ignore . folder.');
17 | elseif strcmp(imname, '.')
18 | disp('Ignore .. folder.');
19 | else
20 | idx = idx + 1;
21 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
22 | fprintf(str_rlt);
23 | % read image
24 | img = imread(fullfile(folder_path, [imname, ext]));
25 | img = im2double(img);
26 | % modcrop
27 | img = modcrop(img, mod_scale);
28 | % LR
29 | im_LR = imresize(img, 1/up_scale, 'bicubic');
30 | if exist('save_LR_folder', 'var')
31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
32 | end
33 | end
34 | end
35 | end
36 |
37 | %% modcrop
38 | function img = modcrop(img, modulo)
39 | if size(img,3) == 1
40 | sz = size(img);
41 | sz = sz - mod(sz, modulo);
42 | img = img(1:sz(1), 1:sz(2));
43 | else
44 | tmpsz = size(img);
45 | sz = tmpsz(1:2);
46 | sz = sz - mod(sz, modulo);
47 | img = img(1:sz(1), 1:sz(2),:);
48 | end
49 | end
50 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import numpy as np
5 |
6 | try:
7 | sys.path.append(os.path.dirname(
8 | os.path.dirname(os.path.abspath(__file__))))
9 | from data.util import imresize_np
10 | except ImportError:
11 | pass
12 |
13 |
14 | def generate_mod_LR_bic(up_scale, sourcedir, savedir):
15 | # params: upscale factor, input directory, output directory
16 | saveHRpath = os.path.join(savedir, 'HR', 'x' + str(up_scale))
17 | saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
18 | saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
19 |
20 | if not os.path.isdir(sourcedir):
21 | print('Error: No source data found')
22 | exit(0)
23 | if not os.path.isdir(savedir):
24 | os.mkdir(savedir)
25 |
26 | if not os.path.isdir(os.path.join(savedir, 'HR')):
27 | os.mkdir(os.path.join(savedir, 'HR'))
28 | if not os.path.isdir(os.path.join(savedir, 'LR')):
29 | os.mkdir(os.path.join(savedir, 'LR'))
30 | if not os.path.isdir(os.path.join(savedir, 'Bic')):
31 | os.mkdir(os.path.join(savedir, 'Bic'))
32 |
33 | if not os.path.isdir(saveHRpath):
34 | os.mkdir(saveHRpath)
35 | else:
36 | print('It will cover ' + str(saveHRpath))
37 |
38 | if not os.path.isdir(saveLRpath):
39 | os.mkdir(saveLRpath)
40 | else:
41 | print('It will cover ' + str(saveLRpath))
42 |
43 | if not os.path.isdir(saveBicpath):
44 | os.mkdir(saveBicpath)
45 | else:
46 | print('It will cover ' + str(saveBicpath))
47 |
48 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
49 | num_files = len(filepaths)
50 |
51 | # prepare data with augementation
52 | for i in range(num_files):
53 | filename = filepaths[i]
54 | print('No.{} -- Processing {}'.format(i, filename))
55 | # read image
56 | image = cv2.imread(os.path.join(sourcedir, filename))
57 |
58 | width = int(np.floor(image.shape[1] / up_scale))
59 | height = int(np.floor(image.shape[0] / up_scale))
60 | # modcrop
61 | if len(image.shape) == 3:
62 | image_HR = image[0:up_scale * height, 0:up_scale * width, :]
63 | else:
64 | image_HR = image[0:up_scale * height, 0:up_scale * width]
65 | # LR
66 | image_LR = imresize_np(image_HR, 1 / up_scale, True)
67 | # bic
68 | image_Bic = imresize_np(image_LR, up_scale, True)
69 |
70 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
71 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
72 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
73 |
74 |
75 | if __name__ == "__main__":
76 | generate_mod_LR_bic(4, 'inPath', 'outPath')
77 |
--------------------------------------------------------------------------------
/codes/data_scripts/sep_vimeo_list.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | if __name__ == "__main__":
5 | inPath = '/data/datasets/SR/vimeo_septuplet/sequences/'
6 | outPath = '/data/datasets/SR/vimeo_septuplet/sequences/test/'
7 | guide = '/data/datasets/SR/vimeo_septuplet/sep_testlist.txt'
8 |
9 | f = open(guide, "r")
10 | lines = f.readlines()
11 |
12 | if not os.path.isdir(outPath):
13 | os.mkdir(outPath)
14 |
15 | for l in lines:
16 | line = l.replace('\n', '')
17 | this_folder = os.path.join(inPath, line)
18 | dest_folder = os.path.join(outPath, line)
19 | print(this_folder)
20 | shutil.copytree(this_folder, dest_folder)
21 | print('Done')
22 |
--------------------------------------------------------------------------------
/codes/models/VideoSR_base_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 | import models.networks as networks
8 | import models.lr_scheduler as lr_scheduler
9 | from .base_model import BaseModel
10 | from models.modules.loss import CharbonnierLoss, LapLoss
11 |
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class VideoSRBaseModel(BaseModel):
16 | def __init__(self, opt):
17 | super(VideoSRBaseModel, self).__init__(opt)
18 |
19 | if opt['dist']:
20 | self.rank = torch.distributed.get_rank()
21 | else:
22 | self.rank = -1 # non dist training
23 | train_opt = opt['train']
24 |
25 | # define network and load pretrained models
26 | self.netG = networks.define_G(opt).to(self.device)
27 |
28 | if opt['dist']:
29 | self.netG = DistributedDataParallel(
30 | self.netG, device_ids=[torch.cuda.current_device()])
31 | else:
32 | self.netG = DataParallel(self.netG)
33 | # print network
34 | self.print_network()
35 | self.load()
36 |
37 | if self.is_train:
38 | self.netG.train()
39 |
40 | # loss
41 | loss_type = train_opt['pixel_criterion']
42 | if loss_type == 'l1':
43 | self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
44 | elif loss_type == 'l2':
45 | self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
46 | elif loss_type == 'cb':
47 | self.cri_pix = CharbonnierLoss().to(self.device)
48 | elif loss_type == 'lp':
49 | self.cri_pix = LapLoss(max_levels=5).to(self.device)
50 | else:
51 | raise NotImplementedError(
52 | 'Loss type [{:s}] is not recognized.'.format(loss_type))
53 | self.l_pix_w = train_opt['pixel_weight']
54 |
55 | # optimizers
56 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
57 | optim_params = []
58 | for k, v in self.netG.named_parameters():
59 | if v.requires_grad:
60 | optim_params.append(v)
61 | else:
62 | if self.rank <= 0:
63 | logger.warning(
64 | 'Params [{:s}] will not optimize.'.format(k))
65 |
66 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
67 | weight_decay=wd_G,
68 | betas=(train_opt['beta1'], train_opt['beta2']))
69 | self.optimizers.append(self.optimizer_G)
70 | # schedulers
71 | if train_opt['lr_scheme'] == 'MultiStepLR':
72 | for optimizer in self.optimizers:
73 | self.schedulers.append(
74 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
75 | restarts=train_opt['restarts'],
76 | weights=train_opt['restart_weights'],
77 | gamma=train_opt['lr_gamma'],
78 | clear_state=train_opt['clear_state']))
79 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
80 | for optimizer in self.optimizers:
81 | self.schedulers.append(
82 | lr_scheduler.CosineAnnealingLR_Restart(
83 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
84 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
85 | else:
86 | raise NotImplementedError()
87 |
88 | self.log_dict = OrderedDict()
89 |
90 | def feed_data(self, data, need_GT=True):
91 | self.var_L = data['LQs'].to(self.device)
92 | if need_GT:
93 | self.real_H = data['GT'].to(self.device)
94 |
95 | def set_params_lr_zero(self):
96 | # fix normal module
97 | self.optimizers[0].param_groups[0]['lr'] = 0
98 |
99 | def optimize_parameters(self, step):
100 | self.optimizer_G.zero_grad()
101 | self.fake_H = self.netG(self.var_L)
102 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
103 | l_pix.backward()
104 | self.optimizer_G.step()
105 |
106 | # set log
107 | self.log_dict['l_pix'] = l_pix.item()
108 |
109 | def test(self):
110 | self.netG.eval()
111 | with torch.no_grad():
112 | self.fake_H = self.netG(self.var_L)
113 | self.netG.train()
114 |
115 | def get_current_log(self):
116 | return self.log_dict
117 |
118 | def get_current_visuals(self, need_GT=True):
119 | out_dict = OrderedDict()
120 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
121 | out_dict['restore'] = self.fake_H.detach()[0].float().cpu()
122 | if need_GT:
123 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
124 | return out_dict
125 |
126 | def print_network(self):
127 | s, n = self.get_network_description(self.netG)
128 | if isinstance(self.netG, nn.DataParallel):
129 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
130 | self.netG.module.__class__.__name__)
131 | else:
132 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
133 | if self.rank <= 0:
134 | logger.info(
135 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
136 | logger.info(s)
137 |
138 | def load(self):
139 | load_path_G = self.opt['path']['pretrain_model_G']
140 | if load_path_G is not None:
141 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
142 | self.load_network(load_path_G, self.netG,
143 | self.opt['path']['strict_load'])
144 |
145 | def save(self, iter_label):
146 | self.save_network(self.netG, 'G', iter_label)
147 |
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | model = opt['model']
7 | if model == 'VideoSR_base':
8 | from .VideoSR_base_model import VideoSRBaseModel as M
9 | else:
10 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
11 | m = M(opt)
12 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
13 | return m
14 |
--------------------------------------------------------------------------------
/codes/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device(
12 | 'cuda' if opt['gpu_ids'] is not None else 'cpu')
13 | self.is_train = opt['is_train']
14 | self.schedulers = []
15 | self.optimizers = []
16 |
17 | def feed_data(self, data):
18 | pass
19 |
20 | def optimize_parameters(self):
21 | pass
22 |
23 | def get_current_visuals(self):
24 | pass
25 |
26 | def get_current_losses(self):
27 | pass
28 |
29 | def print_network(self):
30 | pass
31 |
32 | def save(self, label):
33 | pass
34 |
35 | def load(self):
36 | pass
37 |
38 | def _set_lr(self, lr_groups_l):
39 | ''' set learning rate for warmup,
40 | lr_groups_l: list for lr_groups. each for a optimizer'''
41 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
42 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
43 | param_group['lr'] = lr
44 |
45 | def _get_init_lr(self):
46 | # get the initial lr, which is set by the scheduler
47 | init_lr_groups_l = []
48 | for optimizer in self.optimizers:
49 | init_lr_groups_l.append([v['initial_lr']
50 | for v in optimizer.param_groups])
51 | return init_lr_groups_l
52 |
53 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
54 | for scheduler in self.schedulers:
55 | scheduler.step()
56 | # set up warm up learning rate
57 | if cur_iter < warmup_iter:
58 | # get initial lr for each group
59 | init_lr_g_l = self._get_init_lr()
60 | # modify warming-up learning rates
61 | warm_up_lr_l = []
62 | for init_lr_g in init_lr_g_l:
63 | warm_up_lr_l.append(
64 | [v / warmup_iter * cur_iter for v in init_lr_g])
65 | # set learning rate
66 | self._set_lr(warm_up_lr_l)
67 |
68 | def get_current_learning_rate(self):
69 | lr_l = []
70 | for param_group in self.optimizers[0].param_groups:
71 | lr_l.append(param_group['lr'])
72 | return lr_l
73 |
74 | def get_network_description(self, network):
75 | '''Get the string and total parameters of the network'''
76 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
77 | network = network.module
78 | s = str(network)
79 | n = sum(map(lambda x: x.numel(), network.parameters()))
80 | return s, n
81 |
82 | def save_network(self, network, network_label, iter_label):
83 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
84 | save_path = os.path.join(self.opt['path']['models'], save_filename)
85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
86 | network = network.module
87 | state_dict = network.state_dict()
88 | for key, param in state_dict.items():
89 | state_dict[key] = param.cpu()
90 | torch.save(state_dict, save_path)
91 |
92 | def load_network(self, load_path, network, strict=True):
93 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
94 | network = network.module
95 | load_net = torch.load(load_path)
96 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
97 | for k, v in load_net.items():
98 | if k.startswith('module.'):
99 | load_net_clean[k[7:]] = v
100 | else:
101 | load_net_clean[k] = v
102 | network.load_state_dict(load_net_clean, strict=strict)
103 |
104 | def save_training_state(self, epoch, iter_step):
105 | '''Saves training state during training, which will be used for resuming'''
106 | state = {'epoch': epoch, 'iter': iter_step,
107 | 'schedulers': [], 'optimizers': []}
108 | for s in self.schedulers:
109 | state['schedulers'].append(s.state_dict())
110 | for o in self.optimizers:
111 | state['optimizers'].append(o.state_dict())
112 | save_filename = '{}.state'.format(iter_step)
113 | save_path = os.path.join(
114 | self.opt['path']['training_state'], save_filename)
115 | torch.save(state, save_path)
116 |
117 | def resume_training(self, resume_state):
118 | '''Resume the optimizers and schedulers for training'''
119 | resume_optimizers = resume_state['optimizers']
120 | resume_schedulers = resume_state['schedulers']
121 | assert len(resume_optimizers) == len(
122 | self.optimizers), 'Wrong lengths of optimizers'
123 | assert len(resume_schedulers) == len(
124 | self.schedulers), 'Wrong lengths of schedulers'
125 | for i, o in enumerate(resume_optimizers):
126 | self.optimizers[i].load_state_dict(o)
127 | for i, s in enumerate(resume_schedulers):
128 | self.schedulers[i].load_state_dict(s)
129 |
--------------------------------------------------------------------------------
/codes/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restart_weights = weights if weights else [1]
16 | assert len(self.restarts) == len(
17 | self.restart_weights), 'restarts and their weights do not match.'
18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
19 |
20 | def get_lr(self):
21 | if self.last_epoch in self.restarts:
22 | if self.clear_state:
23 | self.optimizer.state = defaultdict(dict)
24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
26 | if self.last_epoch not in self.milestones:
27 | return [group['lr'] for group in self.optimizer.param_groups]
28 | return [
29 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
30 | for group in self.optimizer.param_groups
31 | ]
32 |
33 |
34 | class CosineAnnealingLR_Restart(_LRScheduler):
35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
36 | self.T_period = T_period
37 | self.T_max = self.T_period[0] # current T period
38 | self.eta_min = eta_min
39 | self.restarts = restarts if restarts else [0]
40 | self.restart_weights = weights if weights else [1]
41 | self.last_restart = 0
42 | assert len(self.restarts) == len(
43 | self.restart_weights), 'restarts and their weights do not match.'
44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
45 |
46 | def get_lr(self):
47 | if self.last_epoch == 0:
48 | return self.base_lrs
49 | elif self.last_epoch in self.restarts:
50 | self.last_restart = self.last_epoch
51 | self.T_max = self.T_period[self.restarts.index(
52 | self.last_epoch) + 1]
53 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
54 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
55 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
56 | return [
57 | group['lr'] + (base_lr - self.eta_min) *
58 | (1 - math.cos(math.pi / self.T_max)) / 2
59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
60 | ]
61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
63 | (group['lr'] - self.eta_min) + self.eta_min
64 | for group in self.optimizer.param_groups]
65 |
66 |
67 | if __name__ == "__main__":
68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
69 | betas=(0.9, 0.99))
70 | ##############################
71 | # MultiStepLR_Restart
72 | ##############################
73 | # Original
74 | lr_steps = [200000, 400000, 600000, 800000]
75 | restarts = None
76 | restart_weights = None
77 |
78 | # two
79 | lr_steps = [100000, 200000, 300000, 400000,
80 | 490000, 600000, 700000, 800000, 900000, 990000]
81 | restarts = [500000]
82 | restart_weights = [1]
83 |
84 | # four
85 | lr_steps = [
86 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
87 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
88 | ]
89 | restarts = [250000, 500000, 750000]
90 | restart_weights = [1, 1, 1]
91 |
92 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
93 | clear_state=False)
94 |
95 | ##############################
96 | # Cosine Annealing Restart
97 | ##############################
98 | # two
99 | T_period = [500000, 500000]
100 | restarts = [500000]
101 | restart_weights = [1]
102 |
103 | # four
104 | T_period = [250000, 250000, 250000, 250000]
105 | restarts = [250000, 500000, 750000]
106 | restart_weights = [1, 1, 1]
107 |
108 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
109 | weights=restart_weights)
110 |
111 | ##############################
112 | # Draw figure
113 | ##############################
114 | N_iter = 1000000
115 | lr_l = list(range(N_iter))
116 | for i in range(N_iter):
117 | scheduler.step()
118 | current_lr = optimizer.param_groups[0]['lr']
119 | lr_l[i] = current_lr
120 |
121 | import matplotlib as mpl
122 | from matplotlib import pyplot as plt
123 | import matplotlib.ticker as mtick
124 | mpl.style.use('default')
125 | import seaborn
126 | seaborn.set(style='whitegrid')
127 | seaborn.set_context('paper')
128 |
129 | plt.figure(1)
130 | plt.subplot(111)
131 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
132 | plt.title('Title', fontsize=16, color='k')
133 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5,
134 | label='learning rate scheme')
135 | legend = plt.legend(loc='upper right', shadow=False)
136 | ax = plt.gca()
137 | labels = ax.get_xticks().tolist()
138 | for k, v in enumerate(labels):
139 | labels[k] = str(int(v / 1000)) + 'K'
140 | ax.set_xticklabels(labels)
141 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
142 |
143 | ax.set_ylabel('Learning rate')
144 | ax.set_xlabel('Iteration')
145 | fig = plt.gcf()
146 | plt.show()
147 |
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .idea
3 | *.so
4 | *.o
5 | *pyc
6 | _ext
7 | build
8 | DCNv2.egg-info
9 | dist
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Charles Shang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/README.md:
--------------------------------------------------------------------------------
1 | ## Deformable Convolutional Networks V2 with Pytorch 1.0
2 |
3 | ### Build
4 | ```bash
5 | ./make.sh # build
6 | python test.py # run examples and gradient check
7 | ```
8 |
9 | ### An Example
10 | - deformable conv
11 | ```python
12 | from dcn_v2 import DCN
13 | input = torch.randn(2, 64, 128, 128).cuda()
14 | # wrap all things (offset and mask) in DCN
15 | dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
16 | output = dcn(input)
17 | print(output.shape)
18 | ```
19 | - deformable roi pooling
20 | ```python
21 | from dcn_v2 import DCNPooling
22 | input = torch.randn(2, 32, 64, 64).cuda()
23 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
24 | x = torch.randint(256, (20, 1)).cuda().float()
25 | y = torch.randint(256, (20, 1)).cuda().float()
26 | w = torch.randint(64, (20, 1)).cuda().float()
27 | h = torch.randint(64, (20, 1)).cuda().float()
28 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
29 |
30 | # mdformable pooling (V2)
31 | # wrap all things (offset and mask) in DCNPooling
32 | dpooling = DCNPooling(spatial_scale=1.0 / 4,
33 | pooled_size=7,
34 | output_dim=32,
35 | no_trans=False,
36 | group_size=1,
37 | trans_std=0.1).cuda()
38 |
39 | dout = dpooling(input, rois)
40 | ```
41 | ### Note
42 | Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
43 | ```bash
44 | git checkout pytorch_0.4
45 | ```
46 |
47 | ### Known Issues:
48 |
49 | - [x] Gradient check w.r.t offset (solved)
50 | - [ ] Backward is not reentrant (minor)
51 |
52 | This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
53 |
54 | I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
55 | However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
56 | non-differential points?
57 |
58 | Update: all gradient check passes with double precision.
59 |
60 | Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
61 | float `<1e-15` for double),
62 | so it may not be a serious problem (?)
63 |
64 | Please post an issue or PR if you have any comments.
65 |
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/codes/models/modules/DCNv2/__init__.py
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/dcn_v2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import math
4 | import logging
5 | import torch
6 | from torch import nn
7 | from torch.autograd import Function
8 | from torch.nn.modules.utils import _pair
9 | from torch.autograd.function import once_differentiable
10 |
11 | import _ext as _backend
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class _DCNv2(Function):
16 | @staticmethod
17 | def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation,
18 | deformable_groups):
19 | ctx.stride = _pair(stride)
20 | ctx.padding = _pair(padding)
21 | ctx.dilation = _pair(dilation)
22 | ctx.kernel_size = _pair(weight.shape[2:4])
23 | ctx.deformable_groups = deformable_groups
24 | output = _backend.dcn_v2_forward(input, weight, bias, offset, mask, ctx.kernel_size[0],
25 | ctx.kernel_size[1], ctx.stride[0], ctx.stride[1],
26 | ctx.padding[0], ctx.padding[1], ctx.dilation[0],
27 | ctx.dilation[1], ctx.deformable_groups)
28 | ctx.save_for_backward(input, offset, mask, weight, bias)
29 | return output
30 |
31 | @staticmethod
32 | @once_differentiable
33 | def backward(ctx, grad_output):
34 | input, offset, mask, weight, bias = ctx.saved_tensors
35 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
36 | _backend.dcn_v2_backward(input, weight,
37 | bias,
38 | offset, mask,
39 | grad_output,
40 | ctx.kernel_size[0], ctx.kernel_size[1],
41 | ctx.stride[0], ctx.stride[1],
42 | ctx.padding[0], ctx.padding[1],
43 | ctx.dilation[0], ctx.dilation[1],
44 | ctx.deformable_groups)
45 |
46 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\
47 | None, None, None, None,
48 |
49 |
50 | dcn_v2_conv = _DCNv2.apply
51 |
52 |
53 | class DCNv2(nn.Module):
54 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
55 | deformable_groups=1):
56 | super(DCNv2, self).__init__()
57 | self.in_channels = in_channels
58 | self.out_channels = out_channels
59 | self.kernel_size = _pair(kernel_size)
60 | self.stride = _pair(stride)
61 | self.padding = _pair(padding)
62 | self.dilation = _pair(dilation)
63 | self.deformable_groups = deformable_groups
64 |
65 | self.weight = nn.Parameter(torch.Tensor(
66 | out_channels, in_channels, *self.kernel_size))
67 | self.bias = nn.Parameter(torch.Tensor(out_channels))
68 | self.reset_parameters()
69 |
70 | def reset_parameters(self):
71 | n = self.in_channels
72 | for k in self.kernel_size:
73 | n *= k
74 | stdv = 1. / math.sqrt(n)
75 | self.weight.data.uniform_(-stdv, stdv)
76 | self.bias.data.zero_()
77 |
78 | def forward(self, input, offset, mask):
79 | assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
80 | offset.shape[1]
81 | assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
82 | mask.shape[1]
83 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
84 | self.dilation, self.deformable_groups)
85 |
86 |
87 | class DCN(DCNv2):
88 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
89 | deformable_groups=1):
90 | super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
91 | deformable_groups)
92 |
93 | channels_ = self.deformable_groups * 3 * \
94 | self.kernel_size[0] * self.kernel_size[1]
95 | self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
96 | stride=self.stride, padding=self.padding, bias=True)
97 | self.init_offset()
98 |
99 | def init_offset(self):
100 | self.conv_offset_mask.weight.data.zero_()
101 | self.conv_offset_mask.bias.data.zero_()
102 |
103 | def forward(self, input):
104 | out = self.conv_offset_mask(input)
105 | o1, o2, mask = torch.chunk(out, 3, dim=1)
106 | offset = torch.cat((o1, o2), dim=1)
107 | mask = torch.sigmoid(mask)
108 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
109 | self.dilation, self.deformable_groups)
110 |
111 |
112 | class DCN_sep(DCNv2):
113 | '''Use other features to generate offsets and masks'''
114 |
115 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
116 | deformable_groups=1):
117 | super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
118 | dilation, deformable_groups)
119 |
120 | channels_ = self.deformable_groups * 3 * \
121 | self.kernel_size[0] * self.kernel_size[1]
122 | self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
123 | stride=self.stride, padding=self.padding, bias=True)
124 | self.init_offset()
125 |
126 | def init_offset(self):
127 | self.conv_offset_mask.weight.data.zero_()
128 | self.conv_offset_mask.bias.data.zero_()
129 |
130 | def forward(self, input, fea):
131 | '''input: input features for deformable conv
132 | fea: other features used for generating offsets and mask'''
133 | out = self.conv_offset_mask(fea)
134 | o1, o2, mask = torch.chunk(out, 3, dim=1)
135 | offset = torch.cat((o1, o2), dim=1)
136 |
137 | offset_mean = torch.mean(torch.abs(offset))
138 | if offset_mean > 100:
139 | logger.warning(
140 | 'Offset mean is {}, larger than 100.'.format(offset_mean))
141 |
142 | mask = torch.sigmoid(mask)
143 | return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
144 | self.dilation, self.deformable_groups)
145 |
146 |
147 | class _DCNv2Pooling(Function):
148 | @staticmethod
149 | def forward(ctx, input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans,
150 | group_size=1, part_size=None, sample_per_part=4, trans_std=.0):
151 | ctx.spatial_scale = spatial_scale
152 | ctx.no_trans = int(no_trans)
153 | ctx.output_dim = output_dim
154 | ctx.group_size = group_size
155 | ctx.pooled_size = pooled_size
156 | ctx.part_size = pooled_size if part_size is None else part_size
157 | ctx.sample_per_part = sample_per_part
158 | ctx.trans_std = trans_std
159 |
160 | output, output_count = \
161 | _backend.dcn_v2_psroi_pooling_forward(input, rois, offset,
162 | ctx.no_trans, ctx.spatial_scale,
163 | ctx.output_dim, ctx.group_size,
164 | ctx.pooled_size, ctx.part_size,
165 | ctx.sample_per_part, ctx.trans_std)
166 | ctx.save_for_backward(input, rois, offset, output_count)
167 | return output
168 |
169 | @staticmethod
170 | @once_differentiable
171 | def backward(ctx, grad_output):
172 | input, rois, offset, output_count = ctx.saved_tensors
173 | grad_input, grad_offset = \
174 | _backend.dcn_v2_psroi_pooling_backward(grad_output,
175 | input,
176 | rois,
177 | offset,
178 | output_count,
179 | ctx.no_trans,
180 | ctx.spatial_scale,
181 | ctx.output_dim,
182 | ctx.group_size,
183 | ctx.pooled_size,
184 | ctx.part_size,
185 | ctx.sample_per_part,
186 | ctx.trans_std)
187 |
188 | return grad_input, None, grad_offset, \
189 | None, None, None, None, None, None, None, None
190 |
191 |
192 | dcn_v2_pooling = _DCNv2Pooling.apply
193 |
194 |
195 | class DCNv2Pooling(nn.Module):
196 | def __init__(self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1,
197 | part_size=None, sample_per_part=4, trans_std=.0):
198 | super(DCNv2Pooling, self).__init__()
199 | self.spatial_scale = spatial_scale
200 | self.pooled_size = pooled_size
201 | self.output_dim = output_dim
202 | self.no_trans = no_trans
203 | self.group_size = group_size
204 | self.part_size = pooled_size if part_size is None else part_size
205 | self.sample_per_part = sample_per_part
206 | self.trans_std = trans_std
207 |
208 | def forward(self, input, rois, offset):
209 | assert input.shape[1] == self.output_dim
210 | if self.no_trans:
211 | offset = input.new()
212 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size,
213 | self.output_dim, self.no_trans, self.group_size, self.part_size,
214 | self.sample_per_part, self.trans_std)
215 |
216 |
217 | class DCNPooling(DCNv2Pooling):
218 | def __init__(self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1,
219 | part_size=None, sample_per_part=4, trans_std=.0, deform_fc_dim=1024):
220 | super(DCNPooling, self).__init__(spatial_scale, pooled_size, output_dim, no_trans,
221 | group_size, part_size, sample_per_part, trans_std)
222 |
223 | self.deform_fc_dim = deform_fc_dim
224 |
225 | if not no_trans:
226 | self.offset_mask_fc = nn.Sequential(
227 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim,
228 | self.deform_fc_dim), nn.ReLU(inplace=True),
229 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(
230 | inplace=True),
231 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3))
232 | self.offset_mask_fc[4].weight.data.zero_()
233 | self.offset_mask_fc[4].bias.data.zero_()
234 |
235 | def forward(self, input, rois):
236 | offset = input.new()
237 |
238 | if not self.no_trans:
239 |
240 | # do roi_align first
241 | n = rois.shape[0]
242 | roi = dcn_v2_pooling(
243 | input,
244 | rois,
245 | offset,
246 | self.spatial_scale,
247 | self.pooled_size,
248 | self.output_dim,
249 | True, # no trans
250 | self.group_size,
251 | self.part_size,
252 | self.sample_per_part,
253 | self.trans_std)
254 |
255 | # build mask and offset
256 | offset_mask = self.offset_mask_fc(roi.view(n, -1))
257 | offset_mask = offset_mask.view(
258 | n, 3, self.pooled_size, self.pooled_size)
259 | o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
260 | offset = torch.cat((o1, o2), dim=1)
261 | mask = torch.sigmoid(mask)
262 |
263 | # do pooling with offset and mask
264 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size,
265 | self.output_dim, self.no_trans, self.group_size, self.part_size,
266 | self.sample_per_part, self.trans_std) * mask
267 | # only roi_align
268 | return dcn_v2_pooling(input, rois, offset, self.spatial_scale, self.pooled_size,
269 | self.output_dim, self.no_trans, self.group_size, self.part_size,
270 | self.sample_per_part, self.trans_std)
271 |
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # You may need to modify the following paths before compiling.
4 |
5 | # CUDA_HOME=/usr/local/cuda-10.0 \
6 | # CUDNN_INCLUDE_DIR=/usr/local/cuda-10.0/include \
7 | # CUDNN_LIB_DIR=/usr/local/cuda-10.0/lib64 \
8 | python setup.py build develop
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import glob
5 |
6 | import torch
7 |
8 | from torch.utils.cpp_extension import CUDA_HOME
9 | from torch.utils.cpp_extension import CppExtension
10 | from torch.utils.cpp_extension import CUDAExtension
11 |
12 | from setuptools import find_packages
13 | from setuptools import setup
14 |
15 | requirements = ["torch", "torchvision"]
16 |
17 |
18 | def get_extensions():
19 | this_dir = os.path.dirname(os.path.abspath(__file__))
20 | extensions_dir = os.path.join(this_dir, "src")
21 |
22 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
23 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
24 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
25 |
26 | sources = main_file + source_cpu
27 | extension = CppExtension
28 | extra_compile_args = {"cxx": []}
29 | define_macros = []
30 |
31 | if torch.cuda.is_available() and CUDA_HOME is not None:
32 | extension = CUDAExtension
33 | sources += source_cuda
34 | define_macros += [("WITH_CUDA", None)]
35 | extra_compile_args["nvcc"] = [
36 | "-DCUDA_HAS_FP16=1",
37 | "-D__CUDA_NO_HALF_OPERATORS__",
38 | "-D__CUDA_NO_HALF_CONVERSIONS__",
39 | "-D__CUDA_NO_HALF2_OPERATORS__",
40 | ]
41 | else:
42 | raise NotImplementedError('Cuda is not availabel')
43 |
44 | sources = [os.path.join(extensions_dir, s) for s in sources]
45 | include_dirs = [extensions_dir]
46 | ext_modules = [
47 | extension(
48 | "_ext",
49 | sources,
50 | include_dirs=include_dirs,
51 | define_macros=define_macros,
52 | extra_compile_args=extra_compile_args,
53 | )
54 | ]
55 | return ext_modules
56 |
57 |
58 | setup(
59 | name="DCNv2",
60 | version="0.1",
61 | author="charlesshang",
62 | url="https://github.com/charlesshang/DCNv2",
63 | description="deformable convolutional networks",
64 | packages=find_packages(exclude=(
65 | "configs",
66 | "tests",
67 | )),
68 | # install_requires=requirements,
69 | ext_modules=get_extensions(),
70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71 | )
72 |
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/cpu/dcn_v2_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 |
7 | at::Tensor
8 | dcn_v2_cpu_forward(const at::Tensor &input,
9 | const at::Tensor &weight,
10 | const at::Tensor &bias,
11 | const at::Tensor &offset,
12 | const at::Tensor &mask,
13 | const int kernel_h,
14 | const int kernel_w,
15 | const int stride_h,
16 | const int stride_w,
17 | const int pad_h,
18 | const int pad_w,
19 | const int dilation_h,
20 | const int dilation_w,
21 | const int deformable_group)
22 | {
23 | AT_ERROR("Not implement on cpu");
24 | }
25 |
26 | std::vector
27 | dcn_v2_cpu_backward(const at::Tensor &input,
28 | const at::Tensor &weight,
29 | const at::Tensor &bias,
30 | const at::Tensor &offset,
31 | const at::Tensor &mask,
32 | const at::Tensor &grad_output,
33 | int kernel_h, int kernel_w,
34 | int stride_h, int stride_w,
35 | int pad_h, int pad_w,
36 | int dilation_h, int dilation_w,
37 | int deformable_group)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 | std::tuple
43 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
44 | const at::Tensor &bbox,
45 | const at::Tensor &trans,
46 | const int no_trans,
47 | const float spatial_scale,
48 | const int output_dim,
49 | const int group_size,
50 | const int pooled_size,
51 | const int part_size,
52 | const int sample_per_part,
53 | const float trans_std)
54 | {
55 | AT_ERROR("Not implement on cpu");
56 | }
57 |
58 | std::tuple
59 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
60 | const at::Tensor &input,
61 | const at::Tensor &bbox,
62 | const at::Tensor &trans,
63 | const at::Tensor &top_count,
64 | const int no_trans,
65 | const float spatial_scale,
66 | const int output_dim,
67 | const int group_size,
68 | const int pooled_size,
69 | const int part_size,
70 | const int sample_per_part,
71 | const float trans_std)
72 | {
73 | AT_ERROR("Not implement on cpu");
74 | }
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/cpu/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cpu_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cpu_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/cuda/dcn_v2_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "cuda/dcn_v2_im2col_cuda.h"
3 |
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | #include
10 |
11 | extern THCState *state;
12 |
13 | // author: Charles Shang
14 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
15 |
16 | // [batch gemm]
17 | // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu
18 |
19 | __global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
20 | float **columns_b, const float **ones_b,
21 | const float **weight_b, const float **bias_b,
22 | float *input, float *output,
23 | float *columns, float *ones,
24 | float *weight, float *bias,
25 | const int input_stride, const int output_stride,
26 | const int columns_stride, const int ones_stride,
27 | const int num_batches)
28 | {
29 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
30 | if (idx < num_batches)
31 | {
32 | input_b[idx] = input + idx * input_stride;
33 | output_b[idx] = output + idx * output_stride;
34 | columns_b[idx] = columns + idx * columns_stride;
35 | ones_b[idx] = ones + idx * ones_stride;
36 | // share weights and bias within a Mini-Batch
37 | weight_b[idx] = weight;
38 | bias_b[idx] = bias;
39 | }
40 | }
41 |
42 | at::Tensor
43 | dcn_v2_cuda_forward(const at::Tensor &input,
44 | const at::Tensor &weight,
45 | const at::Tensor &bias,
46 | const at::Tensor &offset,
47 | const at::Tensor &mask,
48 | const int kernel_h,
49 | const int kernel_w,
50 | const int stride_h,
51 | const int stride_w,
52 | const int pad_h,
53 | const int pad_w,
54 | const int dilation_h,
55 | const int dilation_w,
56 | const int deformable_group)
57 | {
58 | using scalar_t = float;
59 | // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
60 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
61 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
62 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
63 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
64 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
65 |
66 | const int batch = input.size(0);
67 | const int channels = input.size(1);
68 | const int height = input.size(2);
69 | const int width = input.size(3);
70 |
71 | const int channels_out = weight.size(0);
72 | const int channels_kernel = weight.size(1);
73 | const int kernel_h_ = weight.size(2);
74 | const int kernel_w_ = weight.size(3);
75 |
76 | // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
77 | // printf("Channels: %d %d\n", channels, channels_kernel);
78 | // printf("Channels: %d %d\n", channels_out, channels_kernel);
79 |
80 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
81 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
82 |
83 | AT_ASSERTM(channels == channels_kernel,
84 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
85 |
86 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
87 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
88 |
89 | auto ones = at::ones({batch, height_out, width_out}, input.options());
90 | auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
91 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
92 |
93 | // prepare for batch-wise computing, which is significantly faster than instance-wise computing
94 | // when batch size is large.
95 | // launch batch threads
96 | int matrices_size = batch * sizeof(float *);
97 | auto input_b = static_cast(THCudaMalloc(state, matrices_size));
98 | auto output_b = static_cast(THCudaMalloc(state, matrices_size));
99 | auto columns_b = static_cast(THCudaMalloc(state, matrices_size));
100 | auto ones_b = static_cast(THCudaMalloc(state, matrices_size));
101 | auto weight_b = static_cast(THCudaMalloc(state, matrices_size));
102 | auto bias_b = static_cast(THCudaMalloc(state, matrices_size));
103 |
104 | const int block = 128;
105 | const int grid = (batch + block - 1) / block;
106 |
107 | createBatchGemmBuffer<<>>(
108 | input_b, output_b,
109 | columns_b, ones_b,
110 | weight_b, bias_b,
111 | input.data(),
112 | output.data(),
113 | columns.data(),
114 | ones.data(),
115 | weight.data(),
116 | bias.data(),
117 | channels * width * height,
118 | channels_out * width_out * height_out,
119 | channels * kernel_h * kernel_w * height_out * width_out,
120 | height_out * width_out,
121 | batch);
122 |
123 | long m_ = channels_out;
124 | long n_ = height_out * width_out;
125 | long k_ = 1;
126 | THCudaBlas_SgemmBatched(state,
127 | 't',
128 | 'n',
129 | n_,
130 | m_,
131 | k_,
132 | 1.0f,
133 | ones_b, k_,
134 | bias_b, k_,
135 | 0.0f,
136 | output_b, n_,
137 | batch);
138 |
139 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
140 | input.data(),
141 | offset.data(),
142 | mask.data(),
143 | batch, channels, height, width,
144 | height_out, width_out, kernel_h, kernel_w,
145 | pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
146 | deformable_group,
147 | columns.data());
148 |
149 | long m = channels_out;
150 | long n = height_out * width_out;
151 | long k = channels * kernel_h * kernel_w;
152 | THCudaBlas_SgemmBatched(state,
153 | 'n',
154 | 'n',
155 | n,
156 | m,
157 | k,
158 | 1.0f,
159 | (const float **)columns_b, n,
160 | weight_b, k,
161 | 1.0f,
162 | output_b, n,
163 | batch);
164 |
165 | THCudaFree(state, input_b);
166 | THCudaFree(state, output_b);
167 | THCudaFree(state, columns_b);
168 | THCudaFree(state, ones_b);
169 | THCudaFree(state, weight_b);
170 | THCudaFree(state, bias_b);
171 | return output;
172 | }
173 |
174 | __global__ void createBatchGemmBufferBackward(
175 | float **grad_output_b,
176 | float **columns_b,
177 | float **ones_b,
178 | float **weight_b,
179 | float **grad_weight_b,
180 | float **grad_bias_b,
181 | float *grad_output,
182 | float *columns,
183 | float *ones,
184 | float *weight,
185 | float *grad_weight,
186 | float *grad_bias,
187 | const int grad_output_stride,
188 | const int columns_stride,
189 | const int ones_stride,
190 | const int num_batches)
191 | {
192 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
193 | if (idx < num_batches)
194 | {
195 | grad_output_b[idx] = grad_output + idx * grad_output_stride;
196 | columns_b[idx] = columns + idx * columns_stride;
197 | ones_b[idx] = ones + idx * ones_stride;
198 |
199 | // share weights and bias within a Mini-Batch
200 | weight_b[idx] = weight;
201 | grad_weight_b[idx] = grad_weight;
202 | grad_bias_b[idx] = grad_bias;
203 | }
204 | }
205 |
206 | std::vector dcn_v2_cuda_backward(const at::Tensor &input,
207 | const at::Tensor &weight,
208 | const at::Tensor &bias,
209 | const at::Tensor &offset,
210 | const at::Tensor &mask,
211 | const at::Tensor &grad_output,
212 | int kernel_h, int kernel_w,
213 | int stride_h, int stride_w,
214 | int pad_h, int pad_w,
215 | int dilation_h, int dilation_w,
216 | int deformable_group)
217 | {
218 |
219 | THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
220 | THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
221 |
222 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
223 | AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
224 | AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
225 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
226 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
227 |
228 | const int batch = input.size(0);
229 | const int channels = input.size(1);
230 | const int height = input.size(2);
231 | const int width = input.size(3);
232 |
233 | const int channels_out = weight.size(0);
234 | const int channels_kernel = weight.size(1);
235 | const int kernel_h_ = weight.size(2);
236 | const int kernel_w_ = weight.size(3);
237 |
238 | AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
239 | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
240 |
241 | AT_ASSERTM(channels == channels_kernel,
242 | "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
243 |
244 | const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
245 | const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
246 |
247 | auto ones = at::ones({height_out, width_out}, input.options());
248 | auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
249 | auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
250 |
251 | auto grad_input = at::zeros_like(input);
252 | auto grad_weight = at::zeros_like(weight);
253 | auto grad_bias = at::zeros_like(bias);
254 | auto grad_offset = at::zeros_like(offset);
255 | auto grad_mask = at::zeros_like(mask);
256 |
257 | using scalar_t = float;
258 |
259 | for (int b = 0; b < batch; b++)
260 | {
261 | auto input_n = input.select(0, b);
262 | auto offset_n = offset.select(0, b);
263 | auto mask_n = mask.select(0, b);
264 | auto grad_output_n = grad_output.select(0, b);
265 | auto grad_input_n = grad_input.select(0, b);
266 | auto grad_offset_n = grad_offset.select(0, b);
267 | auto grad_mask_n = grad_mask.select(0, b);
268 |
269 | long m = channels * kernel_h * kernel_w;
270 | long n = height_out * width_out;
271 | long k = channels_out;
272 |
273 | THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
274 | grad_output_n.data(), n,
275 | weight.data(), m, 0.0f,
276 | columns.data(), n);
277 |
278 | // gradient w.r.t. input coordinate data
279 | modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
280 | columns.data(),
281 | input_n.data(),
282 | offset_n.data(),
283 | mask_n.data(),
284 | 1, channels, height, width,
285 | height_out, width_out, kernel_h, kernel_w,
286 | pad_h, pad_w, stride_h, stride_w,
287 | dilation_h, dilation_w, deformable_group,
288 | grad_offset_n.data(),
289 | grad_mask_n.data());
290 | // gradient w.r.t. input data
291 | modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
292 | columns.data(),
293 | offset_n.data(),
294 | mask_n.data(),
295 | 1, channels, height, width,
296 | height_out, width_out, kernel_h, kernel_w,
297 | pad_h, pad_w, stride_h, stride_w,
298 | dilation_h, dilation_w, deformable_group,
299 | grad_input_n.data());
300 |
301 | // gradient w.r.t. weight, dWeight should accumulate across the batch and group
302 | modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
303 | input_n.data(),
304 | offset_n.data(),
305 | mask_n.data(),
306 | 1, channels, height, width,
307 | height_out, width_out, kernel_h, kernel_w,
308 | pad_h, pad_w, stride_h, stride_w,
309 | dilation_h, dilation_w, deformable_group,
310 | columns.data());
311 |
312 | long m_ = channels_out;
313 | long n_ = channels * kernel_h * kernel_w;
314 | long k_ = height_out * width_out;
315 |
316 | THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
317 | columns.data(), k_,
318 | grad_output_n.data(), k_, 1.0f,
319 | grad_weight.data(), n_);
320 |
321 | // gradient w.r.t. bias
322 | // long m_ = channels_out;
323 | // long k__ = height_out * width_out;
324 | THCudaBlas_Sgemv(state,
325 | 't',
326 | k_, m_, 1.0f,
327 | grad_output_n.data(), k_,
328 | ones.data(), 1, 1.0f,
329 | grad_bias.data(), 1);
330 | }
331 |
332 | return {
333 | grad_input, grad_offset, grad_mask, grad_weight, grad_bias
334 | };
335 | }
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/cuda/dcn_v2_im2col_cuda.h:
--------------------------------------------------------------------------------
1 |
2 | /*!
3 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
4 | *
5 | * COPYRIGHT
6 | *
7 | * All contributions by the University of California:
8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
9 | * All rights reserved.
10 | *
11 | * All other contributions:
12 | * Copyright (c) 2014-2017, the respective contributors
13 | * All rights reserved.
14 | *
15 | * Caffe uses a shared copyright model: each contributor holds copyright over
16 | * their contributions to Caffe. The project versioning records all such
17 | * contribution and copyright details. If a contributor wants to further mark
18 | * their specific copyright on a particular contribution, they should indicate
19 | * their copyright solely in the commit message of the change when it is
20 | * committed.
21 | *
22 | * LICENSE
23 | *
24 | * Redistribution and use in source and binary forms, with or without
25 | * modification, are permitted provided that the following conditions are met:
26 | *
27 | * 1. Redistributions of source code must retain the above copyright notice, this
28 | * list of conditions and the following disclaimer.
29 | * 2. Redistributions in binary form must reproduce the above copyright notice,
30 | * this list of conditions and the following disclaimer in the documentation
31 | * and/or other materials provided with the distribution.
32 | *
33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
34 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
35 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
37 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
38 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
40 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
42 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 | *
44 | * CONTRIBUTION AGREEMENT
45 | *
46 | * By contributing to the BVLC/caffe repository through pull-request, comment,
47 | * or otherwise, the contributor releases their content to the
48 | * license and copyright terms herein.
49 | *
50 | ***************** END Caffe Copyright Notice and Disclaimer ********************
51 | *
52 | * Copyright (c) 2018 Microsoft
53 | * Licensed under The MIT License [see LICENSE for details]
54 | * \file modulated_deformable_im2col.h
55 | * \brief Function definitions of converting an image to
56 | * column matrix based on kernel, padding, dilation, and offset.
57 | * These functions are mainly used in deformable convolution operators.
58 | * \ref: https://arxiv.org/abs/1811.11168
59 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
60 | */
61 |
62 | /***************** Adapted by Charles Shang *********************/
63 |
64 | #ifndef DCN_V2_IM2COL_CUDA
65 | #define DCN_V2_IM2COL_CUDA
66 |
67 | #ifdef __cplusplus
68 | extern "C"
69 | {
70 | #endif
71 |
72 | void modulated_deformable_im2col_cuda(cudaStream_t stream,
73 | const float *data_im, const float *data_offset, const float *data_mask,
74 | const int batch_size, const int channels, const int height_im, const int width_im,
75 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
76 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
77 | const int dilation_h, const int dilation_w,
78 | const int deformable_group, float *data_col);
79 |
80 | void modulated_deformable_col2im_cuda(cudaStream_t stream,
81 | const float *data_col, const float *data_offset, const float *data_mask,
82 | const int batch_size, const int channels, const int height_im, const int width_im,
83 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
84 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
85 | const int dilation_h, const int dilation_w,
86 | const int deformable_group, float *grad_im);
87 |
88 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
89 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
90 | const int batch_size, const int channels, const int height_im, const int width_im,
91 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
92 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
93 | const int dilation_h, const int dilation_w,
94 | const int deformable_group,
95 | float *grad_offset, float *grad_mask);
96 |
97 | #ifdef __cplusplus
98 | }
99 | #endif
100 |
101 | #endif
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017 Microsoft
3 | * Licensed under The MIT License [see LICENSE for details]
4 | * \file deformable_psroi_pooling.cu
5 | * \brief
6 | * \author Yi Li, Guodong Zhang, Jifeng Dai
7 | */
8 | /***************** Adapted by Charles Shang *********************/
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #define CUDA_KERNEL_LOOP(i, n) \
23 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
24 | i < (n); \
25 | i += blockDim.x * gridDim.x)
26 |
27 | const int CUDA_NUM_THREADS = 1024;
28 | inline int GET_BLOCKS(const int N)
29 | {
30 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
31 | }
32 |
33 | template
34 | __device__ T bilinear_interp(
35 | const T *data,
36 | const T x,
37 | const T y,
38 | const int width,
39 | const int height)
40 | {
41 | int x1 = floor(x);
42 | int x2 = ceil(x);
43 | int y1 = floor(y);
44 | int y2 = ceil(y);
45 | T dist_x = static_cast(x - x1);
46 | T dist_y = static_cast(y - y1);
47 | T value11 = data[y1 * width + x1];
48 | T value12 = data[y2 * width + x1];
49 | T value21 = data[y1 * width + x2];
50 | T value22 = data[y2 * width + x2];
51 | T value = (1 - dist_x) * (1 - dist_y) * value11 +
52 | (1 - dist_x) * dist_y * value12 +
53 | dist_x * (1 - dist_y) * value21 +
54 | dist_x * dist_y * value22;
55 | return value;
56 | }
57 |
58 | template
59 | __global__ void DeformablePSROIPoolForwardKernel(
60 | const int count,
61 | const T *bottom_data,
62 | const T spatial_scale,
63 | const int channels,
64 | const int height, const int width,
65 | const int pooled_height, const int pooled_width,
66 | const T *bottom_rois, const T *bottom_trans,
67 | const int no_trans,
68 | const T trans_std,
69 | const int sample_per_part,
70 | const int output_dim,
71 | const int group_size,
72 | const int part_size,
73 | const int num_classes,
74 | const int channels_each_class,
75 | T *top_data,
76 | T *top_count)
77 | {
78 | CUDA_KERNEL_LOOP(index, count)
79 | {
80 | // The output is in order (n, ctop, ph, pw)
81 | int pw = index % pooled_width;
82 | int ph = (index / pooled_width) % pooled_height;
83 | int ctop = (index / pooled_width / pooled_height) % output_dim;
84 | int n = index / pooled_width / pooled_height / output_dim;
85 |
86 | // [start, end) interval for spatial sampling
87 | const T *offset_bottom_rois = bottom_rois + n * 5;
88 | int roi_batch_ind = offset_bottom_rois[0];
89 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
90 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
91 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
92 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
93 |
94 | // Force too small ROIs to be 1x1
95 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
96 | T roi_height = max(roi_end_h - roi_start_h, 0.1);
97 |
98 | // Compute w and h at bottom
99 | T bin_size_h = roi_height / static_cast(pooled_height);
100 | T bin_size_w = roi_width / static_cast(pooled_width);
101 |
102 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
103 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
104 |
105 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
106 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
107 | int class_id = ctop / channels_each_class;
108 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
109 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
110 |
111 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
112 | wstart += trans_x * roi_width;
113 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
114 | hstart += trans_y * roi_height;
115 |
116 | T sum = 0;
117 | int count = 0;
118 | int gw = floor(static_cast(pw) * group_size / pooled_width);
119 | int gh = floor(static_cast(ph) * group_size / pooled_height);
120 | gw = min(max(gw, 0), group_size - 1);
121 | gh = min(max(gh, 0), group_size - 1);
122 |
123 | const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
124 | for (int ih = 0; ih < sample_per_part; ih++)
125 | {
126 | for (int iw = 0; iw < sample_per_part; iw++)
127 | {
128 | T w = wstart + iw * sub_bin_size_w;
129 | T h = hstart + ih * sub_bin_size_h;
130 | // bilinear interpolation
131 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
132 | {
133 | continue;
134 | }
135 | w = min(max(w, 0.), width - 1.);
136 | h = min(max(h, 0.), height - 1.);
137 | int c = (ctop * group_size + gh) * group_size + gw;
138 | T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
139 | sum += val;
140 | count++;
141 | }
142 | }
143 | top_data[index] = count == 0 ? static_cast(0) : sum / count;
144 | top_count[index] = count;
145 | }
146 | }
147 |
148 | template
149 | __global__ void DeformablePSROIPoolBackwardAccKernel(
150 | const int count,
151 | const T *top_diff,
152 | const T *top_count,
153 | const int num_rois,
154 | const T spatial_scale,
155 | const int channels,
156 | const int height, const int width,
157 | const int pooled_height, const int pooled_width,
158 | const int output_dim,
159 | T *bottom_data_diff, T *bottom_trans_diff,
160 | const T *bottom_data,
161 | const T *bottom_rois,
162 | const T *bottom_trans,
163 | const int no_trans,
164 | const T trans_std,
165 | const int sample_per_part,
166 | const int group_size,
167 | const int part_size,
168 | const int num_classes,
169 | const int channels_each_class)
170 | {
171 | CUDA_KERNEL_LOOP(index, count)
172 | {
173 | // The output is in order (n, ctop, ph, pw)
174 | int pw = index % pooled_width;
175 | int ph = (index / pooled_width) % pooled_height;
176 | int ctop = (index / pooled_width / pooled_height) % output_dim;
177 | int n = index / pooled_width / pooled_height / output_dim;
178 |
179 | // [start, end) interval for spatial sampling
180 | const T *offset_bottom_rois = bottom_rois + n * 5;
181 | int roi_batch_ind = offset_bottom_rois[0];
182 | T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
183 | T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
184 | T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
185 | T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
186 |
187 | // Force too small ROIs to be 1x1
188 | T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
189 | T roi_height = max(roi_end_h - roi_start_h, 0.1);
190 |
191 | // Compute w and h at bottom
192 | T bin_size_h = roi_height / static_cast(pooled_height);
193 | T bin_size_w = roi_width / static_cast(pooled_width);
194 |
195 | T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
196 | T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
197 |
198 | int part_h = floor(static_cast(ph) / pooled_height * part_size);
199 | int part_w = floor(static_cast(pw) / pooled_width * part_size);
200 | int class_id = ctop / channels_each_class;
201 | T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
202 | T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
203 |
204 | T wstart = static_cast(pw) * bin_size_w + roi_start_w;
205 | wstart += trans_x * roi_width;
206 | T hstart = static_cast(ph) * bin_size_h + roi_start_h;
207 | hstart += trans_y * roi_height;
208 |
209 | if (top_count[index] <= 0)
210 | {
211 | continue;
212 | }
213 | T diff_val = top_diff[index] / top_count[index];
214 | const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
215 | T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
216 | int gw = floor(static_cast(pw) * group_size / pooled_width);
217 | int gh = floor(static_cast(ph) * group_size / pooled_height);
218 | gw = min(max(gw, 0), group_size - 1);
219 | gh = min(max(gh, 0), group_size - 1);
220 |
221 | for (int ih = 0; ih < sample_per_part; ih++)
222 | {
223 | for (int iw = 0; iw < sample_per_part; iw++)
224 | {
225 | T w = wstart + iw * sub_bin_size_w;
226 | T h = hstart + ih * sub_bin_size_h;
227 | // bilinear interpolation
228 | if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
229 | {
230 | continue;
231 | }
232 | w = min(max(w, 0.), width - 1.);
233 | h = min(max(h, 0.), height - 1.);
234 | int c = (ctop * group_size + gh) * group_size + gw;
235 | // backward on feature
236 | int x0 = floor(w);
237 | int x1 = ceil(w);
238 | int y0 = floor(h);
239 | int y1 = ceil(h);
240 | T dist_x = w - x0, dist_y = h - y0;
241 | T q00 = (1 - dist_x) * (1 - dist_y);
242 | T q01 = (1 - dist_x) * dist_y;
243 | T q10 = dist_x * (1 - dist_y);
244 | T q11 = dist_x * dist_y;
245 | int bottom_index_base = c * height * width;
246 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
247 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
248 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
249 | atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
250 |
251 | if (no_trans)
252 | {
253 | continue;
254 | }
255 | T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
256 | T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
257 | T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
258 | T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
259 | T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
260 | diff_x *= roi_width;
261 | T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
262 | diff_y *= roi_height;
263 |
264 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
265 | atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
266 | }
267 | }
268 | }
269 | }
270 |
271 | std::tuple
272 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
273 | const at::Tensor &bbox,
274 | const at::Tensor &trans,
275 | const int no_trans,
276 | const float spatial_scale,
277 | const int output_dim,
278 | const int group_size,
279 | const int pooled_size,
280 | const int part_size,
281 | const int sample_per_part,
282 | const float trans_std)
283 | {
284 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
285 | AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor");
286 | AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
287 |
288 | const int batch = input.size(0);
289 | const int channels = input.size(1);
290 | const int height = input.size(2);
291 | const int width = input.size(3);
292 | const int channels_trans = no_trans ? 2 : trans.size(1);
293 | const int num_bbox = bbox.size(0);
294 |
295 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
296 | auto pooled_height = pooled_size;
297 | auto pooled_width = pooled_size;
298 |
299 | auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
300 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
301 | auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
302 |
303 | const int num_classes = no_trans ? 1 : channels_trans / 2;
304 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
305 |
306 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
307 |
308 | if (out.numel() == 0)
309 | {
310 | THCudaCheck(cudaGetLastError());
311 | return std::make_tuple(out, top_count);
312 | }
313 |
314 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
315 | dim3 block(512);
316 |
317 | AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] {
318 | DeformablePSROIPoolForwardKernel<<>>(
319 | out_size,
320 | input.contiguous().data(),
321 | spatial_scale,
322 | channels,
323 | height, width,
324 | pooled_height,
325 | pooled_width,
326 | bbox.contiguous().data(),
327 | trans.contiguous().data(),
328 | no_trans,
329 | trans_std,
330 | sample_per_part,
331 | output_dim,
332 | group_size,
333 | part_size,
334 | num_classes,
335 | channels_each_class,
336 | out.data(),
337 | top_count.data());
338 | });
339 | THCudaCheck(cudaGetLastError());
340 | return std::make_tuple(out, top_count);
341 | }
342 |
343 | std::tuple
344 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
345 | const at::Tensor &input,
346 | const at::Tensor &bbox,
347 | const at::Tensor &trans,
348 | const at::Tensor &top_count,
349 | const int no_trans,
350 | const float spatial_scale,
351 | const int output_dim,
352 | const int group_size,
353 | const int pooled_size,
354 | const int part_size,
355 | const int sample_per_part,
356 | const float trans_std)
357 | {
358 | AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor");
359 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
360 | AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor");
361 | AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
362 | AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");
363 |
364 | const int batch = input.size(0);
365 | const int channels = input.size(1);
366 | const int height = input.size(2);
367 | const int width = input.size(3);
368 | const int channels_trans = no_trans ? 2 : trans.size(1);
369 | const int num_bbox = bbox.size(0);
370 |
371 | AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
372 | auto pooled_height = pooled_size;
373 | auto pooled_width = pooled_size;
374 | long out_size = num_bbox * output_dim * pooled_height * pooled_width;
375 | const int num_classes = no_trans ? 1 : channels_trans / 2;
376 | const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
377 |
378 | auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
379 | auto trans_grad = at::zeros_like(trans);
380 |
381 | if (input_grad.numel() == 0)
382 | {
383 | THCudaCheck(cudaGetLastError());
384 | return std::make_tuple(input_grad, trans_grad);
385 | }
386 |
387 | dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
388 | dim3 block(512);
389 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
390 |
391 | AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] {
392 | DeformablePSROIPoolBackwardAccKernel<<>>(
393 | out_size,
394 | out_grad.contiguous().data(),
395 | top_count.contiguous().data(),
396 | num_bbox,
397 | spatial_scale,
398 | channels,
399 | height,
400 | width,
401 | pooled_height,
402 | pooled_width,
403 | output_dim,
404 | input_grad.contiguous().data(),
405 | trans_grad.contiguous().data(),
406 | input.contiguous().data(),
407 | bbox.contiguous().data(),
408 | trans.contiguous().data(),
409 | no_trans,
410 | trans_std,
411 | sample_per_part,
412 | group_size,
413 | part_size,
414 | num_classes,
415 | channels_each_class);
416 | });
417 | THCudaCheck(cudaGetLastError());
418 | return std::make_tuple(input_grad, trans_grad);
419 | }
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/cuda/vision.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor
5 | dcn_v2_cuda_forward(const at::Tensor &input,
6 | const at::Tensor &weight,
7 | const at::Tensor &bias,
8 | const at::Tensor &offset,
9 | const at::Tensor &mask,
10 | const int kernel_h,
11 | const int kernel_w,
12 | const int stride_h,
13 | const int stride_w,
14 | const int pad_h,
15 | const int pad_w,
16 | const int dilation_h,
17 | const int dilation_w,
18 | const int deformable_group);
19 |
20 | std::vector
21 | dcn_v2_cuda_backward(const at::Tensor &input,
22 | const at::Tensor &weight,
23 | const at::Tensor &bias,
24 | const at::Tensor &offset,
25 | const at::Tensor &mask,
26 | const at::Tensor &grad_output,
27 | int kernel_h, int kernel_w,
28 | int stride_h, int stride_w,
29 | int pad_h, int pad_w,
30 | int dilation_h, int dilation_w,
31 | int deformable_group);
32 |
33 |
34 | std::tuple
35 | dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
36 | const at::Tensor &bbox,
37 | const at::Tensor &trans,
38 | const int no_trans,
39 | const float spatial_scale,
40 | const int output_dim,
41 | const int group_size,
42 | const int pooled_size,
43 | const int part_size,
44 | const int sample_per_part,
45 | const float trans_std);
46 |
47 | std::tuple
48 | dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
49 | const at::Tensor &input,
50 | const at::Tensor &bbox,
51 | const at::Tensor &trans,
52 | const at::Tensor &top_count,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/dcn_v2.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "cpu/vision.h"
4 |
5 | #ifdef WITH_CUDA
6 | #include "cuda/vision.h"
7 | #endif
8 |
9 | at::Tensor
10 | dcn_v2_forward(const at::Tensor &input,
11 | const at::Tensor &weight,
12 | const at::Tensor &bias,
13 | const at::Tensor &offset,
14 | const at::Tensor &mask,
15 | const int kernel_h,
16 | const int kernel_w,
17 | const int stride_h,
18 | const int stride_w,
19 | const int pad_h,
20 | const int pad_w,
21 | const int dilation_h,
22 | const int dilation_w,
23 | const int deformable_group)
24 | {
25 | if (input.type().is_cuda())
26 | {
27 | #ifdef WITH_CUDA
28 | return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
29 | kernel_h, kernel_w,
30 | stride_h, stride_w,
31 | pad_h, pad_w,
32 | dilation_h, dilation_w,
33 | deformable_group);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | dcn_v2_backward(const at::Tensor &input,
43 | const at::Tensor &weight,
44 | const at::Tensor &bias,
45 | const at::Tensor &offset,
46 | const at::Tensor &mask,
47 | const at::Tensor &grad_output,
48 | int kernel_h, int kernel_w,
49 | int stride_h, int stride_w,
50 | int pad_h, int pad_w,
51 | int dilation_h, int dilation_w,
52 | int deformable_group)
53 | {
54 | if (input.type().is_cuda())
55 | {
56 | #ifdef WITH_CUDA
57 | return dcn_v2_cuda_backward(input,
58 | weight,
59 | bias,
60 | offset,
61 | mask,
62 | grad_output,
63 | kernel_h, kernel_w,
64 | stride_h, stride_w,
65 | pad_h, pad_w,
66 | dilation_h, dilation_w,
67 | deformable_group);
68 | #else
69 | AT_ERROR("Not compiled with GPU support");
70 | #endif
71 | }
72 | AT_ERROR("Not implemented on the CPU");
73 | }
74 |
75 | std::tuple
76 | dcn_v2_psroi_pooling_forward(const at::Tensor &input,
77 | const at::Tensor &bbox,
78 | const at::Tensor &trans,
79 | const int no_trans,
80 | const float spatial_scale,
81 | const int output_dim,
82 | const int group_size,
83 | const int pooled_size,
84 | const int part_size,
85 | const int sample_per_part,
86 | const float trans_std)
87 | {
88 | if (input.type().is_cuda())
89 | {
90 | #ifdef WITH_CUDA
91 | return dcn_v2_psroi_pooling_cuda_forward(input,
92 | bbox,
93 | trans,
94 | no_trans,
95 | spatial_scale,
96 | output_dim,
97 | group_size,
98 | pooled_size,
99 | part_size,
100 | sample_per_part,
101 | trans_std);
102 | #else
103 | AT_ERROR("Not compiled with GPU support");
104 | #endif
105 | }
106 | AT_ERROR("Not implemented on the CPU");
107 | }
108 |
109 | std::tuple
110 | dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
111 | const at::Tensor &input,
112 | const at::Tensor &bbox,
113 | const at::Tensor &trans,
114 | const at::Tensor &top_count,
115 | const int no_trans,
116 | const float spatial_scale,
117 | const int output_dim,
118 | const int group_size,
119 | const int pooled_size,
120 | const int part_size,
121 | const int sample_per_part,
122 | const float trans_std)
123 | {
124 | if (input.type().is_cuda())
125 | {
126 | #ifdef WITH_CUDA
127 | return dcn_v2_psroi_pooling_cuda_backward(out_grad,
128 | input,
129 | bbox,
130 | trans,
131 | top_count,
132 | no_trans,
133 | spatial_scale,
134 | output_dim,
135 | group_size,
136 | pooled_size,
137 | part_size,
138 | sample_per_part,
139 | trans_std);
140 | #else
141 | AT_ERROR("Not compiled with GPU support");
142 | #endif
143 | }
144 | AT_ERROR("Not implemented on the CPU");
145 | }
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/src/vision.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "dcn_v2.h"
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5 | m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
6 | m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
7 | m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
8 | m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
9 | }
10 |
--------------------------------------------------------------------------------
/codes/models/modules/DCNv2/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | from torch.autograd import gradcheck
10 |
11 | from dcn_v2 import dcn_v2_conv, DCNv2, DCN
12 | from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling
13 |
14 | deformable_groups = 1
15 | N, inC, inH, inW = 2, 2, 4, 4
16 | outC = 2
17 | kH, kW = 3, 3
18 |
19 |
20 | def conv_identify(weight, bias):
21 | weight.data.zero_()
22 | bias.data.zero_()
23 | o, i, h, w = weight.shape
24 | y = h//2
25 | x = w//2
26 | for p in range(i):
27 | for q in range(o):
28 | if p == q:
29 | weight.data[q, p, y, x] = 1.0
30 |
31 |
32 | def check_zero_offset():
33 | conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW,
34 | kernel_size=(kH, kW),
35 | stride=(1, 1),
36 | padding=(1, 1),
37 | bias=True).cuda()
38 |
39 | conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW,
40 | kernel_size=(kH, kW),
41 | stride=(1, 1),
42 | padding=(1, 1),
43 | bias=True).cuda()
44 |
45 | dcn_v2 = DCNv2(inC, outC, (kH, kW),
46 | stride=1, padding=1, dilation=1,
47 | deformable_groups=deformable_groups).cuda()
48 |
49 | conv_offset.weight.data.zero_()
50 | conv_offset.bias.data.zero_()
51 | conv_mask.weight.data.zero_()
52 | conv_mask.bias.data.zero_()
53 | conv_identify(dcn_v2.weight, dcn_v2.bias)
54 |
55 | input = torch.randn(N, inC, inH, inW).cuda()
56 | offset = conv_offset(input)
57 | mask = conv_mask(input)
58 | mask = torch.sigmoid(mask)
59 | output = dcn_v2(input, offset, mask)
60 | output *= 2
61 | d = (input - output).abs().max()
62 | if d < 1e-10:
63 | print('Zero offset passed')
64 | else:
65 | print('Zero offset failed')
66 | print(input)
67 | print(output)
68 |
69 |
70 | def check_gradient_dconv():
71 |
72 | input = torch.rand(N, inC, inH, inW).cuda() * 0.01
73 | input.requires_grad = True
74 |
75 | offset = torch.randn(N, deformable_groups * 2 *
76 | kW * kH, inH, inW).cuda() * 2
77 | # offset.data.zero_()
78 | # offset.data -= 0.5
79 | offset.requires_grad = True
80 |
81 | mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
82 | # mask.data.zero_()
83 | mask.requires_grad = True
84 | mask = torch.sigmoid(mask)
85 |
86 | weight = torch.randn(outC, inC, kH, kW).cuda()
87 | weight.requires_grad = True
88 |
89 | bias = torch.rand(outC).cuda()
90 | bias.requires_grad = True
91 |
92 | stride = 1
93 | padding = 1
94 | dilation = 1
95 |
96 | print('check_gradient_dconv: ',
97 | gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,
98 | stride, padding, dilation, deformable_groups),
99 | eps=1e-3, atol=1e-4, rtol=1e-2))
100 |
101 |
102 | def check_pooling_zero_offset():
103 |
104 | input = torch.randn(2, 16, 64, 64).cuda().zero_()
105 | input[0, :, 16:26, 16:26] = 1.
106 | input[1, :, 10:20, 20:30] = 2.
107 | rois = torch.tensor([
108 | [0, 65, 65, 103, 103],
109 | [1, 81, 41, 119, 79],
110 | ]).cuda().float()
111 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
112 | pooled_size=7,
113 | output_dim=16,
114 | no_trans=True,
115 | group_size=1,
116 | trans_std=0.0).cuda()
117 |
118 | out = pooling(input, rois, input.new())
119 | s = ', '.join(['%f' % out[i, :, :, :].mean().item()
120 | for i in range(rois.shape[0])])
121 | print(s)
122 |
123 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
124 | pooled_size=7,
125 | output_dim=16,
126 | no_trans=False,
127 | group_size=1,
128 | trans_std=0.0).cuda()
129 | offset = torch.randn(20, 2, 7, 7).cuda().zero_()
130 | dout = dpooling(input, rois, offset)
131 | s = ', '.join(['%f' % dout[i, :, :, :].mean().item()
132 | for i in range(rois.shape[0])])
133 | print(s)
134 |
135 |
136 | def check_gradient_dpooling():
137 | input = torch.randn(2, 3, 5, 5).cuda() * 0.01
138 | N = 4
139 | batch_inds = torch.randint(2, (N, 1)).cuda().float()
140 | x = torch.rand((N, 1)).cuda().float() * 15
141 | y = torch.rand((N, 1)).cuda().float() * 15
142 | w = torch.rand((N, 1)).cuda().float() * 10
143 | h = torch.rand((N, 1)).cuda().float() * 10
144 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
145 | offset = torch.randn(N, 2, 3, 3).cuda()
146 | input.requires_grad = True
147 | offset.requires_grad = True
148 |
149 | spatial_scale = 1.0 / 4
150 | pooled_size = 3
151 | output_dim = 3
152 | no_trans = 0
153 | group_size = 1
154 | trans_std = 0.0
155 | sample_per_part = 4
156 | part_size = pooled_size
157 |
158 | print('check_gradient_dpooling:',
159 | gradcheck(dcn_v2_pooling, (input, rois, offset,
160 | spatial_scale,
161 | pooled_size,
162 | output_dim,
163 | no_trans,
164 | group_size,
165 | part_size,
166 | sample_per_part,
167 | trans_std),
168 | eps=1e-4))
169 |
170 |
171 | def example_dconv():
172 | input = torch.randn(2, 64, 128, 128).cuda()
173 | # wrap all things (offset and mask) in DCN
174 | dcn = DCN(64, 64, kernel_size=(3, 3), stride=1,
175 | padding=1, deformable_groups=2).cuda()
176 | # print(dcn.weight.shape, input.shape)
177 | output = dcn(input)
178 | targert = output.new(*output.size())
179 | targert.data.uniform_(-0.01, 0.01)
180 | error = (targert - output).mean()
181 | error.backward()
182 | print(output.shape)
183 |
184 |
185 | def example_dpooling():
186 | input = torch.randn(2, 32, 64, 64).cuda()
187 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
188 | x = torch.randint(256, (20, 1)).cuda().float()
189 | y = torch.randint(256, (20, 1)).cuda().float()
190 | w = torch.randint(64, (20, 1)).cuda().float()
191 | h = torch.randint(64, (20, 1)).cuda().float()
192 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
193 | offset = torch.randn(20, 2, 7, 7).cuda()
194 | input.requires_grad = True
195 | offset.requires_grad = True
196 |
197 | # normal roi_align
198 | pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
199 | pooled_size=7,
200 | output_dim=32,
201 | no_trans=True,
202 | group_size=1,
203 | trans_std=0.1).cuda()
204 |
205 | # deformable pooling
206 | dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
207 | pooled_size=7,
208 | output_dim=32,
209 | no_trans=False,
210 | group_size=1,
211 | trans_std=0.1).cuda()
212 |
213 | out = pooling(input, rois, offset)
214 | dout = dpooling(input, rois, offset)
215 | print(out.shape)
216 | print(dout.shape)
217 |
218 | target_out = out.new(*out.size())
219 | target_out.data.uniform_(-0.01, 0.01)
220 | target_dout = dout.new(*dout.size())
221 | target_dout.data.uniform_(-0.01, 0.01)
222 | e = (target_out - out).mean()
223 | e.backward()
224 | e = (target_dout - dout).mean()
225 | e.backward()
226 |
227 |
228 | def example_mdpooling():
229 | input = torch.randn(2, 32, 64, 64).cuda()
230 | input.requires_grad = True
231 | batch_inds = torch.randint(2, (20, 1)).cuda().float()
232 | x = torch.randint(256, (20, 1)).cuda().float()
233 | y = torch.randint(256, (20, 1)).cuda().float()
234 | w = torch.randint(64, (20, 1)).cuda().float()
235 | h = torch.randint(64, (20, 1)).cuda().float()
236 | rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
237 |
238 | # mdformable pooling (V2)
239 | dpooling = DCNPooling(spatial_scale=1.0 / 4,
240 | pooled_size=7,
241 | output_dim=32,
242 | no_trans=False,
243 | group_size=1,
244 | trans_std=0.1,
245 | deform_fc_dim=1024).cuda()
246 |
247 | dout = dpooling(input, rois)
248 | target = dout.new(*dout.size())
249 | target.data.uniform_(-0.1, 0.1)
250 | error = (target - dout).mean()
251 | error.backward()
252 | print(dout.shape)
253 |
254 |
255 | if __name__ == '__main__':
256 |
257 | example_dconv()
258 | example_dpooling()
259 | example_mdpooling()
260 |
261 | check_pooling_zero_offset()
262 | # zero offset check
263 | if inC == outC:
264 | check_zero_offset()
265 |
266 | check_gradient_dpooling()
267 | check_gradient_dconv()
268 | # """
269 | # ****** Note: backward is not reentrant error may not be a serious problem,
270 | # ****** since the max error is less than 1e-7,
271 | # ****** Still looking for what trigger this problem
272 | # """
273 |
--------------------------------------------------------------------------------
/codes/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/codes/models/modules/__init__.py
--------------------------------------------------------------------------------
/codes/models/modules/convlstm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.autograd import Variable
3 | import torch
4 |
5 |
6 | class ConvLSTMCell(nn.Module):
7 |
8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
9 | """
10 | Initialize ConvLSTM cell.
11 |
12 | Parameters
13 | ----------
14 | input_size: (int, int)
15 | Height and width of input tensor as (height, width).
16 | input_dim: int
17 | Number of channels of input tensor.
18 | hidden_dim: int
19 | Number of channels of hidden state.
20 | kernel_size: (int, int)
21 | Size of the convolutional kernel.
22 | bias: bool
23 | Whether or not to add the bias.
24 | """
25 |
26 | super(ConvLSTMCell, self).__init__()
27 |
28 | self.height, self.width = input_size
29 | self.input_dim = input_dim
30 | self.hidden_dim = hidden_dim
31 |
32 | self.kernel_size = kernel_size
33 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
34 | self.bias = bias
35 |
36 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
37 | out_channels=4 * self.hidden_dim,
38 | kernel_size=self.kernel_size,
39 | padding=self.padding,
40 | bias=self.bias)
41 |
42 | def forward(self, input_tensor, cur_state):
43 |
44 | h_cur, c_cur = cur_state
45 |
46 | # concatenate along channel axis
47 | combined = torch.cat([input_tensor, h_cur], dim=1)
48 |
49 | combined_conv = self.conv(combined)
50 | cc_i, cc_f, cc_o, cc_g = torch.split(
51 | combined_conv, self.hidden_dim, dim=1)
52 | i = torch.sigmoid(cc_i)
53 | f = torch.sigmoid(cc_f)
54 | o = torch.sigmoid(cc_o)
55 | g = torch.tanh(cc_g)
56 |
57 | c_next = f * c_cur + i * g
58 | h_next = o * torch.tanh(c_next)
59 |
60 | return h_next, c_next
61 |
62 | def init_hidden(self, batch_size, tensor_size):
63 | height, width = tensor_size
64 | return (Variable(torch.zeros(batch_size, self.hidden_dim, height, width)).cuda(),
65 | Variable(torch.zeros(batch_size, self.hidden_dim, height, width)).cuda())
66 |
67 |
68 | class ConvLSTM(nn.Module):
69 |
70 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
71 | batch_first=False, bias=True, return_all_layers=False):
72 | super(ConvLSTM, self).__init__()
73 |
74 | self._check_kernel_size_consistency(kernel_size)
75 |
76 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
77 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
78 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
79 | if not len(kernel_size) == len(hidden_dim) == num_layers:
80 | raise ValueError('Inconsistent list length.')
81 |
82 | self.height, self.width = input_size
83 |
84 | self.input_dim = input_dim
85 | self.hidden_dim = hidden_dim
86 | self.kernel_size = kernel_size
87 | self.num_layers = num_layers
88 | self.batch_first = batch_first
89 | self.bias = bias
90 | self.return_all_layers = return_all_layers
91 |
92 | cell_list = []
93 | for i in range(0, self.num_layers):
94 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]
95 |
96 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
97 | input_dim=cur_input_dim,
98 | hidden_dim=self.hidden_dim[i],
99 | kernel_size=self.kernel_size[i],
100 | bias=self.bias))
101 |
102 | self.cell_list = nn.ModuleList(cell_list)
103 |
104 | def forward(self, input_tensor, hidden_state=None):
105 | """
106 |
107 | Parameters
108 | ----------
109 | input_tensor: todo
110 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
111 | hidden_state: todo
112 | None. todo implement stateful
113 |
114 | Returns
115 | -------
116 | last_state_list, layer_output
117 | """
118 | if not self.batch_first:
119 | # (t, b, c, h, w) -> (b, t, c, h, w)
120 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
121 |
122 | # Implement stateful ConvLSTM
123 | if hidden_state is not None:
124 | raise NotImplementedError()
125 | else:
126 | tensor_size = (input_tensor.size(3), input_tensor.size(4))
127 | hidden_state = self._init_hidden(
128 | batch_size=input_tensor.size(0), tensor_size=tensor_size)
129 |
130 | layer_output_list = []
131 | last_state_list = []
132 |
133 | seq_len = input_tensor.size(1)
134 | cur_layer_input = input_tensor
135 |
136 | for layer_idx in range(self.num_layers):
137 |
138 | h, c = hidden_state[layer_idx]
139 | output_inner = []
140 | for t in range(seq_len):
141 |
142 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
143 | cur_state=[h, c])
144 | output_inner.append(h)
145 |
146 | layer_output = torch.stack(output_inner, dim=1)
147 | cur_layer_input = layer_output
148 |
149 | layer_output_list.append(layer_output)
150 | last_state_list.append([h, c])
151 |
152 | if not self.return_all_layers:
153 | layer_output_list = layer_output_list[-1:]
154 | last_state_list = last_state_list[-1:]
155 |
156 | return layer_output_list, last_state_list
157 |
158 | def _init_hidden(self, batch_size, tensor_size):
159 | init_states = []
160 | for i in range(self.num_layers):
161 | init_states.append(
162 | self.cell_list[i].init_hidden(batch_size, tensor_size))
163 | return init_states
164 |
165 | @staticmethod
166 | def _check_kernel_size_consistency(kernel_size):
167 | if not (isinstance(kernel_size, tuple) or
168 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
169 | raise ValueError('`kernel_size` must be tuple or list of tuples')
170 |
171 | @staticmethod
172 | def _extend_for_multilayer(param, num_layers):
173 | if not isinstance(param, list):
174 | param = [param] * num_layers
175 | return param
176 |
177 |
178 | class ConvBLSTM(nn.Module):
179 | # Constructor
180 | def __init__(self, input_size, input_dim, hidden_dim,
181 | kernel_size, num_layers, batch_first=False, bias=True, return_all_layers=False):
182 |
183 | super(ConvBLSTM, self).__init__()
184 | self.forward_net = ConvLSTM(input_size, input_dim, hidden_dims//2, kernel_size,
185 | num_layers, batch_first=batch_first, bias=bias,
186 | return_all_layers=return_all_layers)
187 | self.reverse_net = ConvLSTM(input_size, input_dim, hidden_dims//2, kernel_size,
188 | num_layers, batch_first=batch_first, bias=bias,
189 | return_all_layers=return_all_layers)
190 |
191 | def forward(self, xforward, xreverse):
192 | """
193 | xforward, xreverse = B T C H W tensors.
194 | """
195 |
196 | y_out_fwd, _ = self.forward_net(xforward)
197 | y_out_rev, _ = self.reverse_net(xreverse)
198 |
199 | if not self.return_all_layers:
200 | # outputs of last CLSTM layer = B, T, C, H, W
201 | y_out_fwd = y_out_fwd[-1]
202 | # outputs of last CLSTM layer = B, T, C, H, W
203 | y_out_rev = y_out_rev[-1]
204 |
205 | reversed_idx = list(reversed(range(y_out_rev.shape[1])))
206 | # reverse temporal outputs.
207 | y_out_rev = y_out_rev[:, reversed_idx, ...]
208 | ycat = torch.cat((y_out_fwd, y_out_rev), dim=2)
209 |
210 | return ycat
211 |
--------------------------------------------------------------------------------
/codes/models/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as fnn
5 | from torch.autograd import Variable
6 |
7 |
8 | class CharbonnierLoss(nn.Module):
9 | """Charbonnier Loss (L1)"""
10 |
11 | def __init__(self, eps=1e-6):
12 | super(CharbonnierLoss, self).__init__()
13 | self.eps = eps
14 |
15 | def forward(self, x, y):
16 | diff = x - y
17 | loss = torch.sum(torch.sqrt(diff * diff + self.eps))
18 | return loss
19 |
20 |
21 | def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
22 | if size % 2 != 1:
23 | raise ValueError("kernel size must be uneven")
24 | grid = np.float32(np.mgrid[0:size, 0:size].T)
25 | def gaussian(x): return np.exp((x - size//2)**2/(-2*sigma**2))**2
26 | kernel = np.sum(gaussian(grid), axis=2)
27 | kernel /= np.sum(kernel)
28 | # repeat same kernel across depth dimension
29 | kernel = np.tile(kernel, (n_channels, 1, 1))
30 | # conv weight should be (out_channels, groups/in_channels, h, w),
31 | # and since we have depth-separable convolution we want the groups dimension to be 1
32 | kernel = torch.FloatTensor(kernel[:, None, :, :])
33 | if cuda:
34 | kernel = kernel.cuda()
35 | return Variable(kernel, requires_grad=False)
36 |
37 |
38 | def conv_gauss(img, kernel):
39 | """ convolve img with a gaussian kernel that has been built with build_gauss_kernel """
40 | n_channels, _, kw, kh = kernel.shape
41 | img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
42 | return fnn.conv2d(img, kernel, groups=n_channels)
43 |
44 |
45 | def laplacian_pyramid(img, kernel, max_levels=5):
46 | current = img
47 | pyr = []
48 |
49 | for level in range(max_levels):
50 | filtered = conv_gauss(current, kernel)
51 | diff = current - filtered
52 | pyr.append(diff)
53 | current = fnn.avg_pool2d(filtered, 2)
54 |
55 | pyr.append(current)
56 | return pyr
57 |
58 |
59 | class LapLoss(nn.Module):
60 | def __init__(self, max_levels=5, k_size=5, sigma=2.0):
61 | super(LapLoss, self).__init__()
62 | self.max_levels = max_levels
63 | self.k_size = k_size
64 | self.sigma = sigma
65 | self._gauss_kernel = None
66 |
67 | def forward(self, input, target):
68 | # input shape :[B, N, C, H, W]
69 | if len(input.shape) == 5:
70 | B, N, C, H, W = input.size()
71 | input = input.view(-1, C, H, W)
72 | target = target.view(-1, C, H, W)
73 | if self._gauss_kernel is None or self._gauss_kernel.shape[1] != input.shape[1]:
74 | self._gauss_kernel = build_gauss_kernel(
75 | size=self.k_size, sigma=self.sigma,
76 | n_channels=input.shape[1], cuda=input.is_cuda
77 | )
78 | pyr_input = laplacian_pyramid(
79 | input, self._gauss_kernel, self.max_levels)
80 | pyr_target = laplacian_pyramid(
81 | target, self._gauss_kernel, self.max_levels)
82 | return sum(fnn.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
83 |
84 | # if __name__ == "__main__":
85 |
--------------------------------------------------------------------------------
/codes/models/modules/module_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def make_layer(block, n_layers):
28 | layers = []
29 | for _ in range(n_layers):
30 | layers.append(block())
31 | return nn.Sequential(*layers)
32 |
33 |
34 | class ResidualBlock_noBN(nn.Module):
35 | '''Residual block w/o BN
36 | ---Conv-ReLU-Conv-+-
37 | |________________|
38 | '''
39 |
40 | def __init__(self, nf=64):
41 | super(ResidualBlock_noBN, self).__init__()
42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44 |
45 | # initialization
46 | initialize_weights([self.conv1, self.conv2], 0.1)
47 |
48 | def forward(self, x):
49 | identity = x
50 | out = F.relu(self.conv1(x), inplace=True)
51 | out = self.conv2(out)
52 | return identity + out
53 |
54 |
55 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
56 | """Warp an image or feature map with optical flow
57 | Args:
58 | x (Tensor): size (N, C, H, W)
59 | flow (Tensor): size (N, H, W, 2), normal value
60 | interp_mode (str): 'nearest' or 'bilinear'
61 | padding_mode (str): 'zeros' or 'border' or 'reflection'
62 |
63 | Returns:
64 | Tensor: warped image or feature map
65 | """
66 | print(x.size()[-2:])
67 | print(flow.size()[1:3])
68 | assert x.size()[-2:] == flow.size()[1:3]
69 | B, C, H, W = x.size()
70 | # mesh grid
71 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
72 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
73 | grid.requires_grad = False
74 | grid = grid.type_as(x)
75 | vgrid = grid + flow
76 | # scale grid to [-1,1]
77 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
78 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
79 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
80 | output = F.grid_sample(
81 | x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
82 | return output
83 |
--------------------------------------------------------------------------------
/codes/models/networks.py:
--------------------------------------------------------------------------------
1 | import models.modules.Sakuya_arch as Sakuya_arch
2 |
3 | ####################
4 | # define network
5 | ####################
6 | # Generator
7 |
8 |
9 | def define_G(opt):
10 | opt_net = opt['network_G']
11 | which_model = opt_net['which_model_G']
12 |
13 | if which_model == 'LunaTokis':
14 | netG = Sakuya_arch.LunaTokis(nf=opt_net['nf'], nframes=opt_net['nframes'],
15 | groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
16 | back_RBs=opt_net['back_RBs'])
17 | else:
18 | raise NotImplementedError(
19 | 'Generator model [{:s}] not recognized'.format(which_model))
20 |
21 | return netG
22 |
--------------------------------------------------------------------------------
/codes/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/codes/options/__init__.py
--------------------------------------------------------------------------------
/codes/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | if opt['distortion'] == 'sr' or opt['distortion'] == 'isr':
19 | scale = opt['scale']
20 |
21 | # datasets
22 | for phase, dataset in opt['datasets'].items():
23 | phase = phase.split('_')[0]
24 | dataset['phase'] = phase
25 | if opt['distortion'] == 'sr' or opt['distortion'] == 'isr':
26 | dataset['scale'] = scale
27 | is_lmdb = False
28 | if dataset.get('dataroot_GT', None) is not None:
29 | dataset['dataroot_GT'] = os.path.expanduser(dataset['dataroot_GT'])
30 | if dataset['dataroot_GT'].endswith('lmdb'):
31 | is_lmdb = True
32 | if dataset.get('dataroot_LQ', None) is not None:
33 | dataset['dataroot_LQ'] = os.path.expanduser(dataset['dataroot_LQ'])
34 | if dataset['dataroot_LQ'].endswith('lmdb'):
35 | is_lmdb = True
36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
37 | if dataset['mode'].endswith('mc'): # for memcached
38 | dataset['data_type'] = 'mc'
39 | dataset['mode'] = dataset['mode'].replace('_mc', '')
40 |
41 | # path
42 | for key, path in opt['path'].items():
43 | if path and key in opt['path'] and key != 'strict_load':
44 | opt['path'][key] = osp.expanduser(path)
45 | opt['path']['root'] = osp.abspath(
46 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
47 | if is_train:
48 | experiments_root = os.path.join(
49 | opt['path']['root'], 'experiments', opt['name'])
50 | opt['path']['experiments_root'] = experiments_root
51 | opt['path']['models'] = os.path.join(experiments_root, 'models')
52 | opt['path']['training_state'] = os.path.join(
53 | experiments_root, 'training_state')
54 | opt['path']['log'] = experiments_root
55 | opt['path']['val_images'] = os.path.join(
56 | experiments_root, 'val_images')
57 |
58 | # change some options for debug mode
59 | if 'debug' in opt['name']:
60 | opt['train']['val_freq'] = 8
61 | opt['logger']['print_freq'] = 1
62 | opt['logger']['save_checkpoint_freq'] = 8
63 | else: # test
64 | results_root = os.path.join(
65 | opt['path']['root'], 'results', opt['name'])
66 | opt['path']['results_root'] = results_root
67 | opt['path']['log'] = results_root
68 |
69 | # network
70 | if opt['distortion'] == 'sr' or opt['distortion'] == 'isr':
71 | opt['network_G']['scale'] = scale
72 |
73 | return opt
74 |
75 |
76 | def dict2str(opt, indent_l=1):
77 | '''dict to string for logger'''
78 | msg = ''
79 | for k, v in opt.items():
80 | if isinstance(v, dict):
81 | msg += ' ' * (indent_l * 2) + k + ':[\n'
82 | msg += dict2str(v, indent_l + 1)
83 | msg += ' ' * (indent_l * 2) + ']\n'
84 | else:
85 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
86 | return msg
87 |
88 |
89 | # convert to NoneDict, which return None for missing key.
90 | class NoneDict(dict):
91 | def __missing__(self, key):
92 | return None
93 |
94 |
95 | def dict_to_nonedict(opt):
96 | if isinstance(opt, dict):
97 | new_opt = dict()
98 | for key, sub_opt in opt.items():
99 | new_opt[key] = dict_to_nonedict(sub_opt)
100 | return NoneDict(**new_opt)
101 | elif isinstance(opt, list):
102 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
103 | else:
104 | return opt
105 |
106 |
107 | def check_resume(opt, resume_iter):
108 | '''Check resume states and pretrain_model paths'''
109 | logger = logging.getLogger('base')
110 | if opt['path']['resume_state']:
111 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
112 | 'pretrain_model_D', None) is not None:
113 | logger.warning(
114 | 'pretrain_model path will be ignored when resuming training.')
115 |
116 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
117 | '{}_G.pth'.format(resume_iter))
118 | logger.info('Set [pretrain_model_G] to ' +
119 | opt['path']['pretrain_model_G'])
120 | if 'gan' in opt['model']:
121 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
122 | '{}_D.pth'.format(resume_iter))
123 | logger.info('Set [pretrain_model_D] to ' +
124 | opt['path']['pretrain_model_D'])
125 |
--------------------------------------------------------------------------------
/codes/options/train/train_zsm.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: LunaTokis_scratch_b16p32f5b40n7l1_600k_Vimeo
3 | use_tb_logger: false #true
4 | model: VideoSR_base
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [0, 1, 2, 3]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: Vimeo7
13 | mode: Vimeo7
14 | interval_list: [1]
15 | random_reverse: true #false
16 | border_mode: false
17 | dataroot_GT: /data/datasets/SR/vimeo_septuplet/vimeo7_train_GT.lmdb
18 | dataroot_LQ: /data/datasets/SR/vimeo_septuplet/vimeo7_train_LR7.lmdb
19 | cache_keys: Vimeo7_train_keys.pkl
20 |
21 | N_frames: 7
22 | use_shuffle: true
23 | n_workers: 3 # per GPU
24 | batch_size: 16
25 | GT_size: 128
26 | LQ_size: 32
27 | use_flip: true
28 | use_rot: true
29 | color: RGB
30 |
31 | #### network structures
32 | network_G:
33 | which_model_G: LunaTokis
34 | nf: 64
35 | nframes: 7
36 | groups: 8
37 | front_RBs: 5
38 | mid_RBs: 0
39 | back_RBs: 40
40 | HR_in: false
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ~
45 | strict_load: false #true #
46 | resume_state: ~
47 |
48 | #### training settings: learning rate scheme, loss
49 | train:
50 | lr_G: !!float 4e-4
51 | lr_scheme: CosineAnnealingLR_Restart
52 | beta1: 0.9
53 | beta2: 0.99
54 | niter: 600000
55 | warmup_iter: -1 #4000 # -1: no warm up
56 | T_period: [150000, 150000, 150000, 150000]
57 | restarts: [150000, 300000, 450000]
58 | restart_weights: [1, 1, 1]
59 | eta_min: !!float 1e-7
60 |
61 | pixel_criterion: cb
62 | pixel_weight: 1.0
63 | val_freq: !!float 5e3
64 |
65 | manual_seed: 0
66 |
67 | #### logger
68 | logger:
69 | print_freq: 100
70 | save_checkpoint_freq: !!float 5e3
71 |
--------------------------------------------------------------------------------
/codes/test.py:
--------------------------------------------------------------------------------
1 | '''
2 | test Zooming Slow-Mo models on arbitrary datasets
3 | write to txt log file
4 | [kosame] TODO: update the test script to the newest version
5 | '''
6 |
7 | import os
8 | import os.path as osp
9 | import glob
10 | import logging
11 | import numpy as np
12 | import cv2
13 | import torch
14 |
15 | import utils.util as util
16 | import data.util as data_util
17 | import models.modules.Sakuya_arch as Sakuya_arch
18 |
19 | def main():
20 | scale = 4
21 | N_ot = 7 #3
22 | N_in = 1+ N_ot // 2
23 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
24 |
25 | #### model
26 | #### TODO: change your model path here
27 | model_path = '../experiments/pretrained_models/xiang2020zooming.pth'
28 | model = Sakuya_arch.LunaTokis(64, N_ot, 8, 5, 40)
29 |
30 | #### dataset
31 | data_mode = 'Custom' #'Vid4' #'SPMC'#'Middlebury'#
32 |
33 | if data_mode == 'Vid4':
34 | test_dataset_folder = '/data/xiang/SR/Vid4/LR/*'
35 | if data_mode == 'SPMC':
36 | test_dataset_folder = '/data/xiang/SR/spmc/*'
37 | if data_mode == 'Custom':
38 | test_dataset_folder = '../test_example/*' # TODO: put your own data path here
39 |
40 | #### evaluation
41 | flip_test = False #True#
42 | crop_border = 0
43 |
44 | # temporal padding mode
45 | padding = 'replicate'
46 | save_imgs = False #True#
47 | if 'Custom' in data_mode: save_imgs = True
48 | ############################################################################
49 | if torch.cuda.is_available():
50 | device = torch.device('cuda')
51 | else:
52 | device = torch.device('cpu')
53 | save_folder = '../results/{}'.format(data_mode)
54 | util.mkdirs(save_folder)
55 | util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True)
56 | logger = logging.getLogger('base')
57 | model_params = util.get_model_total_params(model)
58 |
59 | #### log info
60 | logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
61 | logger.info('Padding mode: {}'.format(padding))
62 | logger.info('Model path: {}'.format(model_path))
63 | logger.info('Model parameters: {} M'.format(model_params))
64 | logger.info('Save images: {}'.format(save_imgs))
65 | logger.info('Flip Test: {}'.format(flip_test))
66 |
67 |
68 | def single_forward(model, imgs_in):
69 | with torch.no_grad():
70 | # imgs_in.size(): [1,n,3,h,w]
71 | b,n,c,h,w = imgs_in.size()
72 | h_n = int(4*np.ceil(h/4))
73 | w_n = int(4*np.ceil(w/4))
74 | imgs_temp = imgs_in.new_zeros(b,n,c,h_n,w_n)
75 | imgs_temp[:,:,:,0:h,0:w] = imgs_in
76 |
77 | model_output = model(imgs_temp)
78 | # model_output.size(): torch.Size([1, 3, 4h, 4w])
79 | model_output = model_output[:, :, :, 0:scale*h, 0:scale*w]
80 | if isinstance(model_output, list) or isinstance(model_output, tuple):
81 | output = model_output[0]
82 | else:
83 | output = model_output
84 | return output
85 |
86 | sub_folder_l = sorted(glob.glob(test_dataset_folder))
87 |
88 | model.load_state_dict(torch.load(model_path), strict=True)
89 |
90 | model.eval()
91 | model = model.to(device)
92 |
93 | avg_psnr_l = []
94 | avg_psnr_y_l = []
95 | sub_folder_name_l = []
96 | # total_time = []
97 | # for each sub-folder
98 | for sub_folder in sub_folder_l:
99 | gt_tested_list = []
100 | sub_folder_name = sub_folder.split('/')[-1]
101 | sub_folder_name_l.append(sub_folder_name)
102 | save_sub_folder = osp.join(save_folder, sub_folder_name)
103 |
104 | if data_mode == 'SPMC':
105 | sub_folder = sub_folder + '/LR/'
106 | img_LR_l = sorted(glob.glob(sub_folder + '/*'))
107 |
108 | if save_imgs:
109 | util.mkdirs(save_sub_folder)
110 |
111 | #### read LR images
112 | imgs = util.read_seq_imgs(sub_folder)
113 | #### read GT images
114 | img_GT_l = []
115 | if data_mode == 'SPMC':
116 | sub_folder_GT = osp.join(sub_folder.replace('/LR/', '/truth/'))
117 | else:
118 | sub_folder_GT = osp.join(sub_folder.replace('/LR/', '/HR/'))
119 |
120 | if 'Custom' not in data_mode:
121 | for img_GT_path in sorted(glob.glob(osp.join(sub_folder_GT,'*'))):
122 | img_GT_l.append(util.read_image(img_GT_path))
123 |
124 | avg_psnr, avg_psnr_sum, cal_n = 0,0,0
125 | avg_psnr_y, avg_psnr_sum_y = 0,0
126 |
127 | if len(img_LR_l) == len(img_GT_l):
128 | skip = True
129 | else:
130 | skip = False
131 |
132 | if 'Custom' in data_mode:
133 | select_idx_list = util.test_index_generation(False, N_ot, len(img_LR_l))
134 | else:
135 | select_idx_list = util.test_index_generation(skip, N_ot, len(img_LR_l))
136 | # process each image
137 | for select_idxs in select_idx_list:
138 | # get input images
139 | select_idx = select_idxs[0]
140 | gt_idx = select_idxs[1]
141 | imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device)
142 |
143 | output = single_forward(model, imgs_in)
144 |
145 | outputs = output.data.float().cpu().squeeze(0)
146 |
147 | if flip_test:
148 | # flip W
149 | output = single_forward(model, torch.flip(imgs_in, (-1, )))
150 | output = torch.flip(output, (-1, ))
151 | output = output.data.float().cpu().squeeze(0)
152 | outputs = outputs + output
153 | # flip H
154 | output = single_forward(model, torch.flip(imgs_in, (-2, )))
155 | output = torch.flip(output, (-2, ))
156 | output = output.data.float().cpu().squeeze(0)
157 | outputs = outputs + output
158 | # flip both H and W
159 | output = single_forward(model, torch.flip(imgs_in, (-2, -1)))
160 | output = torch.flip(output, (-2, -1))
161 | output = output.data.float().cpu().squeeze(0)
162 | outputs = outputs + output
163 |
164 | outputs = outputs / 4
165 |
166 | # save imgs
167 | for idx, name_idx in enumerate(gt_idx):
168 | if name_idx in gt_tested_list:
169 | continue
170 | gt_tested_list.append(name_idx)
171 | output_f = outputs[idx,:,:,:].squeeze(0)
172 |
173 | output = util.tensor2img(output_f)
174 | if save_imgs:
175 | cv2.imwrite(osp.join(save_sub_folder, '{:08d}.png'.format(name_idx+1)), output)
176 |
177 | if 'Custom' not in data_mode:
178 | #### calculate PSNR
179 | output = output / 255.
180 |
181 | GT = np.copy(img_GT_l[name_idx])
182 |
183 | if crop_border == 0:
184 | cropped_output = output
185 | cropped_GT = GT
186 | else:
187 | cropped_output = output[crop_border:-crop_border, crop_border:-crop_border, :]
188 | cropped_GT = GT[crop_border:-crop_border, crop_border:-crop_border, :]
189 | crt_psnr = util.calculate_psnr(cropped_output * 255, cropped_GT * 255)
190 | cropped_GT_y = data_util.bgr2ycbcr(cropped_GT, only_y=True)
191 | cropped_output_y = data_util.bgr2ycbcr(cropped_output, only_y=True)
192 | crt_psnr_y = util.calculate_psnr(cropped_output_y * 255, cropped_GT_y * 255)
193 | logger.info('{:3d} - {:25}.png \tPSNR: {:.6f} dB PSNR-Y: {:.6f} dB'.format(name_idx + 1, name_idx+1, crt_psnr, crt_psnr_y))
194 | avg_psnr_sum += crt_psnr
195 | avg_psnr_sum_y += crt_psnr_y
196 | cal_n += 1
197 |
198 | if 'Custom' not in data_mode:
199 | avg_psnr = avg_psnr_sum / cal_n
200 | avg_psnr_y = avg_psnr_sum_y / cal_n
201 |
202 | logger.info('Folder {} - Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB for {} frames; '.format(sub_folder_name, avg_psnr, avg_psnr_y, cal_n))
203 |
204 | avg_psnr_l.append(avg_psnr)
205 | avg_psnr_y_l.append(avg_psnr_y)
206 |
207 | if 'Custom' not in data_mode:
208 | logger.info('################ Tidy Outputs ################')
209 | for name, psnr, psnr_y in zip(sub_folder_name_l, avg_psnr_l, avg_psnr_y_l):
210 | logger.info('Folder {} - Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB. '
211 | .format(name, psnr, psnr_y))
212 | logger.info('################ Final Results ################')
213 | logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
214 | logger.info('Padding mode: {}'.format(padding))
215 | logger.info('Model path: {}'.format(model_path))
216 | logger.info('Save images: {}'.format(save_imgs))
217 | logger.info('Flip Test: {}'.format(flip_test))
218 | logger.info('Total Average PSNR: {:.6f} dB PSNR-Y: {:.6f} dB for {} clips. '
219 | .format(
220 | sum(avg_psnr_l) / len(avg_psnr_l), sum(avg_psnr_y_l) / len(avg_psnr_y_l), len(sub_folder_l)))
221 | # logger.info('Total Runtime: {:.6f} s Average Runtime: {:.6f} for {} images.'
222 | # .format(sum(total_time), sum(total_time)/171, 171))
223 |
224 | if __name__ == '__main__':
225 | main()
--------------------------------------------------------------------------------
/codes/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import random
5 | import logging
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.multiprocessing as mp
10 | from data.data_sampler import DistIterSampler
11 |
12 | import options.options as option
13 | from utils import util
14 | from data import create_dataloader, create_dataset
15 | from models import create_model
16 |
17 |
18 | def init_dist(backend='nccl', **kwargs):
19 | ''' initialization for distributed training'''
20 | # if mp.get_start_method(allow_none=True) is None:
21 | if mp.get_start_method(allow_none=True) != 'spawn':
22 | mp.set_start_method('spawn')
23 | rank = int(os.environ['RANK'])
24 | num_gpus = torch.cuda.device_count()
25 | torch.cuda.set_device(rank % num_gpus)
26 | dist.init_process_group(backend=backend, **kwargs)
27 |
28 |
29 | def main():
30 | #### options
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
33 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
34 | help='job launcher')
35 | parser.add_argument('--local_rank', type=int, default=0)
36 | args = parser.parse_args()
37 | opt = option.parse(args.opt, is_train=True)
38 |
39 | #### distributed training settings
40 | if args.launcher == 'none': # disabled distributed training
41 | opt['dist'] = False
42 | rank = -1
43 | print('Disabled distributed training.')
44 | else:
45 | opt['dist'] = True
46 | init_dist()
47 | world_size = torch.distributed.get_world_size()
48 | rank = torch.distributed.get_rank()
49 |
50 | #### loading resume state if exists
51 | if opt['path'].get('resume_state', None):
52 | # distributed resuming: all load into default GPU
53 | device_id = torch.cuda.current_device()
54 | resume_state = torch.load(opt['path']['resume_state'],
55 | map_location=lambda storage, loc: storage.cuda(device_id))
56 | option.check_resume(opt, resume_state['iter']) # check resume options
57 | else:
58 | resume_state = None
59 |
60 | #### mkdir and loggers
61 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
62 | if resume_state is None:
63 | util.mkdir_and_rename(
64 | opt['path']['experiments_root']) # rename experiment folder if exists
65 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
66 | and 'pretrain_model' not in key and 'resume' not in key))
67 |
68 | # config loggers. Before it, the log will not work
69 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
70 | screen=True, tofile=True)
71 | logger = logging.getLogger('base')
72 | logger.info(option.dict2str(opt))
73 | # tensorboard logger
74 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
75 | version = float(torch.__version__[0:3])
76 | if version >= 1.1: # PyTorch 1.1
77 | from torch.utils.tensorboard import SummaryWriter
78 | else:
79 | logger.info(
80 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
81 | from tensorboardX import SummaryWriter
82 | tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
83 | else:
84 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
85 | logger = logging.getLogger('base')
86 |
87 | # convert to NoneDict, which returns None for missing keys
88 | opt = option.dict_to_nonedict(opt)
89 |
90 | #### random seed
91 | seed = opt['train']['manual_seed']
92 | if seed is None:
93 | seed = random.randint(1, 10000)
94 | if rank <= 0:
95 | logger.info('Random seed: {}'.format(seed))
96 | util.set_random_seed(seed)
97 |
98 | torch.backends.cudnn.benckmark = True
99 | # torch.backends.cudnn.deterministic = True
100 |
101 | #### create train and val dataloader
102 | dataset_ratio = 200 # enlarge the size of each epoch
103 | for phase, dataset_opt in opt['datasets'].items():
104 | if phase == 'train':
105 | train_set = create_dataset(dataset_opt)
106 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
107 | total_iters = int(opt['train']['niter'])
108 | total_epochs = int(math.ceil(total_iters / train_size))
109 | if opt['dist']:
110 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
111 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
112 | else:
113 | train_sampler = None
114 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
115 | if rank <= 0:
116 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
117 | len(train_set), train_size))
118 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
119 | total_epochs, total_iters))
120 | elif phase == 'val':
121 | pass
122 | '''
123 | val_set = create_dataset(dataset_opt)
124 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
125 | if rank <= 0:
126 | logger.info('Number of val images in [{:s}]: {:d}'.format(
127 | dataset_opt['name'], len(val_set)))
128 | '''
129 | else:
130 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
131 | assert train_loader is not None
132 |
133 | #### create model
134 | model = create_model(opt)
135 |
136 | #### resume training
137 | if resume_state:
138 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
139 | resume_state['epoch'], resume_state['iter']))
140 |
141 | start_epoch = resume_state['epoch']
142 | current_step = resume_state['iter']
143 | model.resume_training(resume_state) # handle optimizers and schedulers
144 | else:
145 | current_step = 0
146 | start_epoch = 0
147 |
148 | #### training
149 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
150 | for epoch in range(start_epoch, total_epochs + 1):
151 | if opt['dist']:
152 | train_sampler.set_epoch(epoch)
153 | for _, train_data in enumerate(train_loader):
154 | current_step += 1
155 | if current_step > total_iters:
156 | break
157 | #### update learning rate
158 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
159 |
160 | #### training
161 | model.feed_data(train_data)
162 | model.optimize_parameters(current_step)
163 |
164 | #### log
165 | if current_step % opt['logger']['print_freq'] == 0:
166 | logs = model.get_current_log()
167 | message = ''
171 | for k, v in logs.items():
172 | message += '{:s}: {:.4e} '.format(k, v)
173 | # tensorboard logger
174 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
175 | if rank <= 0:
176 | tb_logger.add_scalar(k, v, current_step)
177 | if rank <= 0:
178 | logger.info(message)
179 | #### validation
180 | # currently, it does not support validation during training
181 |
182 | #### save models and training states
183 | if current_step % opt['logger']['save_checkpoint_freq'] == 0:
184 | if rank <= 0:
185 | logger.info('Saving models and training states.')
186 | model.save(current_step)
187 | model.save_training_state(epoch, current_step)
188 |
189 | if rank <= 0:
190 | logger.info('Saving the final model.')
191 | model.save('latest')
192 | logger.info('End of training.')
193 |
194 |
195 | if __name__ == '__main__':
196 | main()
197 |
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/codes/utils/__init__.py
--------------------------------------------------------------------------------
/codes/utils/make_video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | from os.path import isfile, join
5 | import re
6 |
7 | # extract frames from a video
8 |
9 |
10 | def extract_frames(pathIn, pathOut, cnt=0):
11 | vidcap = cv2.VideoCapture(pathIn)
12 | success, image = vidcap.read()
13 | print('Start to extract frames from {}.'.format(os.path.basename(pathIn)))
14 | while success:
15 | cv2.imwrite(join(pathOut, "{:06d}.png".format(cnt)), image)
16 | success, image = vidcap.read()
17 | cnt += 1
18 | print('Successfully extract {} frames from {}.'.format(
19 | cnt, os.path.basename(pathIn)))
20 |
21 | # combine frames to a video
22 |
23 |
24 | def combine_frames(pathIn, pathOut, fps):
25 | frame_array = []
26 | files = [f for f in os.listdir(pathIn) if isfile(join(pathIn, f))]
27 | # for sorting the file names properly
28 | files.sort(key=lambda x: int(re.search(r'\d+', x).group()))
29 | for i in range(len(files)):
30 | filename = join(pathIn, files[i])
31 | # reading each files
32 | img = cv2.imread(filename)
33 | height, width, layers = img.shape
34 | size = (width, height)
35 | print(filename)
36 | # inserting the frames into an image array
37 | frame_array.append(img)
38 | out = cv2.VideoWriter(pathOut, cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
39 | for i in range(len(frame_array)):
40 | # writing to a image array
41 | out.write(frame_array[i])
42 | out.release()
43 |
44 |
45 | if __name__ == "__main__":
46 | pathIn = 'demo_vid/'
47 | pathOut = 'out.mp4'
48 | fps = 29.98
49 | combine_frames(pathIn, pathOut, fps)
50 |
--------------------------------------------------------------------------------
/codes/utils/util.py:
--------------------------------------------------------------------------------
1 | # this code is modified from https://github.com/xinntao/EDVR/blob/master/codes/utils/util.py
2 | import os
3 | import sys
4 | import time
5 | import math
6 | import torch.nn.functional as F
7 | from datetime import datetime
8 | import random
9 | import logging
10 | from collections import OrderedDict
11 | import numpy as np
12 | import cv2
13 | import torch
14 | from torchvision.utils import make_grid
15 | from shutil import get_terminal_size
16 | import glob
17 | import re
18 |
19 | import yaml
20 | try:
21 | from yaml import CLoader as Loader, CDumper as Dumper
22 | except ImportError:
23 | from yaml import Loader, Dumper
24 |
25 |
26 | def OrderedYaml():
27 | '''yaml orderedDict support'''
28 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
29 |
30 | def dict_representer(dumper, data):
31 | return dumper.represent_dict(data.items())
32 |
33 | def dict_constructor(loader, node):
34 | return OrderedDict(loader.construct_pairs(node))
35 |
36 | Dumper.add_representer(OrderedDict, dict_representer)
37 | Loader.add_constructor(_mapping_tag, dict_constructor)
38 | return Loader, Dumper
39 |
40 |
41 | ####################
42 | # miscellaneous
43 | ####################
44 | def get_model_total_params(model):
45 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
46 | params = sum([np.prod(p.size()) for p in model_parameters])
47 | return (1.0*params/(1000*1000))
48 |
49 |
50 | def get_timestamp():
51 | return datetime.now().strftime('%y%m%d-%H%M%S')
52 |
53 |
54 | def mkdir(path):
55 | if not os.path.exists(path):
56 | os.makedirs(path)
57 |
58 |
59 | def mkdirs(paths):
60 | if isinstance(paths, str):
61 | mkdir(paths)
62 | else:
63 | for path in paths:
64 | mkdir(path)
65 |
66 |
67 | def mkdir_and_rename(path):
68 | if os.path.exists(path):
69 | new_name = path + '_archived_' + get_timestamp()
70 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
71 | logger = logging.getLogger('base')
72 | logger.info(
73 | 'Path already exists. Rename it to [{:s}]'.format(new_name))
74 | os.rename(path, new_name)
75 | os.makedirs(path)
76 |
77 |
78 | def set_random_seed(seed):
79 | random.seed(seed)
80 | np.random.seed(seed)
81 | torch.manual_seed(seed)
82 | torch.cuda.manual_seed_all(seed)
83 |
84 |
85 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
86 | '''set up logger'''
87 | lg = logging.getLogger(logger_name)
88 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
89 | datefmt='%y-%m-%d %H:%M:%S')
90 | lg.setLevel(level)
91 | if tofile:
92 | log_file = os.path.join(
93 | root, phase + '_{}.log'.format(get_timestamp()))
94 | fh = logging.FileHandler(log_file, mode='w')
95 | fh.setFormatter(formatter)
96 | lg.addHandler(fh)
97 | if screen:
98 | sh = logging.StreamHandler()
99 | sh.setFormatter(formatter)
100 | lg.addHandler(sh)
101 |
102 |
103 | ####################
104 | # image convert
105 | ####################
106 |
107 |
108 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
109 | '''
110 | Converts a torch Tensor into an image Numpy array
111 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
112 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
113 | '''
114 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
115 | tensor = (tensor - min_max[0]) / \
116 | (min_max[1] - min_max[0]) # to range [0,1]
117 | n_dim = tensor.dim()
118 | if n_dim == 4:
119 | n_img = len(tensor)
120 | img_np = make_grid(tensor, nrow=int(
121 | math.sqrt(n_img)), normalize=False).numpy()
122 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
123 | elif n_dim == 3:
124 | img_np = tensor.numpy()
125 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
126 | elif n_dim == 2:
127 | img_np = tensor.numpy()
128 | else:
129 | raise TypeError(
130 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
131 | if out_type == np.uint8:
132 | img_np = (img_np * 255.0).round()
133 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
134 | return img_np.astype(out_type)
135 |
136 |
137 | def save_img(img, img_path, mode='RGB'):
138 | cv2.imwrite(img_path, img)
139 |
140 | ####################
141 | # metric
142 | ####################
143 |
144 |
145 | def calculate_psnr(img1, img2):
146 | # img1 and img2 have range [0, 255]
147 | img1 = img1.astype(np.float64)
148 | img2 = img2.astype(np.float64)
149 | # print(img1)
150 | # print('img1-2')
151 | # print(img2)
152 | mse = np.mean((img1 - img2)**2)
153 | # print(mse)
154 | if mse == 0:
155 | return float('inf')
156 | return 20 * math.log10(255.0 / math.sqrt(mse))
157 |
158 |
159 | def ssim(img1, img2):
160 | C1 = (0.01 * 255)**2
161 | C2 = (0.03 * 255)**2
162 |
163 | img1 = img1.astype(np.float64)
164 | img2 = img2.astype(np.float64)
165 | kernel = cv2.getGaussianKernel(11, 1.5)
166 | window = np.outer(kernel, kernel.transpose())
167 |
168 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
169 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
170 | mu1_sq = mu1**2
171 | mu2_sq = mu2**2
172 | mu1_mu2 = mu1 * mu2
173 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
174 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
175 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
176 |
177 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
178 | (sigma1_sq + sigma2_sq + C2))
179 | return ssim_map.mean()
180 |
181 |
182 | def calculate_ssim(img1, img2):
183 | '''calculate SSIM
184 | the same outputs as MATLAB's
185 | img1, img2: [0, 255]
186 | '''
187 | if not img1.shape == img2.shape:
188 | raise ValueError('Input images must have the same dimensions.')
189 | if img1.ndim == 2:
190 | return ssim(img1, img2)
191 | elif img1.ndim == 3:
192 | if img1.shape[2] == 3:
193 | ssims = []
194 | for i in range(3):
195 | ssims.append(ssim(img1, img2))
196 | return np.array(ssims).mean()
197 | elif img1.shape[2] == 1:
198 | return ssim(np.squeeze(img1), np.squeeze(img2))
199 | else:
200 | raise ValueError('Wrong input image dimensions.')
201 |
202 |
203 | class ProgressBar(object):
204 | '''A progress bar which can print the progress
205 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
206 | '''
207 |
208 | def __init__(self, task_num=0, bar_width=50, start=True):
209 | self.task_num = task_num
210 | max_bar_width = self._get_max_bar_width()
211 | self.bar_width = (bar_width if bar_width <=
212 | max_bar_width else max_bar_width)
213 | self.completed = 0
214 | if start:
215 | self.start()
216 |
217 | def _get_max_bar_width(self):
218 | terminal_width, _ = get_terminal_size()
219 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
220 | if max_bar_width < 10:
221 | print('terminal width is too small ({}), please consider widen the terminal for better '
222 | 'progressbar visualization'.format(terminal_width))
223 | max_bar_width = 10
224 | return max_bar_width
225 |
226 | def start(self):
227 | if self.task_num > 0:
228 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
229 | ' ' * self.bar_width, self.task_num, 'Start...'))
230 | else:
231 | sys.stdout.write('completed: 0, elapsed: 0s')
232 | sys.stdout.flush()
233 | self.start_time = time.time()
234 |
235 | def update(self, msg='In progress...'):
236 | self.completed += 1
237 | elapsed = time.time() - self.start_time
238 | fps = self.completed / elapsed
239 | if self.task_num > 0:
240 | percentage = self.completed / float(self.task_num)
241 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
242 | mark_width = int(self.bar_width * percentage)
243 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
244 | sys.stdout.write('\033[2F') # cursor up 2 lines
245 | # clean the output (remove extra chars since last display)
246 | sys.stdout.write('\033[J')
247 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
248 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
249 | else:
250 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
251 | self.completed, int(elapsed + 0.5), fps))
252 | sys.stdout.flush()
253 |
254 | ####################
255 | # read image
256 | ####################
257 |
258 |
259 | def read_image(img_path):
260 | '''read one image from img_path
261 | Return img: HWC, BGR, [0,1], numpy
262 | '''
263 | img_GT = cv2.imread(img_path)
264 | img = img_GT.astype(np.float32) / 255.
265 | return img
266 |
267 |
268 | def read_seq_imgs(img_seq_path):
269 | '''read a sequence of images'''
270 | img_path_l = glob.glob(img_seq_path + '/*')
271 | img_path_l.sort(key=lambda x: int(
272 | re.search(r'\d+', os.path.basename(x)).group()))
273 | return read_seq_imgs_by_list(img_path_l)
274 |
275 |
276 | def read_seq_imgs_by_list(img_path_l):
277 | '''read a sequence of images from the given list'''
278 | img_l = [read_image(v) for v in img_path_l]
279 | # stack to TCHW, RGB, [0,1], torch
280 | imgs = np.stack(img_l, axis=0)
281 | imgs = imgs[:, :, :, [2, 1, 0]]
282 | imgs = torch.from_numpy(np.ascontiguousarray(
283 | np.transpose(imgs, (0, 3, 1, 2)))).float()
284 | return imgs
285 |
286 |
287 | def test_index_generation(skip, N_out, len_in):
288 | '''
289 | params:
290 | skip: if skip even number;
291 | N_out: number of frames of the network;
292 | len_in: length of input frames
293 |
294 | example:
295 | len_in | N_out | times | (no skip) | (skip)
296 | 5 | 3 | 4/2 | [0,1], [1,2], [2,3], [3,4] | [0,2],[2,4]
297 | 7 | 3 | 5/3 | [0,1],[1,2][2,3]...[5,6] | [0,2],[2,4],[4,6]
298 | 5 | 5 | 2/1 | [0,1,2] [2,3,4] | [0,2,4]
299 | '''
300 | # number of input frames for the network
301 | N_in = 1 + N_out // 2
302 | # input length should be enough to generate the output frames
303 | assert N_in <= len_in
304 |
305 | sele_list = []
306 | if skip:
307 | right = N_out # init
308 | while (right <= len_in):
309 | h_list = [right-N_out+x for x in range(N_out)]
310 | l_list = h_list[::2]
311 | right += (N_out - 1)
312 | sele_list.append([l_list, h_list])
313 | else:
314 | right = N_out # init
315 | right_in = N_in
316 | while (right_in <= len_in):
317 | h_list = [right-N_out+x for x in range(N_out)]
318 | l_list = [right_in-N_in+x for x in range(N_in)]
319 | right += (N_out - 1)
320 | right_in += (N_in - 1)
321 | sele_list.append([l_list, h_list])
322 | # check if it covers the last image, if not, we should cover it
323 | if (skip) and (right < len_in - 1):
324 | h_list = [len_in - N_out + x for x in range(N_out)]
325 | l_list = h_list[::2]
326 | sele_list.append([l_list, h_list])
327 | if (not skip) and (right_in < len_in - 1):
328 | right = len_in * 2 - 1
329 | h_list = [right-N_out+x for x in range(N_out)]
330 | l_list = [len_in - N_in + x for x in range(N_in)]
331 | sele_list.append([l_list, h_list])
332 | return sele_list
333 |
334 |
335 | ####################
336 | # video
337 | ####################
338 |
339 | def extract_frames(ffmpeg_dir, video, outDir):
340 | """
341 | Converts the `video` to images.
342 | Parameters
343 | ----------
344 | video : string
345 | full path to the video file.
346 | outDir : string
347 | path to directory to output the extracted images.
348 | Returns
349 | -------
350 | error : string
351 | Error message if error occurs otherwise blank string.
352 | """
353 |
354 | error = ""
355 | print('{} -i {} -vsync 0 {}/%06d.png'.format(os.path.join(ffmpeg_dir,
356 | "ffmpeg"), video, outDir))
357 | retn = os.system('{} -i "{}" -vsync 0 {}/%06d.png'.format(
358 | os.path.join(ffmpeg_dir, "ffmpeg"), video, outDir))
359 | if retn:
360 | error = "Error converting file:{}. Exiting.".format(video)
361 | return error
362 |
363 |
364 | def create_video(ffmpeg_dir, dir, output, fps):
365 | error = ""
366 | print('{} -r {} -f image2 -i {}/%6d.png {}'.format(os.path.join(ffmpeg_dir,
367 | "ffmpeg"), fps, dir, output))
368 | retn = os.system('{} -r {} -f image2 -i {}/%6d.png {}'.format(
369 | os.path.join(ffmpeg_dir, "ffmpeg"), fps, dir, output))
370 | if retn:
371 | error = "Error creating output video. Exiting."
372 | return error
373 |
374 |
375 | # combine frames to a video
376 | def combine_frames(pathIn, pathOut, fps):
377 | frame_array = []
378 | files = [f for f in os.listdir(
379 | pathIn) if os.path.isfile(os.path.join(pathIn, f))]
380 | # for sorting the file names properly
381 | files.sort(key=lambda x: int(re.search(r'\d+', x).group()))
382 | for i in range(len(files)):
383 | filename = os.path.join(pathIn, files[i])
384 | # reading each files
385 | img = cv2.imread(filename)
386 | height, width, layers = img.shape
387 | size = (width, height)
388 | # inserting the frames into an image array
389 | frame_array.append(img)
390 | out = cv2.VideoWriter(pathOut, cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
391 | for i in range(len(frame_array)):
392 | # writing to a image array
393 | out.write(frame_array[i])
394 | out.release()
395 |
--------------------------------------------------------------------------------
/codes/video_to_zsm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import os.path as osp
4 | import glob
5 | import logging
6 | import numpy as np
7 | import cv2
8 | import torch
9 | import re
10 |
11 | import utils.util as util
12 | import data.util as data_util
13 | import models.modules.Sakuya_arch as Sakuya_arch
14 |
15 | import argparse
16 | from shutil import rmtree
17 |
18 | # For parsing commandline arguments
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--ffmpeg_dir", type=str, default="",
21 | help='path to ffmpeg.exe')
22 | parser.add_argument("--video", type=str, required=True,
23 | help='path of video to be converted')
24 | parser.add_argument("--model", type=str, required=True,
25 | help='path of pretrained model')
26 | parser.add_argument("--fps", type=float, default=24,
27 | help='specify fps of output video. Default: 24.')
28 | parser.add_argument("--N_out", type=int, default=3,
29 | help='Specify size of output frames of the network for faster conversion. This will depend on your cpu/gpu memory. Default: 7')
30 | parser.add_argument("--output", type=str, default="output.mp4",
31 | help='Specify output file name. Default: output.mp4')
32 | args = parser.parse_args()
33 |
34 |
35 | def check():
36 | """
37 | Checks the validity of commandline arguments.
38 | Parameters
39 | ----------
40 | None
41 | Returns
42 | -------
43 | error : string
44 | Error message if error occurs otherwise blank string.
45 | """
46 |
47 | error = ""
48 | if (args.batch_size not in [3, 5, 7]):
49 | error = "Error: --N_out has to be 3 or 5 or 7"
50 | # if ".mkv" not in args.output:
51 | # error = "output needs to have a video container"
52 | return error
53 |
54 |
55 | def main():
56 | scale = 4
57 | N_ot = args.N_out
58 | N_in = 1 + N_ot//2
59 |
60 | # model
61 | model_path = args.model
62 | model = Sakuya_arch.LunaTokis(64, N_ot, 8, 5, 40)
63 |
64 | # extract the input video to temporary folder
65 | save_folder = osp.join(osp.dirname(args.output), '.delme')
66 | save_out_folder = osp.join(osp.dirname(args.output), '.hr_delme')
67 | if os.path.isdir(save_folder):
68 | rmtree(save_folder)
69 | util.mkdirs(save_folder)
70 | if os.path.isdir(save_out_folder):
71 | rmtree(save_out_folder)
72 | util.mkdirs(save_out_folder)
73 | error = util.extract_frames(args.ffmpeg_dir, args.video, save_folder)
74 | if error:
75 | print(error)
76 | exit(1)
77 |
78 | # temporal padding mode
79 | padding = 'replicate'
80 | save_imgs = True
81 |
82 | ############################################################################
83 | if torch.cuda.is_available():
84 | device = torch.device('cuda')
85 | else:
86 | device = torch.device('cpu')
87 |
88 | def single_forward(model, imgs_in):
89 | with torch.no_grad():
90 | # print(imgs_in.size()) # [1,5,3,270,480]
91 | b, n, c, h, w = imgs_in.size()
92 | h_n = int(4*np.ceil(h/4))
93 | w_n = int(4*np.ceil(w/4))
94 | imgs_temp = imgs_in.new_zeros(b, n, c, h_n, w_n)
95 | imgs_temp[:, :, :, 0:h, 0:w] = imgs_in
96 | model_output = model(imgs_temp)
97 | model_output = model_output[:, :, :, 0:scale*h, 0:scale*w]
98 | if isinstance(model_output, list) or isinstance(model_output, tuple):
99 | output = model_output[0]
100 | else:
101 | output = model_output
102 | return output
103 |
104 | model.load_state_dict(torch.load(model_path), strict=True)
105 |
106 | model.eval()
107 | model = model.to(device)
108 | # zsm images
109 | img_path_l = glob.glob(save_folder + '/*')
110 | img_path_l.sort(key=lambda x: int(
111 | re.search(r'\d+', os.path.basename(x)).group()))
112 | select_idx_list = util.test_index_generation(False, N_ot, len(img_path_l))
113 | for select_idxs in select_idx_list:
114 | # get input images
115 | select_idx = select_idxs[0]
116 | imgs_in = util.read_seq_imgs_by_list(
117 | [img_path_l[x] for x in select_idx]).unsqueeze(0).to(device)
118 | output = single_forward(model, imgs_in)
119 | outputs = output.data.float().cpu().squeeze(0)
120 | # save imgs
121 | out_idx = select_idxs[1]
122 | for idx, name_idx in enumerate(out_idx):
123 | output_f = outputs[idx, ...].squeeze(0)
124 | if save_imgs:
125 | output = util.tensor2img(output_f)
126 | cv2.imwrite(osp.join(save_out_folder,
127 | '{:06d}.png'.format(name_idx)), output)
128 |
129 | # now turn output images to video
130 | # generate mp4
131 | util.combine_frames(save_out_folder,
132 | args.output, args.fps)
133 |
134 | # remove tmp folder
135 | rmtree(save_folder)
136 | rmtree(save_out_folder)
137 |
138 | exit(0)
139 |
140 |
141 | if __name__ == '__main__':
142 | main()
143 |
--------------------------------------------------------------------------------
/codes/zsm_my_video.sh:
--------------------------------------------------------------------------------
1 | python video_to_zsm.py --video PATH/TO/VIDEO.mp4 --model ../experiments/pretrained_models/xiang2020zooming.pth --output PATH/TO/OUTPUT.mp4
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | In this directory we provide the meta-info of the datasets used in this paper.
2 |
3 | ## Vimeo
4 |
5 | The official page of Vimeo dataset is [here](http://toflow.csail.mit.edu/). It has 64,612 training samples and 7,824 test samples. The complete dataset can be downloaded [here](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip).
6 |
7 | In this paper we use Vimeo as the training set, the keys for lmdb files can be found at [Vimeo7_train_keys.pkl](./meta_info/Vimeo7_train_keys.pkl).
8 |
9 | The test set of Vimeo is split into 3 subsets based on the movement: fast, medium and slow. The indices of videos for each subset can be found at [slow_testset.txt](/meta_info/slow_testset.txt), [medium_testset.txt](/meta_info/medium_testset.txt), and [fast_testset.txt](/meta_info/fast_testset.txt). Note that we remove videos with totally black frames to avoid NaN results during evaluation.
10 |
11 |
12 | ## Vid4
13 |
14 | Vid4 is a 4-clip test set that has 171 frames in total. It can be downloaded [here](https://drive.google.com/drive/folders/10-gUO6zBeOpWEamrWKCtSkkUFukB9W5m).
15 |
16 | In this paper, we use this dataset to measure the performance and runtime of different methods.
--------------------------------------------------------------------------------
/datasets/meta_info/Vimeo7_train_keys.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/datasets/meta_info/Vimeo7_train_keys.pkl
--------------------------------------------------------------------------------
/dump/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/dump/4539-teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/dump/4539-teaser.gif
--------------------------------------------------------------------------------
/dump/demo720.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/dump/demo720.gif
--------------------------------------------------------------------------------
/dump/demo_thumbnail.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/dump/demo_thumbnail.PNG
--------------------------------------------------------------------------------
/dump/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/dump/framework.png
--------------------------------------------------------------------------------
/experiments/pretrained_models/readme.md:
--------------------------------------------------------------------------------
1 | Our pretrained models is put in this GitHub folder.
2 |
3 | It can also be downloaded via: [Google Drive](https://drive.google.com/open?id=1xeOoZclGeSI1urY6mVCcApfCqOPgxMBK)
--------------------------------------------------------------------------------
/experiments/pretrained_models/xiang2020zooming.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/experiments/pretrained_models/xiang2020zooming.pth
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.2.0
2 | torchvision>=0.4.0
3 | numpy
4 | opencv-python
5 | lmdb
6 | pyyaml
7 | pickle5
8 | matplotlib
9 | seaborn
10 |
--------------------------------------------------------------------------------
/test_example/0625/im1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im1.png
--------------------------------------------------------------------------------
/test_example/0625/im2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im2.png
--------------------------------------------------------------------------------
/test_example/0625/im3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im3.png
--------------------------------------------------------------------------------
/test_example/0625/im4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im4.png
--------------------------------------------------------------------------------
/test_example/0625/im5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im5.png
--------------------------------------------------------------------------------
/test_example/0625/im6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im6.png
--------------------------------------------------------------------------------
/test_example/0625/im7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/a053e08bb0bb5509f634b523256718f502637667/test_example/0625/im7.png
--------------------------------------------------------------------------------