├── LICENSE
├── README.md
├── configs
├── config.txt
├── config_Balloon1.txt
├── config_Balloon2.txt
├── config_Jumping.txt
├── config_Playground.txt
├── config_Skating.txt
├── config_Truck.txt
└── config_Umbrella.txt
├── load_llff.py
├── render_utils.py
├── run_nerf.py
├── run_nerf_helpers.py
└── utils
├── RAFT
├── __init__.py
├── corr.py
├── datasets.py
├── demo.py
├── extractor.py
├── raft.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── colmap_utils.py
├── evaluation.py
├── flow_utils.py
├── generate_data.py
├── generate_depth.py
├── generate_flow.py
├── generate_motion_mask.py
├── generate_pose.py
└── midas
├── base_model.py
├── blocks.py
├── midas_net.py
├── transforms.py
└── vit.py
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) 2020 Virginia Tech Vision and Learning Lab
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | --------------------------- LICENSE FOR EdgeConnect --------------------------------
25 |
26 | Attribution-NonCommercial 4.0 International
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic View Synthesis from Dynamic Monocular Video
2 |
3 | [](https://arxiv.org/abs/2105.06468)
4 |
5 | [Project Website](https://free-view-video.github.io/) | [Video](https://youtu.be/j8CUzIR0f8M) | [Paper](https://arxiv.org/abs/2105.06468)
6 |
7 | > **Dynamic View Synthesis from Dynamic Monocular Video**
8 | > [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Johannes Kopf](https://johanneskopf.de/), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/)
9 | in ICCV 2021
10 |
11 | ## Setup
12 | The code is test with
13 | * Linux (tested on CentOS Linux release 7.4.1708)
14 | * Anaconda 3
15 | * Python 3.7.11
16 | * CUDA 10.1
17 | * 1 V100 GPU
18 |
19 |
20 | To get started, please create the conda environment `dnerf` by running
21 | ```
22 | conda create --name dnerf python=3.7
23 | conda activate dnerf
24 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
25 | pip install imageio scikit-image configargparse timm lpips
26 | ```
27 | and install [COLMAP](https://colmap.github.io/install.html) manually. Then download MiDaS and RAFT weights
28 | ```
29 | ROOT_PATH=/path/to/the/DynamicNeRF/folder
30 | cd $ROOT_PATH
31 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/weights.zip
32 | unzip weights.zip
33 | rm weights.zip
34 | ```
35 |
36 | ## Dynamic Scene Dataset
37 | The [Dynamic Scene Dataset](https://www-users.cse.umn.edu/~jsyoon/dynamic_synth/) is used to
38 | quantitatively evaluate our method. Please download the pre-processed data by running:
39 | ```
40 | cd $ROOT_PATH
41 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/data.zip
42 | unzip data.zip
43 | rm data.zip
44 | ```
45 |
46 | ### Training
47 | You can train a model from scratch by running:
48 | ```
49 | cd $ROOT_PATH/
50 | python run_nerf.py --config configs/config_Balloon2.txt
51 | ```
52 |
53 | Every 100k iterations, you should get videos like the following examples
54 |
55 | The novel view-time synthesis results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/novelviewtime`.
56 | 
57 |
58 |
59 | The reconstruction results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset`.
60 | 
61 |
62 | The fix-view-change-time results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_view000`.
63 | 
64 |
65 | The fix-time-change-view results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_time000`.
66 | 
67 |
68 |
69 | ### Rendering from pre-trained models
70 | We also provide pre-trained models. You can download them by running:
71 | ```
72 | cd $ROOT_PATH/
73 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/logs.zip
74 | unzip logs.zip
75 | rm logs.zip
76 | ```
77 |
78 | Then you can render the results directly by running:
79 | ```
80 | python run_nerf.py --config configs/config_Balloon2.txt --render_only --ft_path $ROOT_PATH/logs/Balloon2_H270_DyNeRF_pretrain/300000.tar
81 | ```
82 |
83 | ### Evaluating our method and others
84 | Our goal is to make the evaluation as simple as possible for you. We have collected the fix-view-change-time results of the following methods:
85 |
86 | `NeRF` \
87 | `NeRF + t` \
88 | `Yoon et al.` \
89 | `Non-Rigid NeRF` \
90 | `NSFF` \
91 | `DynamicNeRF (ours)`
92 |
93 | Please download the results by running:
94 | ```
95 | cd $ROOT_PATH/
96 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/results.zip
97 | unzip results.zip
98 | rm results.zip
99 | ```
100 |
101 | Then you can calculate the PSNR/SSIM/LPIPS by running:
102 | ```
103 | cd $ROOT_PATH/utils
104 | python evaluation.py
105 | ```
106 |
107 | | PSNR / LPIPS | Jumping | Skating | Truck | Umbrella | Balloon1 | Balloon2 | Playground | Average |
108 | |:-------------|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|
109 | | NeRF | 20.99 / 0.305 | 23.67 / 0.311 | 22.73 / 0.229 | 21.29 / 0.440 | 19.82 / 0.205 | 24.37 / 0.098 | 21.07 / 0.165 | 21.99 / 0.250 |
110 | | NeRF + t | 18.04 / 0.455 | 20.32 / 0.512 | 18.33 / 0.382 | 17.69 / 0.728 | 18.54 / 0.275 | 20.69 / 0.216 | 14.68 / 0.421 | 18.33 / 0.427 |
111 | | NR NeRF | 20.09 / 0.287 | 23.95 / 0.227 | 19.33 / 0.446 | 19.63 / 0.421 | 17.39 / 0.348 | 22.41 / 0.213 | 15.06 / 0.317 | 19.69 / 0.323 |
112 | | NSFF | 24.65 / 0.151 | 29.29 / 0.129 | 25.96 / 0.167 | 22.97 / 0.295 | 21.96 / 0.215 | 24.27 / 0.222 | 21.22 / 0.212 | 24.33 / 0.199 |
113 | | Ours | 24.68 / 0.090 | 32.66 / 0.035 | 28.56 / 0.082 | 23.26 / 0.137 | 22.36 / 0.104 | 27.06 / 0.049 | 24.15 / 0.080 | 26.10 / 0.082 |
114 |
115 |
116 | Please note:
117 | 1. The numbers reported in the paper are calculated using TF code. The numbers here are calculated using this improved Pytorch version.
118 | 2. In Yoon's results, the first frame and the last frame are missing. To compare with Yoon's results, we have to omit the first frame and the last frame. To do so, please uncomment line 72 and comment line 73 in `evaluation.py`.
119 | 3. We obtain the results of NSFF and NR NeRF using the official implementation with default parameters.
120 |
121 |
122 | ## Train a model on your sequence
123 | 0. Set some paths
124 |
125 | ```
126 | ROOT_PATH=/path/to/the/DynamicNeRF/folder
127 | DATASET_NAME=name_of_the_video_without_extension
128 | DATASET_PATH=$ROOT_PATH/data/$DATASET_NAME
129 | ```
130 |
131 | 1. Prepare training images and background masks from a video.
132 |
133 | ```
134 | cd $ROOT_PATH/utils
135 | python generate_data.py --videopath /path/to/the/video
136 | ```
137 |
138 | 2. Use COLMAP to obtain camera poses.
139 |
140 | ```
141 | colmap feature_extractor \
142 | --database_path $DATASET_PATH/database.db \
143 | --image_path $DATASET_PATH/images_colmap \
144 | --ImageReader.mask_path $DATASET_PATH/background_mask \
145 | --ImageReader.single_camera 1
146 |
147 | colmap exhaustive_matcher \
148 | --database_path $DATASET_PATH/database.db
149 |
150 | mkdir $DATASET_PATH/sparse
151 | colmap mapper \
152 | --database_path $DATASET_PATH/database.db \
153 | --image_path $DATASET_PATH/images_colmap \
154 | --output_path $DATASET_PATH/sparse \
155 | --Mapper.num_threads 16 \
156 | --Mapper.init_min_tri_angle 4 \
157 | --Mapper.multiple_models 0 \
158 | --Mapper.extract_colors 0
159 | ```
160 |
161 | 3. Save camera poses into the format that NeRF reads.
162 |
163 | ```
164 | cd $ROOT_PATH/utils
165 | python generate_pose.py --dataset_path $DATASET_PATH
166 | ```
167 |
168 | 4. Estimate monocular depth.
169 |
170 | ```
171 | cd $ROOT_PATH/utils
172 | python generate_depth.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/midas_v21-f6b98070.pt
173 | ```
174 |
175 | 5. Predict optical flows.
176 |
177 | ```
178 | cd $ROOT_PATH/utils
179 | python generate_flow.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/raft-things.pth
180 | ```
181 |
182 | 6. Obtain motion mask (code adapted from NSFF).
183 |
184 | ```
185 | cd $ROOT_PATH/utils
186 | python generate_motion_mask.py --dataset_path $DATASET_PATH
187 | ```
188 |
189 | 7. Train a model. Please change `expname` and `datadir` in `configs/config.txt`.
190 |
191 | ```
192 | cd $ROOT_PATH/
193 | python run_nerf.py --config configs/config.txt
194 | ```
195 |
196 | Explanation of each parameter:
197 |
198 | - `expname`: experiment name
199 | - `basedir`: where to store ckpts and logs
200 | - `datadir`: input data directory
201 | - `factor`: downsample factor for the input images
202 | - `N_rand`: number of random rays per gradient step
203 | - `N_samples`: number of samples per ray
204 | - `netwidth`: channels per layer
205 | - `use_viewdirs`: whether enable view-dependency for StaticNeRF
206 | - `use_viewdirsDyn`: whether enable view-dependency for DynamicNeRF
207 | - `raw_noise_std`: std dev of noise added to regularize sigma_a output
208 | - `no_ndc`: do not use normalized device coordinates
209 | - `lindisp`: sampling linearly in disparity rather than depth
210 | - `i_video`: frequency of novel view-time synthesis video saving
211 | - `i_testset`: frequency of testset video saving
212 | - `N_iters`: number of training iterations
213 | - `i_img`: frequency of tensorboard image logging
214 | - `DyNeRF_blending`: whether use DynamicNeRF to predict blending weight
215 | - `pretrain`: whether pre-train StaticNeRF
216 |
217 | ## License
218 | This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
219 |
220 | If you find this code useful for your research, please consider citing the following paper:
221 |
222 | @inproceedings{Gao-ICCV-DynNeRF,
223 | author = {Gao, Chen and Saraf, Ayush and Kopf, Johannes and Huang, Jia-Bin},
224 | title = {Dynamic View Synthesis from Dynamic Monocular Video},
225 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision},
226 | year = {2021}
227 | }
228 |
229 | ## Acknowledgments
230 | Our training code is build upon
231 | [NeRF](https://github.com/bmild/nerf),
232 | [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch), and
233 | [NSFF](https://github.com/zl548/Neural-Scene-Flow-Fields).
234 | Our flow prediction code is modified from [RAFT](https://github.com/princeton-vl/RAFT).
235 | Our depth prediction code is modified from [MiDaS](https://github.com/isl-org/MiDaS).
236 |
--------------------------------------------------------------------------------
/configs/config.txt:
--------------------------------------------------------------------------------
1 | expname = xxxxxx_DyNeRF_pretrain_test
2 | basedir = ./logs
3 | datadir = ./data/xxxxxx/
4 |
5 | dataset_type = llff
6 |
7 | factor = 4
8 | N_rand = 1024
9 | N_samples = 64
10 | netwidth = 256
11 |
12 | i_video = 100000
13 | i_testset = 100000
14 | N_iters = 500001
15 | i_img = 500
16 |
17 | use_viewdirs = True
18 | use_viewdirsDyn = True
19 | raw_noise_std = 1e0
20 | no_ndc = False
21 | lindisp = False
22 |
23 | dynamic_loss_lambda = 1.0
24 | static_loss_lambda = 1.0
25 | full_loss_lambda = 3.0
26 | depth_loss_lambda = 0.04
27 | order_loss_lambda = 0.1
28 | flow_loss_lambda = 0.02
29 | slow_loss_lambda = 0.01
30 | smooth_loss_lambda = 0.1
31 | consistency_loss_lambda = 1.0
32 | mask_loss_lambda = 0.01
33 | sparse_loss_lambda = 0.001
34 | DyNeRF_blending = True
35 | pretrain = True
36 |
--------------------------------------------------------------------------------
/configs/config_Balloon1.txt:
--------------------------------------------------------------------------------
1 | expname = Balloon1_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Balloon1/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = False
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/configs/config_Balloon2.txt:
--------------------------------------------------------------------------------
1 | expname = Balloon2_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Balloon2/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = True
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/configs/config_Jumping.txt:
--------------------------------------------------------------------------------
1 | expname = Jumping_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Jumping/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = False
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/configs/config_Playground.txt:
--------------------------------------------------------------------------------
1 | expname = Playground_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Playground/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = True
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/configs/config_Skating.txt:
--------------------------------------------------------------------------------
1 | expname = Skating_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Skating/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = True
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/configs/config_Truck.txt:
--------------------------------------------------------------------------------
1 | expname = Truck_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Truck/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = True
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/configs/config_Umbrella.txt:
--------------------------------------------------------------------------------
1 | expname = Umbrella_H270_DyNeRF_pretrain
2 | basedir = ./logs
3 | datadir = ./data/Umbrella/
4 |
5 | dataset_type = llff
6 |
7 | factor = 2
8 | N_rand = 1024
9 | N_samples = 64
10 | N_importance = 0
11 | netwidth = 256
12 |
13 | i_video = 100000
14 | i_testset = 100000
15 | N_iters = 300001
16 | i_img = 500
17 |
18 | use_viewdirs = True
19 | use_viewdirsDyn = True
20 | raw_noise_std = 1e0
21 | no_ndc = False
22 | lindisp = False
23 |
24 | dynamic_loss_lambda = 1.0
25 | static_loss_lambda = 1.0
26 | full_loss_lambda = 3.0
27 | depth_loss_lambda = 0.04
28 | order_loss_lambda = 0.1
29 | flow_loss_lambda = 0.02
30 | slow_loss_lambda = 0.01
31 | smooth_loss_lambda = 0.1
32 | consistency_loss_lambda = 1.0
33 | mask_loss_lambda = 0.1
34 | sparse_loss_lambda = 0.001
35 | DyNeRF_blending = True
36 | pretrain = True
37 |
--------------------------------------------------------------------------------
/load_llff.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import imageio
4 | import numpy as np
5 |
6 | from utils.flow_utils import resize_flow
7 | from run_nerf_helpers import get_grid
8 |
9 |
10 | def _minify(basedir, factors=[], resolutions=[]):
11 | needtoload = False
12 | for r in factors:
13 | imgdir = os.path.join(basedir, 'images_{}'.format(r))
14 | if not os.path.exists(imgdir):
15 | needtoload = True
16 | for r in resolutions:
17 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
18 | if not os.path.exists(imgdir):
19 | needtoload = True
20 | if not needtoload:
21 | return
22 |
23 | from shutil import copy
24 | from subprocess import check_output
25 |
26 | imgdir = os.path.join(basedir, 'images')
27 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
28 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
29 | imgdir_orig = imgdir
30 |
31 | wd = os.getcwd()
32 |
33 | for r in factors + resolutions:
34 | if isinstance(r, int):
35 | name = 'images_{}'.format(r)
36 | resizearg = '{}%'.format(100./r)
37 | else:
38 | name = 'images_{}x{}'.format(r[1], r[0])
39 | resizearg = '{}x{}'.format(r[1], r[0])
40 | imgdir = os.path.join(basedir, name)
41 | if os.path.exists(imgdir):
42 | continue
43 |
44 | print('Minifying', r, basedir)
45 |
46 | os.makedirs(imgdir)
47 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
48 |
49 | ext = imgs[0].split('.')[-1]
50 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
51 | print(args)
52 | os.chdir(imgdir)
53 | check_output(args, shell=True)
54 | os.chdir(wd)
55 |
56 | if ext != 'png':
57 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
58 | print('Removed duplicates')
59 | print('Done')
60 |
61 |
62 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
63 | print('factor ', factor)
64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
66 | bds = poses_arr[:, -2:].transpose([1,0])
67 |
68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
70 | sh = imageio.imread(img0).shape
71 |
72 | sfx = ''
73 |
74 | if factor is not None:
75 | sfx = '_{}'.format(factor)
76 | _minify(basedir, factors=[factor])
77 | factor = factor
78 | elif height is not None:
79 | factor = sh[0] / float(height)
80 | width = int(sh[1] / factor)
81 | if width % 2 == 1:
82 | width -= 1
83 | _minify(basedir, resolutions=[[height, width]])
84 | sfx = '_{}x{}'.format(width, height)
85 | elif width is not None:
86 | factor = sh[1] / float(width)
87 | height = int(sh[0] / factor)
88 | if height % 2 == 1:
89 | height -= 1
90 | _minify(basedir, resolutions=[[height, width]])
91 | sfx = '_{}x{}'.format(width, height)
92 | else:
93 | factor = 1
94 |
95 | imgdir = os.path.join(basedir, 'images' + sfx)
96 | if not os.path.exists(imgdir):
97 | print( imgdir, 'does not exist, returning' )
98 | return
99 |
100 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) \
101 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
102 | if poses.shape[-1] != len(imgfiles):
103 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
104 | return
105 |
106 | sh = imageio.imread(imgfiles[0]).shape
107 | num_img = len(imgfiles)
108 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
109 | poses[2, 4, :] = poses[2, 4, :] * 1./factor
110 |
111 | if not load_imgs:
112 | return poses, bds
113 |
114 | def imread(f):
115 | if f.endswith('png'):
116 | return imageio.imread(f, ignoregamma=True)
117 | else:
118 | return imageio.imread(f)
119 |
120 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
121 | imgs = np.stack(imgs, -1)
122 |
123 | assert imgs.shape[0] == sh[0]
124 | assert imgs.shape[1] == sh[1]
125 |
126 | disp_dir = os.path.join(basedir, 'disp')
127 |
128 | dispfiles = [os.path.join(disp_dir, f) \
129 | for f in sorted(os.listdir(disp_dir)) if f.endswith('npy')]
130 |
131 | disp = [cv2.resize(np.load(f),
132 | (sh[1], sh[0]),
133 | interpolation=cv2.INTER_NEAREST) for f in dispfiles]
134 | disp = np.stack(disp, -1)
135 |
136 | mask_dir = os.path.join(basedir, 'motion_masks')
137 | maskfiles = [os.path.join(mask_dir, f) \
138 | for f in sorted(os.listdir(mask_dir)) if f.endswith('png')]
139 |
140 | masks = [cv2.resize(imread(f)/255., (sh[1], sh[0]),
141 | interpolation=cv2.INTER_NEAREST) for f in maskfiles]
142 | masks = np.stack(masks, -1)
143 | masks = np.float32(masks > 1e-3)
144 |
145 | flow_dir = os.path.join(basedir, 'flow')
146 | flows_f = []
147 | flow_masks_f = []
148 | flows_b = []
149 | flow_masks_b = []
150 | for i in range(num_img):
151 | if i == num_img - 1:
152 | fwd_flow, fwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))
153 | else:
154 | fwd_flow_path = os.path.join(flow_dir, '%03d_fwd.npz'%i)
155 | fwd_data = np.load(fwd_flow_path)
156 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
157 | fwd_flow = resize_flow(fwd_flow, sh[0], sh[1])
158 | fwd_mask = np.float32(fwd_mask)
159 | fwd_mask = cv2.resize(fwd_mask, (sh[1], sh[0]),
160 | interpolation=cv2.INTER_NEAREST)
161 | flows_f.append(fwd_flow)
162 | flow_masks_f.append(fwd_mask)
163 |
164 | if i == 0:
165 | bwd_flow, bwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))
166 | else:
167 | bwd_flow_path = os.path.join(flow_dir, '%03d_bwd.npz'%i)
168 | bwd_data = np.load(bwd_flow_path)
169 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
170 | bwd_flow = resize_flow(bwd_flow, sh[0], sh[1])
171 | bwd_mask = np.float32(bwd_mask)
172 | bwd_mask = cv2.resize(bwd_mask, (sh[1], sh[0]),
173 | interpolation=cv2.INTER_NEAREST)
174 | flows_b.append(bwd_flow)
175 | flow_masks_b.append(bwd_mask)
176 |
177 | flows_f = np.stack(flows_f, -1)
178 | flow_masks_f = np.stack(flow_masks_f, -1)
179 | flows_b = np.stack(flows_b, -1)
180 | flow_masks_b = np.stack(flow_masks_b, -1)
181 |
182 | print(imgs.shape)
183 | print(disp.shape)
184 | print(masks.shape)
185 | print(flows_f.shape)
186 | print(flow_masks_f.shape)
187 |
188 | assert(imgs.shape[0] == disp.shape[0])
189 | assert(imgs.shape[0] == masks.shape[0])
190 | assert(imgs.shape[0] == flows_f.shape[0])
191 | assert(imgs.shape[0] == flow_masks_f.shape[0])
192 |
193 | assert(imgs.shape[1] == disp.shape[1])
194 | assert(imgs.shape[1] == masks.shape[1])
195 |
196 | return poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b
197 |
198 |
199 | def normalize(x):
200 | return x / np.linalg.norm(x)
201 |
202 | def viewmatrix(z, up, pos):
203 | vec2 = normalize(z)
204 | vec1_avg = up
205 | vec0 = normalize(np.cross(vec1_avg, vec2))
206 | vec1 = normalize(np.cross(vec2, vec0))
207 | m = np.stack([vec0, vec1, vec2, pos], 1)
208 | return m
209 |
210 |
211 | def poses_avg(poses):
212 |
213 | hwf = poses[0, :3, -1:]
214 |
215 | center = poses[:, :3, 3].mean(0)
216 | vec2 = normalize(poses[:, :3, 2].sum(0))
217 | up = poses[:, :3, 1].sum(0)
218 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
219 |
220 | return c2w
221 |
222 |
223 |
224 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
225 | render_poses = []
226 | rads = np.array(list(rads) + [1.])
227 | hwf = c2w[:,4:5]
228 |
229 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
230 | c = np.dot(c2w[:3, :4],
231 | np.array([np.cos(theta),
232 | -np.sin(theta),
233 | -np.sin(theta*zrate),
234 | 1.]) * rads)
235 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
236 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
237 | return render_poses
238 |
239 |
240 |
241 | def recenter_poses(poses):
242 |
243 | poses_ = poses+0
244 | bottom = np.reshape([0,0,0,1.], [1,4])
245 | c2w = poses_avg(poses)
246 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
247 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
248 | poses = np.concatenate([poses[:,:3,:4], bottom], -2)
249 |
250 | poses = np.linalg.inv(c2w) @ poses
251 | poses_[:,:3,:4] = poses[:,:3,:4]
252 | poses = poses_
253 | return poses
254 |
255 |
256 | def spherify_poses(poses, bds):
257 |
258 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
259 |
260 | rays_d = poses[:,:3,2:3]
261 | rays_o = poses[:,:3,3:4]
262 |
263 | def min_line_dist(rays_o, rays_d):
264 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
265 | b_i = -A_i @ rays_o
266 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
267 | return pt_mindist
268 |
269 | pt_mindist = min_line_dist(rays_o, rays_d)
270 |
271 | center = pt_mindist
272 | up = (poses[:,:3,3] - center).mean(0)
273 |
274 | vec0 = normalize(up)
275 | vec1 = normalize(np.cross([.1,.2,.3], vec0))
276 | vec2 = normalize(np.cross(vec0, vec1))
277 | pos = center
278 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
279 |
280 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
281 |
282 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
283 |
284 | sc = 1./rad
285 | poses_reset[:,:3,3] *= sc
286 | bds *= sc
287 | rad *= sc
288 |
289 | centroid = np.mean(poses_reset[:,:3,3], 0)
290 | zh = centroid[2]
291 | radcircle = np.sqrt(rad**2-zh**2)
292 | new_poses = []
293 |
294 | for th in np.linspace(0.,2.*np.pi, 120):
295 |
296 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
297 | up = np.array([0,0,-1.])
298 |
299 | vec2 = normalize(camorigin)
300 | vec0 = normalize(np.cross(vec2, up))
301 | vec1 = normalize(np.cross(vec2, vec0))
302 | pos = camorigin
303 | p = np.stack([vec0, vec1, vec2, pos], 1)
304 |
305 | new_poses.append(p)
306 |
307 | new_poses = np.stack(new_poses, 0)
308 |
309 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
310 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
311 |
312 | return poses_reset, new_poses, bds
313 |
314 |
315 | def load_llff_data(args, basedir,
316 | factor=2,
317 | recenter=True, bd_factor=.75,
318 | spherify=False, path_zflat=False,
319 | frame2dolly=10):
320 |
321 | poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b = \
322 | _load_data(basedir, factor=factor) # factor=2 downsamples original imgs by 2x
323 |
324 | print('Loaded', basedir, bds.min(), bds.max())
325 |
326 | # Correct rotation matrix ordering and move variable dim to axis 0
327 | poses = np.concatenate([poses[:, 1:2, :],
328 | -poses[:, 0:1, :],
329 | poses[:, 2:, :]], 1)
330 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
331 | images = np.moveaxis(imgs, -1, 0).astype(np.float32)
332 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
333 | disp = np.moveaxis(disp, -1, 0).astype(np.float32)
334 | masks = np.moveaxis(masks, -1, 0).astype(np.float32)
335 | flows_f = np.moveaxis(flows_f, -1, 0).astype(np.float32)
336 | flow_masks_f = np.moveaxis(flow_masks_f, -1, 0).astype(np.float32)
337 | flows_b = np.moveaxis(flows_b, -1, 0).astype(np.float32)
338 | flow_masks_b = np.moveaxis(flow_masks_b, -1, 0).astype(np.float32)
339 |
340 | # Rescale if bd_factor is provided
341 | sc = 1. if bd_factor is None else 1./(np.percentile(bds[:, 0], 5) * bd_factor)
342 |
343 | poses[:, :3, 3] *= sc
344 | bds *= sc
345 |
346 | if recenter:
347 | poses = recenter_poses(poses)
348 |
349 | # Only for rendering
350 | if frame2dolly == -1:
351 | c2w = poses_avg(poses)
352 | else:
353 | c2w = poses[frame2dolly, :, :]
354 |
355 | H, W, _ = c2w[:, -1]
356 |
357 | # Generate poses for novel views
358 | render_poses, render_focals = generate_path(c2w, args)
359 | render_poses = np.array(render_poses).astype(np.float32)
360 |
361 | grids = get_grid(int(H), int(W), len(poses), flows_f, flow_masks_f, flows_b, flow_masks_b) # [N, H, W, 8]
362 |
363 | return images, disp, masks, poses, bds,\
364 | render_poses, render_focals, grids
365 |
366 |
367 | def generate_path(c2w, args):
368 | hwf = c2w[:, 4:5]
369 | num_novelviews = args.num_novelviews
370 | max_disp = 48.0
371 | H, W, focal = hwf[:, 0]
372 |
373 | max_trans = max_disp / focal
374 | output_poses = []
375 | output_focals = []
376 |
377 | # Rendering teaser. Add translation.
378 | for i in range(num_novelviews):
379 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.x_trans_multiplier
380 | y_trans = max_trans * (np.cos(2.0 * np.pi * float(i) / float(num_novelviews)) - 1.) * args.y_trans_multiplier
381 | z_trans = 0.
382 |
383 | i_pose = np.concatenate([
384 | np.concatenate(
385 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
386 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
387 | ],axis=0)
388 |
389 | i_pose = np.linalg.inv(i_pose)
390 |
391 | ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
392 |
393 | render_pose = np.dot(ref_pose, i_pose)
394 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
395 | output_focals.append(focal)
396 |
397 | # Rendering teaser. Add zooming.
398 | if args.frame2dolly != -1:
399 | for i in range(num_novelviews // 2 + 1):
400 | x_trans = 0.
401 | y_trans = 0.
402 | # z_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.z_trans_multiplier
403 | z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)
404 | i_pose = np.concatenate([
405 | np.concatenate(
406 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
407 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
408 | ],axis=0)
409 |
410 | i_pose = np.linalg.inv(i_pose) #torch.tensor(np.linalg.inv(i_pose)).float()
411 |
412 | ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
413 |
414 | render_pose = np.dot(ref_pose, i_pose)
415 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
416 | output_focals.append(focal)
417 | print(z_trans / max_trans / args.z_trans_multiplier)
418 |
419 | # Rendering teaser. Add dolly zoom.
420 | if args.frame2dolly != -1:
421 | for i in range(num_novelviews // 2 + 1):
422 | x_trans = 0.
423 | y_trans = 0.
424 | z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)
425 | i_pose = np.concatenate([
426 | np.concatenate(
427 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
428 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
429 | ],axis=0)
430 |
431 | i_pose = np.linalg.inv(i_pose)
432 |
433 | ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
434 |
435 | render_pose = np.dot(ref_pose, i_pose)
436 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
437 | new_focal = focal - args.focal_decrease * z_trans / max_trans / args.z_trans_multiplier
438 | output_focals.append(new_focal)
439 | print(z_trans / max_trans / args.z_trans_multiplier, new_focal)
440 |
441 | return output_poses, output_focals
442 |
--------------------------------------------------------------------------------
/run_nerf_helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import imageio
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | # Misc utils
12 | def img2mse(x, y, M=None):
13 | if M == None:
14 | return torch.mean((x - y) ** 2)
15 | else:
16 | return torch.sum((x - y) ** 2 * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
17 |
18 |
19 | def img2mae(x, y, M=None):
20 | if M == None:
21 | return torch.mean(torch.abs(x - y))
22 | else:
23 | return torch.sum(torch.abs(x - y) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
24 |
25 |
26 | def L1(x, M=None):
27 | if M == None:
28 | return torch.mean(torch.abs(x))
29 | else:
30 | return torch.sum(torch.abs(x) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
31 |
32 |
33 | def L2(x, M=None):
34 | if M == None:
35 | return torch.mean(x ** 2)
36 | else:
37 | return torch.sum((x ** 2) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
38 |
39 |
40 | def entropy(x):
41 | return -torch.sum(x * torch.log(x + 1e-19)) / x.shape[0]
42 |
43 |
44 | def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
45 |
46 |
47 | def to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8)
48 |
49 |
50 | class Embedder:
51 |
52 | def __init__(self, **kwargs):
53 |
54 | self.kwargs = kwargs
55 | self.create_embedding_fn()
56 |
57 | def create_embedding_fn(self):
58 |
59 | embed_fns = []
60 | d = self.kwargs['input_dims']
61 | out_dim = 0
62 | if self.kwargs['include_input']:
63 | embed_fns.append(lambda x: x)
64 | out_dim += d
65 |
66 | max_freq = self.kwargs['max_freq_log2']
67 | N_freqs = self.kwargs['num_freqs']
68 |
69 | if self.kwargs['log_sampling']:
70 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
71 | else:
72 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
73 |
74 | for freq in freq_bands:
75 | for p_fn in self.kwargs['periodic_fns']:
76 | embed_fns.append(lambda x, p_fn=p_fn,
77 | freq=freq : p_fn(x * freq))
78 | out_dim += d
79 |
80 | self.embed_fns = embed_fns
81 | self.out_dim = out_dim
82 |
83 | def embed(self, inputs):
84 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
85 |
86 |
87 | def get_embedder(multires, i=0, input_dims=3):
88 |
89 | if i == -1:
90 | return nn.Identity(), 3
91 |
92 | embed_kwargs = {
93 | 'include_input': True,
94 | 'input_dims': input_dims,
95 | 'max_freq_log2': multires-1,
96 | 'num_freqs': multires,
97 | 'log_sampling': True,
98 | 'periodic_fns': [torch.sin, torch.cos],
99 | }
100 |
101 | embedder_obj = Embedder(**embed_kwargs)
102 | def embed(x, eo=embedder_obj): return eo.embed(x)
103 | return embed, embedder_obj.out_dim
104 |
105 |
106 | # Dynamic NeRF model architecture
107 | class NeRF_d(nn.Module):
108 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirsDyn=True):
109 | """
110 | """
111 | super(NeRF_d, self).__init__()
112 | self.D = D
113 | self.W = W
114 | self.input_ch = input_ch
115 | self.input_ch_views = input_ch_views
116 | self.skips = skips
117 | self.use_viewdirsDyn = use_viewdirsDyn
118 |
119 | self.pts_linears = nn.ModuleList(
120 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
121 |
122 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
123 |
124 | if self.use_viewdirsDyn:
125 | self.feature_linear = nn.Linear(W, W)
126 | self.alpha_linear = nn.Linear(W, 1)
127 | self.rgb_linear = nn.Linear(W//2, 3)
128 | else:
129 | self.output_linear = nn.Linear(W, output_ch)
130 |
131 | self.sf_linear = nn.Linear(W, 6)
132 | self.weight_linear = nn.Linear(W, 1)
133 |
134 | def forward(self, x):
135 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
136 | h = input_pts
137 | for i, l in enumerate(self.pts_linears):
138 | h = self.pts_linears[i](h)
139 | h = F.relu(h)
140 | if i in self.skips:
141 | h = torch.cat([input_pts, h], -1)
142 |
143 | # Scene flow should be unbounded. However, in NDC space the coordinate is
144 | # bounded in [-1, 1].
145 | sf = torch.tanh(self.sf_linear(h))
146 | blending = torch.sigmoid(self.weight_linear(h))
147 |
148 | if self.use_viewdirsDyn:
149 | alpha = self.alpha_linear(h)
150 | feature = self.feature_linear(h)
151 | h = torch.cat([feature, input_views], -1)
152 |
153 | for i, l in enumerate(self.views_linears):
154 | h = self.views_linears[i](h)
155 | h = F.relu(h)
156 |
157 | rgb = self.rgb_linear(h)
158 | outputs = torch.cat([rgb, alpha], -1)
159 | else:
160 | outputs = self.output_linear(h)
161 |
162 | return torch.cat([outputs, sf, blending], dim=-1)
163 |
164 |
165 | # Static NeRF model architecture
166 | class NeRF_s(nn.Module):
167 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=True):
168 | """
169 | """
170 | super(NeRF_s, self).__init__()
171 | self.D = D
172 | self.W = W
173 | self.input_ch = input_ch
174 | self.input_ch_views = input_ch_views
175 | self.skips = skips
176 | self.use_viewdirs = use_viewdirs
177 |
178 | self.pts_linears = nn.ModuleList(
179 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
180 |
181 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
182 |
183 | if self.use_viewdirs:
184 | self.feature_linear = nn.Linear(W, W)
185 | self.alpha_linear = nn.Linear(W, 1)
186 | self.rgb_linear = nn.Linear(W//2, 3)
187 | else:
188 | self.output_linear = nn.Linear(W, output_ch)
189 |
190 | self.weight_linear = nn.Linear(W, 1)
191 |
192 | def forward(self, x):
193 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
194 | h = input_pts
195 | for i, l in enumerate(self.pts_linears):
196 | h = self.pts_linears[i](h)
197 | h = F.relu(h)
198 | if i in self.skips:
199 | h = torch.cat([input_pts, h], -1)
200 |
201 | blending = torch.sigmoid(self.weight_linear(h))
202 | if self.use_viewdirs:
203 | alpha = self.alpha_linear(h)
204 | feature = self.feature_linear(h)
205 | h = torch.cat([feature, input_views], -1)
206 |
207 | for i, l in enumerate(self.views_linears):
208 | h = self.views_linears[i](h)
209 | h = F.relu(h)
210 |
211 | rgb = self.rgb_linear(h)
212 | outputs = torch.cat([rgb, alpha], -1)
213 | else:
214 | outputs = self.output_linear(h)
215 |
216 | return torch.cat([outputs, blending], -1)
217 |
218 |
219 | def batchify(fn, chunk):
220 | """Constructs a version of 'fn' that applies to smaller batches.
221 | """
222 | if chunk is None:
223 | return fn
224 |
225 | def ret(inputs):
226 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
227 | return ret
228 |
229 |
230 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
231 | """Prepares inputs and applies network 'fn'.
232 | """
233 |
234 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
235 |
236 | embedded = embed_fn(inputs_flat)
237 | if viewdirs is not None:
238 | input_dirs = viewdirs[:, None].expand(inputs[:, :, :3].shape)
239 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
240 | embedded_dirs = embeddirs_fn(input_dirs_flat)
241 | embedded = torch.cat([embedded, embedded_dirs], -1)
242 |
243 | outputs_flat = batchify(fn, netchunk)(embedded)
244 | outputs = torch.reshape(outputs_flat, list(
245 | inputs.shape[:-1]) + [outputs_flat.shape[-1]])
246 | return outputs
247 |
248 |
249 | def create_nerf(args):
250 | """Instantiate NeRF's MLP model.
251 | """
252 |
253 | embed_fn_d, input_ch_d = get_embedder(args.multires, args.i_embed, 4)
254 | # 10 * 2 * 4 + 4 = 84
255 | # L * (sin, cos) * (x, y, z, t) + (x, y, z, t)
256 |
257 | input_ch_views = 0
258 | embeddirs_fn = None
259 | if args.use_viewdirs:
260 | embeddirs_fn, input_ch_views = get_embedder(
261 | args.multires_views, args.i_embed, 3)
262 | # 4 * 2 * 3 + 3 = 27
263 | # L * (sin, cos) * (3 Cartesian viewing direction unit vector from [theta, phi]) + (3 Cartesian viewing direction unit vector from [theta, phi])
264 | output_ch = 5 if args.N_importance > 0 else 4
265 | skips = [4]
266 | model_d = NeRF_d(D=args.netdepth, W=args.netwidth,
267 | input_ch=input_ch_d, output_ch=output_ch, skips=skips,
268 | input_ch_views=input_ch_views,
269 | use_viewdirsDyn=args.use_viewdirsDyn).to(device)
270 |
271 | device_ids = list(range(torch.cuda.device_count()))
272 | model_d = torch.nn.DataParallel(model_d, device_ids=device_ids)
273 | grad_vars = list(model_d.parameters())
274 |
275 | embed_fn_s, input_ch_s = get_embedder(args.multires, args.i_embed, 3)
276 | # 10 * 2 * 3 + 3 = 63
277 | # L * (sin, cos) * (x, y, z) + (x, y, z)
278 |
279 | model_s = NeRF_s(D=args.netdepth, W=args.netwidth,
280 | input_ch=input_ch_s, output_ch=output_ch, skips=skips,
281 | input_ch_views=input_ch_views,
282 | use_viewdirs=args.use_viewdirs).to(device)
283 |
284 | model_s = torch.nn.DataParallel(model_s, device_ids=device_ids)
285 | grad_vars += list(model_s.parameters())
286 |
287 | model_fine = None
288 | if args.N_importance > 0:
289 | raise NotImplementedError
290 |
291 | def network_query_fn_d(inputs, viewdirs, network_fn): return run_network(
292 | inputs, viewdirs, network_fn,
293 | embed_fn=embed_fn_d,
294 | embeddirs_fn=embeddirs_fn,
295 | netchunk=args.netchunk)
296 |
297 | def network_query_fn_s(inputs, viewdirs, network_fn): return run_network(
298 | inputs, viewdirs, network_fn,
299 | embed_fn=embed_fn_s,
300 | embeddirs_fn=embeddirs_fn,
301 | netchunk=args.netchunk)
302 |
303 | render_kwargs_train = {
304 | 'network_query_fn_d': network_query_fn_d,
305 | 'network_query_fn_s': network_query_fn_s,
306 | 'network_fn_d': model_d,
307 | 'network_fn_s': model_s,
308 | 'perturb': args.perturb,
309 | 'N_importance': args.N_importance,
310 | 'N_samples': args.N_samples,
311 | 'use_viewdirs': args.use_viewdirs,
312 | 'raw_noise_std': args.raw_noise_std,
313 | 'inference': False,
314 | 'DyNeRF_blending': args.DyNeRF_blending,
315 | }
316 |
317 | # NDC only good for LLFF-style forward facing data
318 | if args.dataset_type != 'llff' or args.no_ndc:
319 | print('Not ndc!')
320 | render_kwargs_train['ndc'] = False
321 | render_kwargs_train['lindisp'] = args.lindisp
322 | else:
323 | render_kwargs_train['ndc'] = True
324 |
325 | render_kwargs_test = {
326 | k: render_kwargs_train[k] for k in render_kwargs_train}
327 | render_kwargs_test['perturb'] = False
328 | render_kwargs_test['raw_noise_std'] = 0.
329 | render_kwargs_test['inference'] = True
330 |
331 | # Create optimizer
332 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
333 |
334 | start = 0
335 | basedir = args.basedir
336 | expname = args.expname
337 |
338 | if args.ft_path is not None and args.ft_path != 'None':
339 | ckpts = [args.ft_path]
340 | else:
341 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
342 | print('Found ckpts', ckpts)
343 | if len(ckpts) > 0 and not args.no_reload:
344 | ckpt_path = ckpts[-1]
345 | print('Reloading from', ckpt_path)
346 | ckpt = torch.load(ckpt_path)
347 |
348 | start = ckpt['global_step'] + 1
349 | # optimizer.load_state_dict(ckpt['optimizer_state_dict'])
350 | model_d.load_state_dict(ckpt['network_fn_d_state_dict'])
351 | model_s.load_state_dict(ckpt['network_fn_s_state_dict'])
352 | print('Resetting step to', start)
353 |
354 | if model_fine is not None:
355 | raise NotImplementedError
356 |
357 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
358 |
359 |
360 | # Ray helpers
361 | def get_rays(H, W, focal, c2w):
362 | """Get ray origins, directions from a pinhole camera."""
363 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
364 | i = i.t()
365 | j = j.t()
366 | dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
367 | # Rotate ray directions from camera frame to the world frame
368 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
369 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
370 | rays_o = c2w[:3, -1].expand(rays_d.shape)
371 | return rays_o, rays_d
372 |
373 |
374 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
375 | """Normalized device coordinate rays.
376 | Space such that the canvas is a cube with sides [-1, 1] in each axis.
377 | Args:
378 | H: int. Height in pixels.
379 | W: int. Width in pixels.
380 | focal: float. Focal length of pinhole camera.
381 | near: float or array of shape[batch_size]. Near depth bound for the scene.
382 | rays_o: array of shape [batch_size, 3]. Camera origin.
383 | rays_d: array of shape [batch_size, 3]. Ray direction.
384 | Returns:
385 | rays_o: array of shape [batch_size, 3]. Camera origin in NDC.
386 | rays_d: array of shape [batch_size, 3]. Ray direction in NDC.
387 | """
388 | # Shift ray origins to near plane
389 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
390 | rays_o = rays_o + t[..., None] * rays_d
391 |
392 | # Projection
393 | o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2]
394 | o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2]
395 | o2 = 1. + 2. * near / rays_o[..., 2]
396 |
397 | d0 = -1./(W/(2.*focal)) * \
398 | (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])
399 | d1 = -1./(H/(2.*focal)) * \
400 | (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])
401 | d2 = -2. * near / rays_o[..., 2]
402 |
403 | rays_o = torch.stack([o0, o1, o2], -1)
404 | rays_d = torch.stack([d0, d1, d2], -1)
405 |
406 | return rays_o, rays_d
407 |
408 |
409 | def get_grid(H, W, num_img, flows_f, flow_masks_f, flows_b, flow_masks_b):
410 |
411 | # |--------------------| |--------------------|
412 | # | j | | v |
413 | # | i * | | u * |
414 | # | | | |
415 | # |--------------------| |--------------------|
416 |
417 | i, j = np.meshgrid(np.arange(W, dtype=np.float32),
418 | np.arange(H, dtype=np.float32), indexing='xy')
419 |
420 | grid = np.empty((0, H, W, 8), np.float32)
421 | for idx in range(num_img):
422 | grid = np.concatenate((grid, np.stack([i,
423 | j,
424 | flows_f[idx, :, :, 0],
425 | flows_f[idx, :, :, 1],
426 | flow_masks_f[idx, :, :],
427 | flows_b[idx, :, :, 0],
428 | flows_b[idx, :, :, 1],
429 | flow_masks_b[idx, :, :]], -1)[None, ...]))
430 | return grid
431 |
432 |
433 | def NDC2world(pts, H, W, f):
434 |
435 | # NDC coordinate to world coordinate
436 | pts_z = 2 / (torch.clamp(pts[..., 2:], min=-1., max=1-1e-3) - 1)
437 | pts_x = - pts[..., 0:1] * pts_z * W / 2 / f
438 | pts_y = - pts[..., 1:2] * pts_z * H / 2 / f
439 | pts_world = torch.cat([pts_x, pts_y, pts_z], -1)
440 |
441 | return pts_world
442 |
443 |
444 | def render_3d_point(H, W, f, pose, weights, pts):
445 | """Render 3D position along each ray and project it to the image plane.
446 | """
447 |
448 | c2w = pose
449 | w2c = c2w[:3, :3].transpose(0, 1) # same as np.linalg.inv(c2w[:3, :3])
450 |
451 | # Rendered 3D position in NDC coordinate
452 | pts_map_NDC = torch.sum(weights[..., None] * pts, -2)
453 |
454 | # NDC coordinate to world coordinate
455 | pts_map_world = NDC2world(pts_map_NDC, H, W, f)
456 |
457 | # World coordinate to camera coordinate
458 | # Translate
459 | pts_map_world = pts_map_world - c2w[:, 3]
460 | # Rotate
461 | pts_map_cam = torch.sum(pts_map_world[..., None, :] * w2c[:3, :3], -1)
462 |
463 | # Camera coordinate to 2D image coordinate
464 | pts_plane = torch.cat([pts_map_cam[..., 0:1] / (- pts_map_cam[..., 2:]) * f + W * .5,
465 | - pts_map_cam[..., 1:2] / (- pts_map_cam[..., 2:]) * f + H * .5],
466 | -1)
467 |
468 | return pts_plane
469 |
470 |
471 | def induce_flow(H, W, focal, pose_neighbor, weights, pts_3d_neighbor, pts_2d):
472 |
473 | # Render 3D position along each ray and project it to the neighbor frame's image plane.
474 | pts_2d_neighbor = render_3d_point(H, W, focal,
475 | pose_neighbor,
476 | weights,
477 | pts_3d_neighbor)
478 | induced_flow = pts_2d_neighbor - pts_2d
479 |
480 | return induced_flow
481 |
482 |
483 | def compute_depth_loss(dyn_depth, gt_depth):
484 |
485 | t_d = torch.median(dyn_depth)
486 | s_d = torch.mean(torch.abs(dyn_depth - t_d))
487 | dyn_depth_norm = (dyn_depth - t_d) / s_d
488 |
489 | t_gt = torch.median(gt_depth)
490 | s_gt = torch.mean(torch.abs(gt_depth - t_gt))
491 | gt_depth_norm = (gt_depth - t_gt) / s_gt
492 |
493 | return torch.mean((dyn_depth_norm - gt_depth_norm) ** 2)
494 |
495 |
496 | def normalize_depth(depth):
497 | return torch.clamp(depth / percentile(depth, 97), 0., 1.)
498 |
499 |
500 | def percentile(t, q):
501 | """
502 | Return the ``q``-th percentile of the flattened input tensor's data.
503 |
504 | CAUTION:
505 | * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
506 | * Values are not interpolated, which corresponds to
507 | ``numpy.percentile(..., interpolation="nearest")``.
508 |
509 | :param t: Input tensor.
510 | :param q: Percentile to compute, which must be between 0 and 100 inclusive.
511 | :return: Resulting value (scalar).
512 | """
513 |
514 | k = 1 + round(.01 * float(q) * (t.numel() - 1))
515 | result = t.view(-1).kthvalue(k).values.item()
516 | return result
517 |
518 |
519 | def save_res(moviebase, ret, fps=None):
520 |
521 | if fps == None:
522 | if len(ret['rgbs']) < 25:
523 | fps = 4
524 | else:
525 | fps = 24
526 |
527 | for k in ret:
528 | if 'rgbs' in k:
529 | imageio.mimwrite(moviebase + k + '.mp4',
530 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
531 | # imageio.mimsave(moviebase + k + '.gif',
532 | # to8b(ret[k]), format='gif', fps=fps)
533 | elif 'depths' in k:
534 | imageio.mimwrite(moviebase + k + '.mp4',
535 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
536 | # imageio.mimsave(moviebase + k + '.gif',
537 | # to8b(ret[k]), format='gif', fps=fps)
538 | elif 'disps' in k:
539 | imageio.mimwrite(moviebase + k + '.mp4',
540 | to8b(ret[k] / np.max(ret[k])), fps=fps, quality=8, macro_block_size=1)
541 | # imageio.mimsave(moviebase + k + '.gif',
542 | # to8b(ret[k] / np.max(ret[k])), format='gif', fps=fps)
543 | elif 'sceneflow_' in k:
544 | imageio.mimwrite(moviebase + k + '.mp4',
545 | to8b(norm_sf(ret[k])), fps=fps, quality=8, macro_block_size=1)
546 | # imageio.mimsave(moviebase + k + '.gif',
547 | # to8b(norm_sf(ret[k])), format='gif', fps=fps)
548 | elif 'flows' in k:
549 | imageio.mimwrite(moviebase + k + '.mp4',
550 | ret[k], fps=fps, quality=8, macro_block_size=1)
551 | # imageio.mimsave(moviebase + k + '.gif',
552 | # ret[k], format='gif', fps=fps)
553 | elif 'dynamicness' in k:
554 | imageio.mimwrite(moviebase + k + '.mp4',
555 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
556 | # imageio.mimsave(moviebase + k + '.gif',
557 | # to8b(ret[k]), format='gif', fps=fps)
558 | elif 'disocclusions' in k:
559 | imageio.mimwrite(moviebase + k + '.mp4',
560 | to8b(ret[k][..., 0]), fps=fps, quality=8, macro_block_size=1)
561 | # imageio.mimsave(moviebase + k + '.gif',
562 | # to8b(ret[k][..., 0]), format='gif', fps=fps)
563 | elif 'blending' in k:
564 | blending = ret[k][..., None]
565 | blending = np.moveaxis(blending, [0, 1, 2, 3], [1, 2, 0, 3])
566 | imageio.mimwrite(moviebase + k + '.mp4',
567 | to8b(blending), fps=fps, quality=8, macro_block_size=1)
568 | # imageio.mimsave(moviebase + k + '.gif',
569 | # to8b(blending), format='gif', fps=fps)
570 | elif 'weights' in k:
571 | imageio.mimwrite(moviebase + k + '.mp4',
572 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
573 | else:
574 | raise NotImplementedError
575 |
576 |
577 | def norm_sf_channel(sf_ch):
578 |
579 | # Make sure zero scene flow is not shifted
580 | sf_ch[sf_ch >= 0] = sf_ch[sf_ch >= 0] / sf_ch.max() / 2
581 | sf_ch[sf_ch < 0] = sf_ch[sf_ch < 0] / np.abs(sf_ch.min()) / 2
582 | sf_ch = sf_ch + 0.5
583 | return sf_ch
584 |
585 |
586 | def norm_sf(sf):
587 |
588 | sf = np.concatenate((norm_sf_channel(sf[..., 0:1]),
589 | norm_sf_channel(sf[..., 1:2]),
590 | norm_sf_channel(sf[..., 2:3])), -1)
591 | sf = np.moveaxis(sf, [0, 1, 2, 3], [1, 2, 0, 3])
592 | return sf
593 |
594 |
595 | # Spatial smoothness (adapted from NSFF)
596 | def compute_sf_smooth_s_loss(pts1, pts2, H, W, f):
597 |
598 | N_samples = pts1.shape[1]
599 |
600 | # NDC coordinate to world coordinate
601 | pts1_world = NDC2world(pts1[..., :int(N_samples * 0.95), :], H, W, f)
602 | pts2_world = NDC2world(pts2[..., :int(N_samples * 0.95), :], H, W, f)
603 |
604 | # scene flow in world coordinate
605 | scene_flow_world = pts1_world - pts2_world
606 |
607 | return L1(scene_flow_world[..., :-1, :] - scene_flow_world[..., 1:, :])
608 |
609 |
610 | # Temporal smoothness
611 | def compute_sf_smooth_loss(pts, pts_f, pts_b, H, W, f):
612 |
613 | N_samples = pts.shape[1]
614 |
615 | pts_world = NDC2world(pts[..., :int(N_samples * 0.9), :], H, W, f)
616 | pts_f_world = NDC2world(pts_f[..., :int(N_samples * 0.9), :], H, W, f)
617 | pts_b_world = NDC2world(pts_b[..., :int(N_samples * 0.9), :], H, W, f)
618 |
619 | # scene flow in world coordinate
620 | sceneflow_f = pts_f_world - pts_world
621 | sceneflow_b = pts_b_world - pts_world
622 |
623 | # For a 3D point, its forward and backward sceneflow should be opposite.
624 | return L2(sceneflow_f + sceneflow_b)
625 |
--------------------------------------------------------------------------------
/utils/RAFT/__init__.py:
--------------------------------------------------------------------------------
1 | # from .demo import RAFT_infer
2 | from .raft import RAFT
3 |
--------------------------------------------------------------------------------
/utils/RAFT/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class CorrLayer(torch.autograd.Function):
64 | @staticmethod
65 | def forward(ctx, fmap1, fmap2, coords, r):
66 | fmap1 = fmap1.contiguous()
67 | fmap2 = fmap2.contiguous()
68 | coords = coords.contiguous()
69 | ctx.save_for_backward(fmap1, fmap2, coords)
70 | ctx.r = r
71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
72 | return corr
73 |
74 | @staticmethod
75 | def backward(ctx, grad_corr):
76 | fmap1, fmap2, coords = ctx.saved_tensors
77 | grad_corr = grad_corr.contiguous()
78 | fmap1_grad, fmap2_grad, coords_grad = \
79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
80 | return fmap1_grad, fmap2_grad, coords_grad, None
81 |
82 |
83 | class AlternateCorrBlock:
84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
85 | self.num_levels = num_levels
86 | self.radius = radius
87 |
88 | self.pyramid = [(fmap1, fmap2)]
89 | for i in range(self.num_levels):
90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
92 | self.pyramid.append((fmap1, fmap2))
93 |
94 | def __call__(self, coords):
95 |
96 | coords = coords.permute(0, 2, 3, 1)
97 | B, H, W, _ = coords.shape
98 |
99 | corr_list = []
100 | for i in range(self.num_levels):
101 | r = self.radius
102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
104 |
105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
107 | corr_list.append(corr.squeeze(1))
108 |
109 | corr = torch.stack(corr_list, dim=1)
110 | corr = corr.reshape(B, -1, H, W)
111 | return corr / 16.0
112 |
--------------------------------------------------------------------------------
/utils/RAFT/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/utils/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os
4 | import cv2
5 | import glob
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 | from .raft import RAFT
11 | from .utils import flow_viz
12 | from .utils.utils import InputPadder
13 |
14 |
15 |
16 | DEVICE = 'cuda'
17 |
18 | def load_image(imfile):
19 | img = np.array(Image.open(imfile)).astype(np.uint8)
20 | img = torch.from_numpy(img).permute(2, 0, 1).float()
21 | return img
22 |
23 |
24 | def load_image_list(image_files):
25 | images = []
26 | for imfile in sorted(image_files):
27 | images.append(load_image(imfile))
28 |
29 | images = torch.stack(images, dim=0)
30 | images = images.to(DEVICE)
31 |
32 | padder = InputPadder(images.shape)
33 | return padder.pad(images)[0]
34 |
35 |
36 | def viz(img, flo):
37 | img = img[0].permute(1,2,0).cpu().numpy()
38 | flo = flo[0].permute(1,2,0).cpu().numpy()
39 |
40 | # map flow to rgb image
41 | flo = flow_viz.flow_to_image(flo)
42 | # img_flo = np.concatenate([img, flo], axis=0)
43 | img_flo = flo
44 |
45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
47 | # cv2.waitKey()
48 |
49 |
50 | def demo(args):
51 | model = torch.nn.DataParallel(RAFT(args))
52 | model.load_state_dict(torch.load(args.model))
53 |
54 | model = model.module
55 | model.to(DEVICE)
56 | model.eval()
57 |
58 | with torch.no_grad():
59 | images = glob.glob(os.path.join(args.path, '*.png')) + \
60 | glob.glob(os.path.join(args.path, '*.jpg'))
61 |
62 | images = load_image_list(images)
63 | for i in range(images.shape[0]-1):
64 | image1 = images[i,None]
65 | image2 = images[i+1,None]
66 |
67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
68 | viz(image1, flow_up)
69 |
70 |
71 | def RAFT_infer(args):
72 | model = torch.nn.DataParallel(RAFT(args))
73 | model.load_state_dict(torch.load(args.model))
74 |
75 | model = model.module
76 | model.to(DEVICE)
77 | model.eval()
78 |
79 | return model
80 |
--------------------------------------------------------------------------------
/utils/RAFT/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/utils/RAFT/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock, SmallUpdateBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in args._get_kwargs():
42 | args.dropout = 0
43 |
44 | if 'alternate_corr' not in args._get_kwargs():
45 | args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 |
59 | def freeze_bn(self):
60 | for m in self.modules():
61 | if isinstance(m, nn.BatchNorm2d):
62 | m.eval()
63 |
64 | def initialize_flow(self, img):
65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
66 | N, C, H, W = img.shape
67 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
68 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
69 |
70 | # optical flow computed as difference: flow = coords1 - coords0
71 | return coords0, coords1
72 |
73 | def upsample_flow(self, flow, mask):
74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
75 | N, _, H, W = flow.shape
76 | mask = mask.view(N, 1, 9, 8, 8, H, W)
77 | mask = torch.softmax(mask, dim=2)
78 |
79 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
81 |
82 | up_flow = torch.sum(mask * up_flow, dim=2)
83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
84 | return up_flow.reshape(N, 2, 8*H, 8*W)
85 |
86 |
87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
88 | """ Estimate optical flow between pair of frames """
89 |
90 | image1 = 2 * (image1 / 255.0) - 1.0
91 | image2 = 2 * (image2 / 255.0) - 1.0
92 |
93 | image1 = image1.contiguous()
94 | image2 = image2.contiguous()
95 |
96 | hdim = self.hidden_dim
97 | cdim = self.context_dim
98 |
99 | # run the feature network
100 | with autocast(enabled=self.args.mixed_precision):
101 | fmap1, fmap2 = self.fnet([image1, image2])
102 |
103 | fmap1 = fmap1.float()
104 | fmap2 = fmap2.float()
105 | if self.args.alternate_corr:
106 | corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
107 | else:
108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
109 |
110 | # run the context network
111 | with autocast(enabled=self.args.mixed_precision):
112 | cnet = self.cnet(image1)
113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
114 | net = torch.tanh(net)
115 | inp = torch.relu(inp)
116 |
117 | coords0, coords1 = self.initialize_flow(image1)
118 |
119 | if flow_init is not None:
120 | coords1 = coords1 + flow_init
121 |
122 | flow_predictions = []
123 | for itr in range(iters):
124 | coords1 = coords1.detach()
125 | corr = corr_fn(coords1) # index correlation volume
126 |
127 | flow = coords1 - coords0
128 | with autocast(enabled=self.args.mixed_precision):
129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
130 |
131 | # F(t+1) = F(t) + \Delta(t)
132 | coords1 = coords1 + delta_flow
133 |
134 | # upsample predictions
135 | if up_mask is None:
136 | flow_up = upflow8(coords1 - coords0)
137 | else:
138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
139 |
140 | flow_predictions.append(flow_up)
141 |
142 | if test_mode:
143 | return coords1 - coords0, flow_up
144 |
145 | return flow_predictions
146 |
--------------------------------------------------------------------------------
/utils/RAFT/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/utils/RAFT/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .flow_viz import flow_to_image
2 | from .frame_utils import writeFlow
3 |
--------------------------------------------------------------------------------
/utils/RAFT/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/utils/RAFT/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/utils/RAFT/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/utils/RAFT/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd):
75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/utils/colmap_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
31 |
32 | import os
33 | import sys
34 | import collections
35 | import numpy as np
36 | import struct
37 |
38 |
39 | CameraModel = collections.namedtuple(
40 | "CameraModel", ["model_id", "model_name", "num_params"])
41 | Camera = collections.namedtuple(
42 | "Camera", ["id", "model", "width", "height", "params"])
43 | BaseImage = collections.namedtuple(
44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
45 | Point3D = collections.namedtuple(
46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
47 |
48 | class Image(BaseImage):
49 | def qvec2rotmat(self):
50 | return qvec2rotmat(self.qvec)
51 |
52 |
53 | CAMERA_MODELS = {
54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
61 | CameraModel(model_id=7, model_name="FOV", num_params=5),
62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
65 | }
66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
67 | for camera_model in CAMERA_MODELS])
68 |
69 |
70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
71 | """Read and unpack the next bytes from a binary file.
72 | :param fid:
73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
75 | :param endian_character: Any of {@, =, <, >, !}
76 | :return: Tuple of read and unpacked values.
77 | """
78 | data = fid.read(num_bytes)
79 | return struct.unpack(endian_character + format_char_sequence, data)
80 |
81 |
82 | def read_cameras_text(path):
83 | """
84 | see: src/base/reconstruction.cc
85 | void Reconstruction::WriteCamerasText(const std::string& path)
86 | void Reconstruction::ReadCamerasText(const std::string& path)
87 | """
88 | cameras = {}
89 | with open(path, "r") as fid:
90 | while True:
91 | line = fid.readline()
92 | if not line:
93 | break
94 | line = line.strip()
95 | if len(line) > 0 and line[0] != "#":
96 | elems = line.split()
97 | camera_id = int(elems[0])
98 | model = elems[1]
99 | width = int(elems[2])
100 | height = int(elems[3])
101 | params = np.array(tuple(map(float, elems[4:])))
102 | cameras[camera_id] = Camera(id=camera_id, model=model,
103 | width=width, height=height,
104 | params=params)
105 | return cameras
106 |
107 |
108 | def read_cameras_binary(path_to_model_file):
109 | """
110 | see: src/base/reconstruction.cc
111 | void Reconstruction::WriteCamerasBinary(const std::string& path)
112 | void Reconstruction::ReadCamerasBinary(const std::string& path)
113 | """
114 | cameras = {}
115 | with open(path_to_model_file, "rb") as fid:
116 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
117 | for camera_line_index in range(num_cameras):
118 | camera_properties = read_next_bytes(
119 | fid, num_bytes=24, format_char_sequence="iiQQ")
120 | camera_id = camera_properties[0]
121 | model_id = camera_properties[1]
122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
123 | width = camera_properties[2]
124 | height = camera_properties[3]
125 | num_params = CAMERA_MODEL_IDS[model_id].num_params
126 | params = read_next_bytes(fid, num_bytes=8*num_params,
127 | format_char_sequence="d"*num_params)
128 | cameras[camera_id] = Camera(id=camera_id,
129 | model=model_name,
130 | width=width,
131 | height=height,
132 | params=np.array(params))
133 | assert len(cameras) == num_cameras
134 | return cameras
135 |
136 |
137 | def read_images_text(path):
138 | """
139 | see: src/base/reconstruction.cc
140 | void Reconstruction::ReadImagesText(const std::string& path)
141 | void Reconstruction::WriteImagesText(const std::string& path)
142 | """
143 | images = {}
144 | with open(path, "r") as fid:
145 | while True:
146 | line = fid.readline()
147 | if not line:
148 | break
149 | line = line.strip()
150 | if len(line) > 0 and line[0] != "#":
151 | elems = line.split()
152 | image_id = int(elems[0])
153 | qvec = np.array(tuple(map(float, elems[1:5])))
154 | tvec = np.array(tuple(map(float, elems[5:8])))
155 | camera_id = int(elems[8])
156 | image_name = elems[9]
157 | elems = fid.readline().split()
158 | xys = np.column_stack([tuple(map(float, elems[0::3])),
159 | tuple(map(float, elems[1::3]))])
160 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
161 | images[image_id] = Image(
162 | id=image_id, qvec=qvec, tvec=tvec,
163 | camera_id=camera_id, name=image_name,
164 | xys=xys, point3D_ids=point3D_ids)
165 | return images
166 |
167 |
168 | def read_images_binary(path_to_model_file):
169 | """
170 | see: src/base/reconstruction.cc
171 | void Reconstruction::ReadImagesBinary(const std::string& path)
172 | void Reconstruction::WriteImagesBinary(const std::string& path)
173 | """
174 | images = {}
175 | with open(path_to_model_file, "rb") as fid:
176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
177 | for image_index in range(num_reg_images):
178 | binary_image_properties = read_next_bytes(
179 | fid, num_bytes=64, format_char_sequence="idddddddi")
180 | image_id = binary_image_properties[0]
181 | qvec = np.array(binary_image_properties[1:5])
182 | tvec = np.array(binary_image_properties[5:8])
183 | camera_id = binary_image_properties[8]
184 | image_name = ""
185 | current_char = read_next_bytes(fid, 1, "c")[0]
186 | while current_char != b"\x00": # look for the ASCII 0 entry
187 | image_name += current_char.decode("utf-8")
188 | current_char = read_next_bytes(fid, 1, "c")[0]
189 | num_points2D = read_next_bytes(fid, num_bytes=8,
190 | format_char_sequence="Q")[0]
191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
192 | format_char_sequence="ddq"*num_points2D)
193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
194 | tuple(map(float, x_y_id_s[1::3]))])
195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
196 | images[image_id] = Image(
197 | id=image_id, qvec=qvec, tvec=tvec,
198 | camera_id=camera_id, name=image_name,
199 | xys=xys, point3D_ids=point3D_ids)
200 | return images
201 |
202 |
203 | def read_points3D_text(path):
204 | """
205 | see: src/base/reconstruction.cc
206 | void Reconstruction::ReadPoints3DText(const std::string& path)
207 | void Reconstruction::WritePoints3DText(const std::string& path)
208 | """
209 | points3D = {}
210 | with open(path, "r") as fid:
211 | while True:
212 | line = fid.readline()
213 | if not line:
214 | break
215 | line = line.strip()
216 | if len(line) > 0 and line[0] != "#":
217 | elems = line.split()
218 | point3D_id = int(elems[0])
219 | xyz = np.array(tuple(map(float, elems[1:4])))
220 | rgb = np.array(tuple(map(int, elems[4:7])))
221 | error = float(elems[7])
222 | image_ids = np.array(tuple(map(int, elems[8::2])))
223 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
225 | error=error, image_ids=image_ids,
226 | point2D_idxs=point2D_idxs)
227 | return points3D
228 |
229 |
230 | def read_points3d_binary(path_to_model_file):
231 | """
232 | see: src/base/reconstruction.cc
233 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
234 | void Reconstruction::WritePoints3DBinary(const std::string& path)
235 | """
236 | points3D = {}
237 | with open(path_to_model_file, "rb") as fid:
238 | num_points = read_next_bytes(fid, 8, "Q")[0]
239 | for point_line_index in range(num_points):
240 | binary_point_line_properties = read_next_bytes(
241 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
242 | point3D_id = binary_point_line_properties[0]
243 | xyz = np.array(binary_point_line_properties[1:4])
244 | rgb = np.array(binary_point_line_properties[4:7])
245 | error = np.array(binary_point_line_properties[7])
246 | track_length = read_next_bytes(
247 | fid, num_bytes=8, format_char_sequence="Q")[0]
248 | track_elems = read_next_bytes(
249 | fid, num_bytes=8*track_length,
250 | format_char_sequence="ii"*track_length)
251 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
253 | points3D[point3D_id] = Point3D(
254 | id=point3D_id, xyz=xyz, rgb=rgb,
255 | error=error, image_ids=image_ids,
256 | point2D_idxs=point2D_idxs)
257 | return points3D
258 |
259 |
260 | def read_model(path, ext):
261 | if ext == ".txt":
262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
263 | images = read_images_text(os.path.join(path, "images" + ext))
264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
265 | else:
266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
267 | images = read_images_binary(os.path.join(path, "images" + ext))
268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
269 | return cameras, images, points3D
270 |
271 |
272 | def qvec2rotmat(qvec):
273 | return np.array([
274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
283 |
284 |
285 | def rotmat2qvec(R):
286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
287 | K = np.array([
288 | [Rxx - Ryy - Rzz, 0, 0, 0],
289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
292 | eigvals, eigvecs = np.linalg.eigh(K)
293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
294 | if qvec[0] < 0:
295 | qvec *= -1
296 | return qvec
297 |
298 |
299 | def main():
300 | if len(sys.argv) != 3:
301 | print("Usage: python read_model.py path/to/model/folder [.txt,.bin]")
302 | return
303 |
304 | cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2])
305 |
306 | print("num_cameras:", len(cameras))
307 | print("num_images:", len(images))
308 | print("num_points3D:", len(points3D))
309 |
310 |
311 | if __name__ == "__main__":
312 | main()
313 |
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import lpips
4 | import torch
5 | import numpy as np
6 | from skimage.metrics import structural_similarity
7 |
8 |
9 | def im2tensor(img):
10 | return torch.Tensor(img.transpose(2, 0, 1) / 127.5 - 1.0)[None, ...]
11 |
12 |
13 | def create_dir(dir):
14 | if not os.path.exists(dir):
15 | os.makedirs(dir)
16 |
17 |
18 | def readimage(data_dir, sequence, time, method):
19 | img = cv2.imread(os.path.join(data_dir, method, sequence, 'v000_t' + str(time).zfill(3) + '.png'))
20 | return img
21 |
22 |
23 | def calculate_metrics(data_dir, sequence, methods, lpips_loss):
24 |
25 | PSNRs = np.zeros((len(methods)))
26 | SSIMs = np.zeros((len(methods)))
27 | LPIPSs = np.zeros((len(methods)))
28 |
29 | nFrame = 0
30 |
31 | # Yoon's results do not include v000_t000 and v000_t011. Omit these two
32 | # frames if evaluating Yoon's method.
33 | if 'Yoon' in methods:
34 | time_start = 1
35 | time_end = 11
36 | else:
37 | time_start = 0
38 | time_end = 12
39 |
40 | for time in range(time_start, time_end): # Fix view v0, change time
41 |
42 | nFrame += 1
43 |
44 | img_true = readimage(data_dir, sequence, time, 'gt')
45 |
46 | for method_idx, method in enumerate(methods):
47 |
48 | if 'Yoon' in methods and sequence == 'Truck' and time == 10:
49 | break
50 |
51 | img = readimage(data_dir, sequence, time, method)
52 | PSNR = cv2.PSNR(img_true, img)
53 | SSIM = structural_similarity(img_true, img, multichannel=True)
54 | LPIPS = lpips_loss.forward(im2tensor(img_true), im2tensor(img)).item()
55 |
56 | PSNRs[method_idx] += PSNR
57 | SSIMs[method_idx] += SSIM
58 | LPIPSs[method_idx] += LPIPS
59 |
60 | PSNRs = PSNRs / nFrame
61 | SSIMs = SSIMs / nFrame
62 | LPIPSs = LPIPSs / nFrame
63 |
64 | return PSNRs, SSIMs, LPIPSs
65 |
66 |
67 | if __name__ == '__main__':
68 |
69 | lpips_loss = lpips.LPIPS(net='alex') # best forward scores
70 | data_dir = '../results'
71 | sequences = ['Balloon1', 'Balloon2', 'Jumping', 'Playground', 'Skating', 'Truck', 'Umbrella']
72 | # methods = ['NeRF', 'NeRF_t', 'Yoon', 'NR', 'NSFF', 'Ours']
73 | methods = ['NeRF', 'NeRF_t', 'NR', 'NSFF', 'Ours']
74 |
75 | PSNRs_total = np.zeros((len(methods)))
76 | SSIMs_total = np.zeros((len(methods)))
77 | LPIPSs_total = np.zeros((len(methods)))
78 | for sequence in sequences:
79 | print(sequence)
80 | PSNRs, SSIMs, LPIPSs = calculate_metrics(data_dir, sequence, methods, lpips_loss)
81 | for method_idx, method in enumerate(methods):
82 | print(method.ljust(7) + '%.2f'%(PSNRs[method_idx]) + ' / %.4f'%(SSIMs[method_idx]) + ' / %.3f'%(LPIPSs[method_idx]))
83 |
84 | PSNRs_total += PSNRs
85 | SSIMs_total += SSIMs
86 | LPIPSs_total += LPIPSs
87 |
88 | PSNRs_total = PSNRs_total / len(sequences)
89 | SSIMs_total = SSIMs_total / len(sequences)
90 | LPIPSs_total = LPIPSs_total / len(sequences)
91 | print('Avg.')
92 | for method_idx, method in enumerate(methods):
93 | print(method.ljust(7) + '%.2f'%(PSNRs_total[method_idx]) + ' / %.4f'%(SSIMs_total[method_idx]) + ' / %.3f'%(LPIPSs_total[method_idx]))
94 |
--------------------------------------------------------------------------------
/utils/flow_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 | from os.path import *
6 | UNKNOWN_FLOW_THRESH = 1e7
7 |
8 | def flow_to_image(flow, global_max=None):
9 | """
10 | Convert flow into middlebury color code image
11 | :param flow: optical flow map
12 | :return: optical flow image in middlebury color
13 | """
14 | u = flow[:, :, 0]
15 | v = flow[:, :, 1]
16 |
17 | maxu = -999.
18 | maxv = -999.
19 | minu = 999.
20 | minv = 999.
21 |
22 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
23 | u[idxUnknow] = 0
24 | v[idxUnknow] = 0
25 |
26 | maxu = max(maxu, np.max(u))
27 | minu = min(minu, np.min(u))
28 |
29 | maxv = max(maxv, np.max(v))
30 | minv = min(minv, np.min(v))
31 |
32 | rad = np.sqrt(u ** 2 + v ** 2)
33 |
34 | if global_max == None:
35 | maxrad = max(-1, np.max(rad))
36 | else:
37 | maxrad = global_max
38 |
39 | u = u/(maxrad + np.finfo(float).eps)
40 | v = v/(maxrad + np.finfo(float).eps)
41 |
42 | img = compute_color(u, v)
43 |
44 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
45 | img[idx] = 0
46 |
47 | return np.uint8(img)
48 |
49 |
50 | def compute_color(u, v):
51 | """
52 | compute optical flow color map
53 | :param u: optical flow horizontal map
54 | :param v: optical flow vertical map
55 | :return: optical flow in color code
56 | """
57 | [h, w] = u.shape
58 | img = np.zeros([h, w, 3])
59 | nanIdx = np.isnan(u) | np.isnan(v)
60 | u[nanIdx] = 0
61 | v[nanIdx] = 0
62 |
63 | colorwheel = make_color_wheel()
64 | ncols = np.size(colorwheel, 0)
65 |
66 | rad = np.sqrt(u**2+v**2)
67 |
68 | a = np.arctan2(-v, -u) / np.pi
69 |
70 | fk = (a+1) / 2 * (ncols - 1) + 1
71 |
72 | k0 = np.floor(fk).astype(int)
73 |
74 | k1 = k0 + 1
75 | k1[k1 == ncols+1] = 1
76 | f = fk - k0
77 |
78 | for i in range(0, np.size(colorwheel,1)):
79 | tmp = colorwheel[:, i]
80 | col0 = tmp[k0-1] / 255
81 | col1 = tmp[k1-1] / 255
82 | col = (1-f) * col0 + f * col1
83 |
84 | idx = rad <= 1
85 | col[idx] = 1-rad[idx]*(1-col[idx])
86 | notidx = np.logical_not(idx)
87 |
88 | col[notidx] *= 0.75
89 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
90 |
91 | return img
92 |
93 |
94 | def make_color_wheel():
95 | """
96 | Generate color wheel according Middlebury color code
97 | :return: Color wheel
98 | """
99 | RY = 15
100 | YG = 6
101 | GC = 4
102 | CB = 11
103 | BM = 13
104 | MR = 6
105 |
106 | ncols = RY + YG + GC + CB + BM + MR
107 |
108 | colorwheel = np.zeros([ncols, 3])
109 |
110 | col = 0
111 |
112 | # RY
113 | colorwheel[0:RY, 0] = 255
114 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
115 | col += RY
116 |
117 | # YG
118 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
119 | colorwheel[col:col+YG, 1] = 255
120 | col += YG
121 |
122 | # GC
123 | colorwheel[col:col+GC, 1] = 255
124 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
125 | col += GC
126 |
127 | # CB
128 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
129 | colorwheel[col:col+CB, 2] = 255
130 | col += CB
131 |
132 | # BM
133 | colorwheel[col:col+BM, 2] = 255
134 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
135 | col += + BM
136 |
137 | # MR
138 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
139 | colorwheel[col:col+MR, 0] = 255
140 |
141 | return colorwheel
142 |
143 |
144 | def resize_flow(flow, H_new, W_new):
145 | H_old, W_old = flow.shape[0:2]
146 | flow_resized = cv2.resize(flow, (W_new, H_new), interpolation=cv2.INTER_LINEAR)
147 | flow_resized[:, :, 0] *= H_new / H_old
148 | flow_resized[:, :, 1] *= W_new / W_old
149 | return flow_resized
150 |
151 |
152 |
153 | def warp_flow(img, flow):
154 | h, w = flow.shape[:2]
155 | flow_new = flow.copy()
156 | flow_new[:,:,0] += np.arange(w)
157 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
158 |
159 | res = cv2.remap(img, flow_new, None,
160 | cv2.INTER_CUBIC,
161 | borderMode=cv2.BORDER_CONSTANT)
162 | return res
163 |
164 |
165 | def consistCheck(flowB, flowF):
166 |
167 | # |--------------------| |--------------------|
168 | # | y | | v |
169 | # | x * | | u * |
170 | # | | | |
171 | # |--------------------| |--------------------|
172 |
173 | # sub: numPix * [y x t]
174 |
175 | imgH, imgW, _ = flowF.shape
176 |
177 | (fy, fx) = np.mgrid[0 : imgH, 0 : imgW].astype(np.float32)
178 | fxx = fx + flowB[:, :, 0] # horizontal
179 | fyy = fy + flowB[:, :, 1] # vertical
180 |
181 | u = (fxx + cv2.remap(flowF[:, :, 0], fxx, fyy, cv2.INTER_LINEAR) - fx)
182 | v = (fyy + cv2.remap(flowF[:, :, 1], fxx, fyy, cv2.INTER_LINEAR) - fy)
183 | BFdiff = (u ** 2 + v ** 2) ** 0.5
184 |
185 | return BFdiff, np.stack((u, v), axis=2)
186 |
187 |
188 | def read_optical_flow(basedir, img_i_name, read_fwd):
189 | flow_dir = os.path.join(basedir, 'flow')
190 |
191 | fwd_flow_path = os.path.join(flow_dir, '%s_fwd.npz'%img_i_name[:-4])
192 | bwd_flow_path = os.path.join(flow_dir, '%s_bwd.npz'%img_i_name[:-4])
193 |
194 | if read_fwd:
195 | fwd_data = np.load(fwd_flow_path)
196 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
197 | return fwd_flow, fwd_mask
198 | else:
199 | bwd_data = np.load(bwd_flow_path)
200 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
201 | return bwd_flow, bwd_mask
202 |
203 |
204 | def compute_epipolar_distance(T_21, K, p_1, p_2):
205 | R_21 = T_21[:3, :3]
206 | t_21 = T_21[:3, 3]
207 |
208 | E_mat = np.dot(skew(t_21), R_21)
209 | # compute bearing vector
210 | inv_K = np.linalg.inv(K)
211 |
212 | F_mat = np.dot(np.dot(inv_K.T, E_mat), inv_K)
213 |
214 | l_2 = np.dot(F_mat, p_1)
215 | algebric_e_distance = np.sum(p_2 * l_2, axis=0)
216 | n_term = np.sqrt(l_2[0, :]**2 + l_2[1, :]**2) + 1e-8
217 | geometric_e_distance = algebric_e_distance/n_term
218 | geometric_e_distance = np.abs(geometric_e_distance)
219 |
220 | return geometric_e_distance
221 |
222 |
223 | def skew(x):
224 | return np.array([[0, -x[2], x[1]],
225 | [x[2], 0, -x[0]],
226 | [-x[1], x[0], 0]])
227 |
--------------------------------------------------------------------------------
/utils/generate_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 | import glob
5 | import torch
6 | import torchvision
7 | import skimage.morphology
8 | import argparse
9 |
10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def create_dir(dir):
14 | if not os.path.exists(dir):
15 | os.makedirs(dir)
16 |
17 |
18 | def multi_view_multi_time(args):
19 | """
20 | Generating multi view multi time data
21 | """
22 |
23 | Maskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()
24 | threshold = 0.5
25 |
26 | videoname, ext = os.path.splitext(os.path.basename(args.videopath))
27 |
28 | imgs = []
29 | reader = imageio.get_reader(args.videopath)
30 | for i, im in enumerate(reader):
31 | imgs.append(im)
32 |
33 | imgs = np.array(imgs)
34 | num_frames, H, W, _ = imgs.shape
35 | imgs = imgs[::int(np.ceil(num_frames / 100))]
36 |
37 | create_dir(os.path.join(args.data_dir, videoname, 'images'))
38 | create_dir(os.path.join(args.data_dir, videoname, 'images_colmap'))
39 | create_dir(os.path.join(args.data_dir, videoname, 'background_mask'))
40 |
41 | for idx, img in enumerate(imgs):
42 | print(idx)
43 | imageio.imwrite(os.path.join(args.data_dir, videoname, 'images', str(idx).zfill(3) + '.png'), img)
44 | imageio.imwrite(os.path.join(args.data_dir, videoname, 'images_colmap', str(idx).zfill(3) + '.jpg'), img)
45 |
46 | # Get coarse background mask
47 | img = torchvision.transforms.functional.to_tensor(img).to(device)
48 | background_mask = torch.FloatTensor(H, W).fill_(1.0).to(device)
49 | objPredictions = Maskrcnn([img])[0]
50 |
51 | for intMask in range(len(objPredictions['masks'])):
52 | if objPredictions['scores'][intMask].item() > threshold:
53 | if objPredictions['labels'][intMask].item() == 1: # person
54 | background_mask[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
55 |
56 | background_mask_np = ((background_mask.cpu().numpy() > 0.1) * 255).astype(np.uint8)
57 | imageio.imwrite(os.path.join(args.data_dir, videoname, 'background_mask', str(idx).zfill(3) + '.jpg.png'), background_mask_np)
58 |
59 |
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument("--videopath", type=str,
63 | help='video path')
64 | parser.add_argument("--data_dir", type=str, default='../data/',
65 | help='where to store data')
66 |
67 | args = parser.parse_args()
68 |
69 | multi_view_multi_time(args)
70 |
--------------------------------------------------------------------------------
/utils/generate_depth.py:
--------------------------------------------------------------------------------
1 | """Compute depth maps for images in the input folder.
2 | """
3 | import os
4 | import cv2
5 | import glob
6 | import torch
7 | import argparse
8 | import numpy as np
9 |
10 | from torchvision.transforms import Compose
11 | from midas.midas_net import MidasNet
12 | from midas.transforms import Resize, NormalizeImage, PrepareForNet
13 |
14 |
15 | def create_dir(dir):
16 | if not os.path.exists(dir):
17 | os.makedirs(dir)
18 |
19 |
20 | def read_image(path):
21 | """Read image and output RGB image (0-1).
22 |
23 | Args:
24 | path (str): path to file
25 |
26 | Returns:
27 | array: RGB image (0-1)
28 | """
29 | img = cv2.imread(path)
30 |
31 | if img.ndim == 2:
32 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
33 |
34 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
35 |
36 | return img
37 |
38 |
39 | def run(input_path, output_path, output_img_path, model_path):
40 | """Run MonoDepthNN to compute depth maps.
41 | Args:
42 | input_path (str): path to input folder
43 | output_path (str): path to output folder
44 | model_path (str): path to saved model
45 | """
46 | print("initialize")
47 |
48 | # select device
49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50 | print("device: %s" % device)
51 |
52 | # load network
53 | model = MidasNet(model_path, non_negative=True)
54 | sh = cv2.imread(sorted(glob.glob(os.path.join(input_path, "*")))[0]).shape
55 | net_w, net_h = sh[1], sh[0]
56 |
57 | resize_mode="upper_bound"
58 |
59 | transform = Compose(
60 | [
61 | Resize(
62 | net_w,
63 | net_h,
64 | resize_target=None,
65 | keep_aspect_ratio=True,
66 | ensure_multiple_of=32,
67 | resize_method=resize_mode,
68 | image_interpolation_method=cv2.INTER_CUBIC,
69 | ),
70 | NormalizeImage(mean=[0.485, 0.456, 0.406],
71 | std=[0.229, 0.224, 0.225]),
72 | PrepareForNet(),
73 | ]
74 | )
75 |
76 | model.eval()
77 | model.to(device)
78 |
79 | # get input
80 | img_names = sorted(glob.glob(os.path.join(input_path, "*")))
81 | num_images = len(img_names)
82 |
83 | # create output folder
84 | os.makedirs(output_path, exist_ok=True)
85 |
86 | print("start processing")
87 |
88 | for ind, img_name in enumerate(img_names):
89 |
90 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
91 |
92 | # input
93 | img = read_image(img_name)
94 | img_input = transform({"image": img})["image"]
95 |
96 | # compute
97 | with torch.no_grad():
98 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
99 | prediction = model.forward(sample)
100 | prediction = (
101 | torch.nn.functional.interpolate(
102 | prediction.unsqueeze(1),
103 | size=[net_h, net_w],
104 | mode="bicubic",
105 | align_corners=False,
106 | )
107 | .squeeze()
108 | .cpu()
109 | .numpy()
110 | )
111 |
112 | # output
113 | filename = os.path.join(
114 | output_path, os.path.splitext(os.path.basename(img_name))[0]
115 | )
116 |
117 | print(filename + '.npy')
118 | np.save(filename + '.npy', prediction.astype(np.float32))
119 |
120 | depth_min = prediction.min()
121 | depth_max = prediction.max()
122 |
123 | max_val = (2**(8*2))-1
124 |
125 | if depth_max - depth_min > np.finfo("float").eps:
126 | out = max_val * (prediction - depth_min) / (depth_max - depth_min)
127 | else:
128 | out = np.zeros(prediction.shape, dtype=prediction.type)
129 |
130 | cv2.imwrite(os.path.join(output_img_path, os.path.splitext(os.path.basename(img_name))[0] + '.png'), out.astype("uint16"))
131 |
132 |
133 | if __name__ == "__main__":
134 | parser = argparse.ArgumentParser()
135 | parser.add_argument("--dataset_path", type=str, help='Dataset path')
136 | parser.add_argument('--model', help="restore midas checkpoint")
137 | args = parser.parse_args()
138 |
139 | input_path = os.path.join(args.dataset_path, 'images')
140 | output_path = os.path.join(args.dataset_path, 'disp')
141 | output_img_path = os.path.join(args.dataset_path, 'disp_png')
142 | create_dir(output_path)
143 | create_dir(output_img_path)
144 |
145 | # set torch options
146 | torch.backends.cudnn.enabled = True
147 | torch.backends.cudnn.benchmark = True
148 |
149 | # compute depth maps
150 | run(input_path, output_path, output_img_path, args.model)
151 |
--------------------------------------------------------------------------------
/utils/generate_flow.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import cv2
4 | import glob
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 |
9 | from RAFT.raft import RAFT
10 | from RAFT.utils import flow_viz
11 | from RAFT.utils.utils import InputPadder
12 |
13 | from flow_utils import *
14 |
15 | DEVICE = 'cuda'
16 |
17 |
18 | def create_dir(dir):
19 | if not os.path.exists(dir):
20 | os.makedirs(dir)
21 |
22 |
23 | def load_image(imfile):
24 | img = np.array(Image.open(imfile)).astype(np.uint8)
25 | img = torch.from_numpy(img).permute(2, 0, 1).float()
26 | return img[None].to(DEVICE)
27 |
28 |
29 | def warp_flow(img, flow):
30 | h, w = flow.shape[:2]
31 | flow_new = flow.copy()
32 | flow_new[:,:,0] += np.arange(w)
33 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
34 |
35 | res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
36 | return res
37 |
38 |
39 | def compute_fwdbwd_mask(fwd_flow, bwd_flow):
40 | alpha_1 = 0.5
41 | alpha_2 = 0.5
42 |
43 | bwd2fwd_flow = warp_flow(bwd_flow, fwd_flow)
44 | fwd_lr_error = np.linalg.norm(fwd_flow + bwd2fwd_flow, axis=-1)
45 | fwd_mask = fwd_lr_error < alpha_1 * (np.linalg.norm(fwd_flow, axis=-1) \
46 | + np.linalg.norm(bwd2fwd_flow, axis=-1)) + alpha_2
47 |
48 | fwd2bwd_flow = warp_flow(fwd_flow, bwd_flow)
49 | bwd_lr_error = np.linalg.norm(bwd_flow + fwd2bwd_flow, axis=-1)
50 |
51 | bwd_mask = bwd_lr_error < alpha_1 * (np.linalg.norm(bwd_flow, axis=-1) \
52 | + np.linalg.norm(fwd2bwd_flow, axis=-1)) + alpha_2
53 |
54 | return fwd_mask, bwd_mask
55 |
56 | def run(args, input_path, output_path, output_img_path):
57 | model = torch.nn.DataParallel(RAFT(args))
58 | model.load_state_dict(torch.load(args.model))
59 |
60 | model = model.module
61 | model.to(DEVICE)
62 | model.eval()
63 |
64 | with torch.no_grad():
65 | images = glob.glob(os.path.join(input_path, '*.png')) + \
66 | glob.glob(os.path.join(input_path, '*.jpg'))
67 |
68 | images = sorted(images)
69 | for i in range(len(images) - 1):
70 | print(i)
71 | image1 = load_image(images[i])
72 | image2 = load_image(images[i + 1])
73 |
74 | padder = InputPadder(image1.shape)
75 | image1, image2 = padder.pad(image1, image2)
76 |
77 | _, flow_fwd = model(image1, image2, iters=20, test_mode=True)
78 | _, flow_bwd = model(image2, image1, iters=20, test_mode=True)
79 |
80 | flow_fwd = padder.unpad(flow_fwd[0]).cpu().numpy().transpose(1, 2, 0)
81 | flow_bwd = padder.unpad(flow_bwd[0]).cpu().numpy().transpose(1, 2, 0)
82 |
83 | mask_fwd, mask_bwd = compute_fwdbwd_mask(flow_fwd, flow_bwd)
84 |
85 | # Save flow
86 | np.savez(os.path.join(output_path, '%03d_fwd.npz'%i), flow=flow_fwd, mask=mask_fwd)
87 | np.savez(os.path.join(output_path, '%03d_bwd.npz'%(i + 1)), flow=flow_bwd, mask=mask_bwd)
88 |
89 | # Save flow_img
90 | Image.fromarray(flow_viz.flow_to_image(flow_fwd)).save(os.path.join(output_img_path, '%03d_fwd.png'%i))
91 | Image.fromarray(flow_viz.flow_to_image(flow_bwd)).save(os.path.join(output_img_path, '%03d_bwd.png'%(i + 1)))
92 |
93 | Image.fromarray(mask_fwd).save(os.path.join(output_img_path, '%03d_fwd_mask.png'%i))
94 | Image.fromarray(mask_bwd).save(os.path.join(output_img_path, '%03d_bwd_mask.png'%(i + 1)))
95 |
96 |
97 | if __name__ == '__main__':
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument("--dataset_path", type=str, help='Dataset path')
100 | parser.add_argument('--model', help="restore RAFT checkpoint")
101 | parser.add_argument('--small', action='store_true', help='use small model')
102 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
103 | args = parser.parse_args()
104 |
105 | input_path = os.path.join(args.dataset_path, 'images')
106 | output_path = os.path.join(args.dataset_path, 'flow')
107 | output_img_path = os.path.join(args.dataset_path, 'flow_png')
108 | create_dir(output_path)
109 | create_dir(output_img_path)
110 |
111 | run(args, input_path, output_path, output_img_path)
112 |
--------------------------------------------------------------------------------
/utils/generate_motion_mask.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import PIL
4 | import glob
5 | import torch
6 | import argparse
7 | import numpy as np
8 |
9 | from colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
10 |
11 | import skimage.morphology
12 | import torchvision
13 | from flow_utils import read_optical_flow, compute_epipolar_distance, skew
14 |
15 |
16 |
17 | def create_dir(dir):
18 | if not os.path.exists(dir):
19 | os.makedirs(dir)
20 |
21 |
22 | def extract_poses(im):
23 | R = im.qvec2rotmat()
24 | t = im.tvec.reshape([3,1])
25 | bottom = np.array([0,0,0,1.]).reshape([1,4])
26 |
27 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
28 |
29 | return m
30 |
31 |
32 | def load_colmap_data(realdir):
33 |
34 | camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')
35 | camdata = read_cameras_binary(camerasfile)
36 |
37 | list_of_keys = list(camdata.keys())
38 | cam = camdata[list_of_keys[0]]
39 | print( 'Cameras', len(cam))
40 |
41 | h, w, f = cam.height, cam.width, cam.params[0]
42 | # w, h, f = factor * w, factor * h, factor * f
43 | hwf = np.array([h,w,f]).reshape([3,1])
44 |
45 | imagesfile = os.path.join(realdir, 'sparse/0/images.bin')
46 | imdata = read_images_binary(imagesfile)
47 |
48 | w2c_mats = []
49 | # bottom = np.array([0,0,0,1.]).reshape([1,4])
50 |
51 | names = [imdata[k].name for k in imdata]
52 | img_keys = [k for k in imdata]
53 |
54 | print( 'Images #', len(names))
55 | perm = np.argsort(names)
56 |
57 | return imdata, perm, img_keys, hwf
58 |
59 |
60 | def run_maskrcnn(model, img_path, intWidth=1024, intHeight=576):
61 |
62 | # intHeight = 576
63 | # intWidth = 1024
64 |
65 | threshold = 0.5
66 |
67 | o_image = PIL.Image.open(img_path)
68 | image = o_image.resize((intWidth, intHeight), PIL.Image.ANTIALIAS)
69 |
70 | image_tensor = torchvision.transforms.functional.to_tensor(image).cuda()
71 |
72 | tenHumans = torch.FloatTensor(intHeight, intWidth).fill_(1.0).cuda()
73 |
74 | objPredictions = model([image_tensor])[0]
75 |
76 | for intMask in range(objPredictions['masks'].size(0)):
77 | if objPredictions['scores'][intMask].item() > threshold:
78 | if objPredictions['labels'][intMask].item() == 1: # person
79 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
80 |
81 | if objPredictions['labels'][intMask].item() == 4: # motorcycle
82 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
83 |
84 | if objPredictions['labels'][intMask].item() == 2: # bicycle
85 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
86 |
87 | if objPredictions['labels'][intMask].item() == 8: # truck
88 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
89 |
90 | if objPredictions['labels'][intMask].item() == 28: # umbrella
91 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
92 |
93 | if objPredictions['labels'][intMask].item() == 17: # cat
94 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
95 |
96 | if objPredictions['labels'][intMask].item() == 18: # dog
97 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
98 |
99 | if objPredictions['labels'][intMask].item() == 36: # snowboard
100 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
101 |
102 | if objPredictions['labels'][intMask].item() == 41: # skateboard
103 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
104 |
105 | npyMask = skimage.morphology.erosion(tenHumans.cpu().numpy(),
106 | skimage.morphology.disk(1))
107 | npyMask = ((npyMask < 1e-3) * 255.0).clip(0.0, 255.0).astype(np.uint8)
108 | return npyMask
109 |
110 |
111 | def motion_segmentation(basedir, threshold,
112 | input_semantic_w=1024,
113 | input_semantic_h=576):
114 |
115 | points3dfile = os.path.join(basedir, 'sparse/0/points3D.bin')
116 | pts3d = read_points3d_binary(points3dfile)
117 |
118 | img_dir = glob.glob(basedir + '/images_colmap')[0]
119 | img0 = glob.glob(glob.glob(img_dir)[0] + '/*jpg')[0]
120 | shape_0 = cv2.imread(img0).shape
121 |
122 | resized_height, resized_width = shape_0[0], shape_0[1]
123 |
124 | imdata, perm, img_keys, hwf = load_colmap_data(basedir)
125 | scale_x, scale_y = resized_width / float(hwf[1]), resized_height / float(hwf[0])
126 |
127 | K = np.eye(3)
128 | K[0, 0] = hwf[2]
129 | K[0, 2] = hwf[1] / 2.
130 | K[1, 1] = hwf[2]
131 | K[1, 2] = hwf[0] / 2.
132 |
133 | xx = range(0, resized_width)
134 | yy = range(0, resized_height)
135 | xv, yv = np.meshgrid(xx, yy)
136 | p_ref = np.float32(np.stack((xv, yv), axis=-1))
137 | p_ref_h = np.reshape(p_ref, (-1, 2))
138 | p_ref_h = np.concatenate((p_ref_h, np.ones((p_ref_h.shape[0], 1))), axis=-1).T
139 |
140 | num_frames = len(perm)
141 |
142 | if os.path.isdir(os.path.join(basedir, 'images_colmap')):
143 | num_colmap_frames = len(glob.glob(os.path.join(basedir, 'images_colmap', '*.jpg')))
144 | num_data_frames = len(glob.glob(os.path.join(basedir, 'images', '*.png')))
145 |
146 | if num_colmap_frames != num_data_frames:
147 | num_frames = num_data_frames
148 |
149 |
150 | save_mask_dir = os.path.join(basedir, 'motion_segmentation')
151 | create_dir(save_mask_dir)
152 |
153 | for i in range(0, num_frames):
154 | im_prev = imdata[img_keys[perm[max(0, i - 1)]]]
155 | im_ref = imdata[img_keys[perm[i]]]
156 | im_post = imdata[img_keys[perm[min(num_frames -1, i + 1)]]]
157 |
158 | print(im_prev.name, im_ref.name, im_post.name)
159 |
160 | T_prev_G = extract_poses(im_prev)
161 | T_ref_G = extract_poses(im_ref)
162 | T_post_G = extract_poses(im_post)
163 |
164 | T_ref2prev = np.dot(T_prev_G, np.linalg.inv(T_ref_G))
165 | T_ref2post = np.dot(T_post_G, np.linalg.inv(T_ref_G))
166 | # load optical flow
167 |
168 | if i == 0:
169 | fwd_flow, _ = read_optical_flow(basedir,
170 | im_ref.name,
171 | read_fwd=True)
172 | bwd_flow = np.zeros_like(fwd_flow)
173 | elif i == num_frames - 1:
174 | bwd_flow, _ = read_optical_flow(basedir,
175 | im_ref.name,
176 | read_fwd=False)
177 | fwd_flow = np.zeros_like(bwd_flow)
178 | else:
179 | fwd_flow, _ = read_optical_flow(basedir,
180 | im_ref.name,
181 | read_fwd=True)
182 | bwd_flow, _ = read_optical_flow(basedir,
183 | im_ref.name,
184 | read_fwd=False)
185 |
186 | p_post = p_ref + fwd_flow
187 | p_post_h = np.reshape(p_post, (-1, 2))
188 | p_post_h = np.concatenate((p_post_h, np.ones((p_post_h.shape[0], 1))), axis=-1).T
189 |
190 | fwd_e_dist = compute_epipolar_distance(T_ref2post, K,
191 | p_ref_h, p_post_h)
192 | fwd_e_dist = np.reshape(fwd_e_dist, (fwd_flow.shape[0], fwd_flow.shape[1]))
193 |
194 | p_prev = p_ref + bwd_flow
195 | p_prev_h = np.reshape(p_prev, (-1, 2))
196 | p_prev_h = np.concatenate((p_prev_h, np.ones((p_prev_h.shape[0], 1))), axis=-1).T
197 |
198 | bwd_e_dist = compute_epipolar_distance(T_ref2prev, K,
199 | p_ref_h, p_prev_h)
200 | bwd_e_dist = np.reshape(bwd_e_dist, (bwd_flow.shape[0], bwd_flow.shape[1]))
201 |
202 | e_dist = np.maximum(bwd_e_dist, fwd_e_dist)
203 |
204 | motion_mask = skimage.morphology.binary_opening(e_dist > threshold, skimage.morphology.disk(1))
205 |
206 | cv2.imwrite(os.path.join(save_mask_dir, im_ref.name.replace('.jpg', '.png')), np.uint8(255 * (0. + motion_mask)))
207 |
208 | # RUN SEMANTIC SEGMENTATION
209 | img_dir = os.path.join(basedir, 'images')
210 | img_path_list = sorted(glob.glob(os.path.join(img_dir, '*.jpg'))) \
211 | + sorted(glob.glob(os.path.join(img_dir, '*.png')))
212 | semantic_mask_dir = os.path.join(basedir, 'semantic_mask')
213 | netMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()
214 | create_dir(semantic_mask_dir)
215 |
216 |
217 | for i in range(0, len(img_path_list)):
218 | img_path = img_path_list[i]
219 | img_name = img_path.split('/')[-1]
220 | semantic_mask = run_maskrcnn(netMaskrcnn, img_path,
221 | input_semantic_w,
222 | input_semantic_h)
223 | cv2.imwrite(os.path.join(semantic_mask_dir,
224 | img_name.replace('.jpg', '.png')),
225 | semantic_mask)
226 |
227 | # combine them
228 | save_mask_dir = os.path.join(basedir, 'motion_masks')
229 | create_dir(save_mask_dir)
230 |
231 | mask_dir = os.path.join(basedir, 'motion_segmentation')
232 | mask_path_list = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
233 |
234 | semantic_dir = os.path.join(basedir, 'semantic_mask')
235 |
236 | for mask_path in mask_path_list:
237 | print(mask_path)
238 |
239 | motion_mask = cv2.imread(mask_path)
240 | motion_mask = cv2.resize(motion_mask, (resized_width, resized_height),
241 | interpolation=cv2.INTER_NEAREST)
242 | motion_mask = motion_mask[:, :, 0] > 0.1
243 |
244 | # combine from motion segmentation
245 | semantic_mask = cv2.imread(os.path.join(semantic_dir, mask_path.split('/')[-1]))
246 | semantic_mask = cv2.resize(semantic_mask, (resized_width, resized_height),
247 | interpolation=cv2.INTER_NEAREST)
248 | semantic_mask = semantic_mask[:, :, 0] > 0.1
249 | motion_mask = semantic_mask | motion_mask
250 |
251 | motion_mask = skimage.morphology.dilation(motion_mask, skimage.morphology.disk(2))
252 | cv2.imwrite(os.path.join(save_mask_dir, '%s'%mask_path.split('/')[-1]),
253 | np.uint8(np.clip((motion_mask), 0, 1) * 255) )
254 |
255 | # delete old mask dir
256 | os.system('rm -r %s'%mask_dir)
257 | os.system('rm -r %s'%semantic_dir)
258 |
259 |
260 | if __name__ == '__main__':
261 | parser = argparse.ArgumentParser()
262 | parser.add_argument("--dataset_path", type=str, help='Dataset path')
263 | parser.add_argument("--epi_threshold", type=float,
264 | default=1.0,
265 | help='epipolar distance threshold for physical motion segmentation')
266 |
267 | parser.add_argument("--input_flow_w", type=int,
268 | default=768,
269 | help='input image width for optical flow, \
270 | the height will be computed based on original aspect ratio ')
271 |
272 | parser.add_argument("--input_semantic_w", type=int,
273 | default=1024,
274 | help='input image width for semantic segmentation')
275 |
276 | parser.add_argument("--input_semantic_h", type=int,
277 | default=576,
278 | help='input image height for semantic segmentation')
279 | args = parser.parse_args()
280 |
281 | motion_segmentation(args.dataset_path, args.epi_threshold,
282 | args.input_semantic_w,
283 | args.input_semantic_h)
284 |
--------------------------------------------------------------------------------
/utils/generate_pose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 | import numpy as np
5 | from colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
6 |
7 |
8 | def load_colmap_data(realdir):
9 |
10 | camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')
11 | camdata = read_cameras_binary(camerasfile)
12 |
13 | list_of_keys = list(camdata.keys())
14 | cam = camdata[list_of_keys[0]]
15 | print( 'Cameras', len(cam))
16 |
17 | h, w, f = cam.height, cam.width, cam.params[0]
18 | # w, h, f = factor * w, factor * h, factor * f
19 | hwf = np.array([h,w,f]).reshape([3,1])
20 |
21 | imagesfile = os.path.join(realdir, 'sparse/0/images.bin')
22 | imdata = read_images_binary(imagesfile)
23 |
24 | w2c_mats = []
25 | bottom = np.array([0,0,0,1.]).reshape([1,4])
26 |
27 | names = [imdata[k].name for k in imdata]
28 | img_keys = [k for k in imdata]
29 |
30 | print('Images #', len(names))
31 | perm = np.argsort(names)
32 |
33 | points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin')
34 | pts3d = read_points3d_binary(points3dfile)
35 |
36 | bounds_mats = []
37 |
38 | for i in perm[0:len(img_keys)]:
39 |
40 | im = imdata[img_keys[i]]
41 | print(im.name)
42 | R = im.qvec2rotmat()
43 | t = im.tvec.reshape([3,1])
44 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
45 | w2c_mats.append(m)
46 |
47 | pts_3d_idx = im.point3D_ids
48 | pts_3d_vis_idx = pts_3d_idx[pts_3d_idx >= 0]
49 |
50 | #
51 | depth_list = []
52 | for k in range(len(pts_3d_vis_idx)):
53 | point_info = pts3d[pts_3d_vis_idx[k]]
54 |
55 | P_g = point_info.xyz
56 | P_c = np.dot(R, P_g.reshape(3, 1)) + t.reshape(3, 1)
57 | depth_list.append(P_c[2])
58 |
59 | zs = np.array(depth_list)
60 | close_depth, inf_depth = np.percentile(zs, 5), np.percentile(zs, 95)
61 | bounds = np.array([close_depth, inf_depth])
62 | bounds_mats.append(bounds)
63 |
64 | w2c_mats = np.stack(w2c_mats, 0)
65 | c2w_mats = np.linalg.inv(w2c_mats)
66 |
67 | poses = c2w_mats[:, :3, :4].transpose([1,2,0])
68 | poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis],
69 | [1,1,poses.shape[-1]])], 1)
70 |
71 | # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t]
72 | poses = np.concatenate([poses[:, 1:2, :],
73 | poses[:, 0:1, :],
74 | -poses[:, 2:3, :],
75 | poses[:, 3:4, :],
76 | poses[:, 4:5, :]], 1)
77 |
78 | save_arr = []
79 |
80 | for i in range((poses.shape[2])):
81 | save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0))
82 |
83 | save_arr = np.array(save_arr)
84 | print(save_arr.shape)
85 |
86 | # Use all frames to calculate COLMAP camera poses.
87 | if os.path.isdir(os.path.join(realdir, 'images_colmap')):
88 | num_colmap_frames = len(glob.glob(os.path.join(realdir, 'images_colmap', '*.jpg')))
89 | num_data_frames = len(glob.glob(os.path.join(realdir, 'images', '*.png')))
90 |
91 | assert num_colmap_frames == save_arr.shape[0]
92 | np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr[:num_data_frames, :])
93 | else:
94 | np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr)
95 |
96 |
97 | if __name__ == '__main__':
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument("--dataset_path", type=str,
100 | help='Dataset path')
101 |
102 | args = parser.parse_args()
103 |
104 | load_colmap_data(args.dataset_path)
105 |
--------------------------------------------------------------------------------
/utils/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 | Args:
8 | path (str): file path
9 | """
10 | parameters = torch.load(path, map_location=torch.device('cpu'))
11 |
12 | if "optimizer" in parameters:
13 | parameters = parameters["model"]
14 |
15 | self.load_state_dict(parameters)
16 |
--------------------------------------------------------------------------------
/utils/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
--------------------------------------------------------------------------------
/utils/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/utils/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/utils/midas/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import types
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | class Slice(nn.Module):
10 | def __init__(self, start_index=1):
11 | super(Slice, self).__init__()
12 | self.start_index = start_index
13 |
14 | def forward(self, x):
15 | return x[:, self.start_index :]
16 |
17 |
18 | class AddReadout(nn.Module):
19 | def __init__(self, start_index=1):
20 | super(AddReadout, self).__init__()
21 | self.start_index = start_index
22 |
23 | def forward(self, x):
24 | if self.start_index == 2:
25 | readout = (x[:, 0] + x[:, 1]) / 2
26 | else:
27 | readout = x[:, 0]
28 | return x[:, self.start_index :] + readout.unsqueeze(1)
29 |
30 |
31 | class ProjectReadout(nn.Module):
32 | def __init__(self, in_features, start_index=1):
33 | super(ProjectReadout, self).__init__()
34 | self.start_index = start_index
35 |
36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37 |
38 | def forward(self, x):
39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40 | features = torch.cat((x[:, self.start_index :], readout), -1)
41 |
42 | return self.project(features)
43 |
44 |
45 | class Transpose(nn.Module):
46 | def __init__(self, dim0, dim1):
47 | super(Transpose, self).__init__()
48 | self.dim0 = dim0
49 | self.dim1 = dim1
50 |
51 | def forward(self, x):
52 | x = x.transpose(self.dim0, self.dim1)
53 | return x
54 |
55 |
56 | def forward_vit(pretrained, x):
57 | b, c, h, w = x.shape
58 |
59 | glob = pretrained.model.forward_flex(x)
60 |
61 | layer_1 = pretrained.activations["1"]
62 | layer_2 = pretrained.activations["2"]
63 | layer_3 = pretrained.activations["3"]
64 | layer_4 = pretrained.activations["4"]
65 |
66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70 |
71 | unflatten = nn.Sequential(
72 | nn.Unflatten(
73 | 2,
74 | torch.Size(
75 | [
76 | h // pretrained.model.patch_size[1],
77 | w // pretrained.model.patch_size[0],
78 | ]
79 | ),
80 | )
81 | )
82 |
83 | if layer_1.ndim == 3:
84 | layer_1 = unflatten(layer_1)
85 | if layer_2.ndim == 3:
86 | layer_2 = unflatten(layer_2)
87 | if layer_3.ndim == 3:
88 | layer_3 = unflatten(layer_3)
89 | if layer_4.ndim == 3:
90 | layer_4 = unflatten(layer_4)
91 |
92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96 |
97 | return layer_1, layer_2, layer_3, layer_4
98 |
99 |
100 | def _resize_pos_embed(self, posemb, gs_h, gs_w):
101 | posemb_tok, posemb_grid = (
102 | posemb[:, : self.start_index],
103 | posemb[0, self.start_index :],
104 | )
105 |
106 | gs_old = int(math.sqrt(len(posemb_grid)))
107 |
108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111 |
112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113 |
114 | return posemb
115 |
116 |
117 | def forward_flex(self, x):
118 | b, c, h, w = x.shape
119 |
120 | pos_embed = self._resize_pos_embed(
121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122 | )
123 |
124 | B = x.shape[0]
125 |
126 | if hasattr(self.patch_embed, "backbone"):
127 | x = self.patch_embed.backbone(x)
128 | if isinstance(x, (list, tuple)):
129 | x = x[-1] # last feature if backbone outputs list/tuple of features
130 |
131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132 |
133 | if getattr(self, "dist_token", None) is not None:
134 | cls_tokens = self.cls_token.expand(
135 | B, -1, -1
136 | ) # stole cls_tokens impl from Phil Wang, thanks
137 | dist_token = self.dist_token.expand(B, -1, -1)
138 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
139 | else:
140 | cls_tokens = self.cls_token.expand(
141 | B, -1, -1
142 | ) # stole cls_tokens impl from Phil Wang, thanks
143 | x = torch.cat((cls_tokens, x), dim=1)
144 |
145 | x = x + pos_embed
146 | x = self.pos_drop(x)
147 |
148 | for blk in self.blocks:
149 | x = blk(x)
150 |
151 | x = self.norm(x)
152 |
153 | return x
154 |
155 |
156 | activations = {}
157 |
158 |
159 | def get_activation(name):
160 | def hook(model, input, output):
161 | activations[name] = output
162 |
163 | return hook
164 |
165 |
166 | def get_readout_oper(vit_features, features, use_readout, start_index=1):
167 | if use_readout == "ignore":
168 | readout_oper = [Slice(start_index)] * len(features)
169 | elif use_readout == "add":
170 | readout_oper = [AddReadout(start_index)] * len(features)
171 | elif use_readout == "project":
172 | readout_oper = [
173 | ProjectReadout(vit_features, start_index) for out_feat in features
174 | ]
175 | else:
176 | assert (
177 | False
178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179 |
180 | return readout_oper
181 |
182 |
183 | def _make_vit_b16_backbone(
184 | model,
185 | features=[96, 192, 384, 768],
186 | size=[384, 384],
187 | hooks=[2, 5, 8, 11],
188 | vit_features=768,
189 | use_readout="ignore",
190 | start_index=1,
191 | ):
192 | pretrained = nn.Module()
193 |
194 | pretrained.model = model
195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199 |
200 | pretrained.activations = activations
201 |
202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203 |
204 | # 32, 48, 136, 384
205 | pretrained.act_postprocess1 = nn.Sequential(
206 | readout_oper[0],
207 | Transpose(1, 2),
208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209 | nn.Conv2d(
210 | in_channels=vit_features,
211 | out_channels=features[0],
212 | kernel_size=1,
213 | stride=1,
214 | padding=0,
215 | ),
216 | nn.ConvTranspose2d(
217 | in_channels=features[0],
218 | out_channels=features[0],
219 | kernel_size=4,
220 | stride=4,
221 | padding=0,
222 | bias=True,
223 | dilation=1,
224 | groups=1,
225 | ),
226 | )
227 |
228 | pretrained.act_postprocess2 = nn.Sequential(
229 | readout_oper[1],
230 | Transpose(1, 2),
231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232 | nn.Conv2d(
233 | in_channels=vit_features,
234 | out_channels=features[1],
235 | kernel_size=1,
236 | stride=1,
237 | padding=0,
238 | ),
239 | nn.ConvTranspose2d(
240 | in_channels=features[1],
241 | out_channels=features[1],
242 | kernel_size=2,
243 | stride=2,
244 | padding=0,
245 | bias=True,
246 | dilation=1,
247 | groups=1,
248 | ),
249 | )
250 |
251 | pretrained.act_postprocess3 = nn.Sequential(
252 | readout_oper[2],
253 | Transpose(1, 2),
254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255 | nn.Conv2d(
256 | in_channels=vit_features,
257 | out_channels=features[2],
258 | kernel_size=1,
259 | stride=1,
260 | padding=0,
261 | ),
262 | )
263 |
264 | pretrained.act_postprocess4 = nn.Sequential(
265 | readout_oper[3],
266 | Transpose(1, 2),
267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268 | nn.Conv2d(
269 | in_channels=vit_features,
270 | out_channels=features[3],
271 | kernel_size=1,
272 | stride=1,
273 | padding=0,
274 | ),
275 | nn.Conv2d(
276 | in_channels=features[3],
277 | out_channels=features[3],
278 | kernel_size=3,
279 | stride=2,
280 | padding=1,
281 | ),
282 | )
283 |
284 | pretrained.model.start_index = start_index
285 | pretrained.model.patch_size = [16, 16]
286 |
287 | # We inject this function into the VisionTransformer instances so that
288 | # we can use it with interpolated position embeddings without modifying the library source.
289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290 | pretrained.model._resize_pos_embed = types.MethodType(
291 | _resize_pos_embed, pretrained.model
292 | )
293 |
294 | return pretrained
295 |
296 |
297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299 |
300 | hooks = [5, 11, 17, 23] if hooks == None else hooks
301 | return _make_vit_b16_backbone(
302 | model,
303 | features=[256, 512, 1024, 1024],
304 | hooks=hooks,
305 | vit_features=1024,
306 | use_readout=use_readout,
307 | )
308 |
309 |
310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312 |
313 | hooks = [2, 5, 8, 11] if hooks == None else hooks
314 | return _make_vit_b16_backbone(
315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316 | )
317 |
318 |
319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321 |
322 | hooks = [2, 5, 8, 11] if hooks == None else hooks
323 | return _make_vit_b16_backbone(
324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325 | )
326 |
327 |
328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329 | model = timm.create_model(
330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331 | )
332 |
333 | hooks = [2, 5, 8, 11] if hooks == None else hooks
334 | return _make_vit_b16_backbone(
335 | model,
336 | features=[96, 192, 384, 768],
337 | hooks=hooks,
338 | use_readout=use_readout,
339 | start_index=2,
340 | )
341 |
342 |
343 | def _make_vit_b_rn50_backbone(
344 | model,
345 | features=[256, 512, 768, 768],
346 | size=[384, 384],
347 | hooks=[0, 1, 8, 11],
348 | vit_features=768,
349 | use_vit_only=False,
350 | use_readout="ignore",
351 | start_index=1,
352 | ):
353 | pretrained = nn.Module()
354 |
355 | pretrained.model = model
356 |
357 | if use_vit_only == True:
358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360 | else:
361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362 | get_activation("1")
363 | )
364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365 | get_activation("2")
366 | )
367 |
368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370 |
371 | pretrained.activations = activations
372 |
373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374 |
375 | if use_vit_only == True:
376 | pretrained.act_postprocess1 = nn.Sequential(
377 | readout_oper[0],
378 | Transpose(1, 2),
379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380 | nn.Conv2d(
381 | in_channels=vit_features,
382 | out_channels=features[0],
383 | kernel_size=1,
384 | stride=1,
385 | padding=0,
386 | ),
387 | nn.ConvTranspose2d(
388 | in_channels=features[0],
389 | out_channels=features[0],
390 | kernel_size=4,
391 | stride=4,
392 | padding=0,
393 | bias=True,
394 | dilation=1,
395 | groups=1,
396 | ),
397 | )
398 |
399 | pretrained.act_postprocess2 = nn.Sequential(
400 | readout_oper[1],
401 | Transpose(1, 2),
402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403 | nn.Conv2d(
404 | in_channels=vit_features,
405 | out_channels=features[1],
406 | kernel_size=1,
407 | stride=1,
408 | padding=0,
409 | ),
410 | nn.ConvTranspose2d(
411 | in_channels=features[1],
412 | out_channels=features[1],
413 | kernel_size=2,
414 | stride=2,
415 | padding=0,
416 | bias=True,
417 | dilation=1,
418 | groups=1,
419 | ),
420 | )
421 | else:
422 | pretrained.act_postprocess1 = nn.Sequential(
423 | nn.Identity(), nn.Identity(), nn.Identity()
424 | )
425 | pretrained.act_postprocess2 = nn.Sequential(
426 | nn.Identity(), nn.Identity(), nn.Identity()
427 | )
428 |
429 | pretrained.act_postprocess3 = nn.Sequential(
430 | readout_oper[2],
431 | Transpose(1, 2),
432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433 | nn.Conv2d(
434 | in_channels=vit_features,
435 | out_channels=features[2],
436 | kernel_size=1,
437 | stride=1,
438 | padding=0,
439 | ),
440 | )
441 |
442 | pretrained.act_postprocess4 = nn.Sequential(
443 | readout_oper[3],
444 | Transpose(1, 2),
445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446 | nn.Conv2d(
447 | in_channels=vit_features,
448 | out_channels=features[3],
449 | kernel_size=1,
450 | stride=1,
451 | padding=0,
452 | ),
453 | nn.Conv2d(
454 | in_channels=features[3],
455 | out_channels=features[3],
456 | kernel_size=3,
457 | stride=2,
458 | padding=1,
459 | ),
460 | )
461 |
462 | pretrained.model.start_index = start_index
463 | pretrained.model.patch_size = [16, 16]
464 |
465 | # We inject this function into the VisionTransformer instances so that
466 | # we can use it with interpolated position embeddings without modifying the library source.
467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468 |
469 | # We inject this function into the VisionTransformer instances so that
470 | # we can use it with interpolated position embeddings without modifying the library source.
471 | pretrained.model._resize_pos_embed = types.MethodType(
472 | _resize_pos_embed, pretrained.model
473 | )
474 |
475 | return pretrained
476 |
477 |
478 | def _make_pretrained_vitb_rn50_384(
479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480 | ):
481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482 |
483 | hooks = [0, 1, 8, 11] if hooks == None else hooks
484 | return _make_vit_b_rn50_backbone(
485 | model,
486 | features=[256, 512, 768, 768],
487 | size=[384, 384],
488 | hooks=hooks,
489 | use_vit_only=use_vit_only,
490 | use_readout=use_readout,
491 | )
492 |
--------------------------------------------------------------------------------