├── LICENSE ├── README.md ├── configs ├── config.txt ├── config_Balloon1.txt ├── config_Balloon2.txt ├── config_Jumping.txt ├── config_Playground.txt ├── config_Skating.txt ├── config_Truck.txt └── config_Umbrella.txt ├── load_llff.py ├── render_utils.py ├── run_nerf.py ├── run_nerf_helpers.py └── utils ├── RAFT ├── __init__.py ├── corr.py ├── datasets.py ├── demo.py ├── extractor.py ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── colmap_utils.py ├── evaluation.py ├── flow_utils.py ├── generate_data.py ├── generate_depth.py ├── generate_flow.py ├── generate_motion_mask.py ├── generate_pose.py └── midas ├── base_model.py ├── blocks.py ├── midas_net.py ├── transforms.py └── vit.py /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2020 Virginia Tech Vision and Learning Lab 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | --------------------------- LICENSE FOR EdgeConnect -------------------------------- 25 | 26 | Attribution-NonCommercial 4.0 International 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic View Synthesis from Dynamic Monocular Video 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2108.00946-b31b1b.svg)](https://arxiv.org/abs/2105.06468) 4 | 5 | [Project Website](https://free-view-video.github.io/) | [Video](https://youtu.be/j8CUzIR0f8M) | [Paper](https://arxiv.org/abs/2105.06468) 6 | 7 | > **Dynamic View Synthesis from Dynamic Monocular Video**
8 | > [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Johannes Kopf](https://johanneskopf.de/), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/)
9 | in ICCV 2021
10 | 11 | ## Setup 12 | The code is test with 13 | * Linux (tested on CentOS Linux release 7.4.1708) 14 | * Anaconda 3 15 | * Python 3.7.11 16 | * CUDA 10.1 17 | * 1 V100 GPU 18 | 19 | 20 | To get started, please create the conda environment `dnerf` by running 21 | ``` 22 | conda create --name dnerf python=3.7 23 | conda activate dnerf 24 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 25 | pip install imageio scikit-image configargparse timm lpips 26 | ``` 27 | and install [COLMAP](https://colmap.github.io/install.html) manually. Then download MiDaS and RAFT weights 28 | ``` 29 | ROOT_PATH=/path/to/the/DynamicNeRF/folder 30 | cd $ROOT_PATH 31 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/weights.zip 32 | unzip weights.zip 33 | rm weights.zip 34 | ``` 35 | 36 | ## Dynamic Scene Dataset 37 | The [Dynamic Scene Dataset](https://www-users.cse.umn.edu/~jsyoon/dynamic_synth/) is used to 38 | quantitatively evaluate our method. Please download the pre-processed data by running: 39 | ``` 40 | cd $ROOT_PATH 41 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/data.zip 42 | unzip data.zip 43 | rm data.zip 44 | ``` 45 | 46 | ### Training 47 | You can train a model from scratch by running: 48 | ``` 49 | cd $ROOT_PATH/ 50 | python run_nerf.py --config configs/config_Balloon2.txt 51 | ``` 52 | 53 | Every 100k iterations, you should get videos like the following examples 54 | 55 | The novel view-time synthesis results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/novelviewtime`. 56 | ![novelviewtime](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/novelviewtime_Balloon2.gif) 57 | 58 | 59 | The reconstruction results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset`. 60 | ![testset](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_Balloon2.gif) 61 | 62 | The fix-view-change-time results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_view000`. 63 | ![testset_view000](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_view000_Balloon2.gif) 64 | 65 | The fix-time-change-view results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_time000`. 66 | ![testset_time000](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_time000_Balloon2.gif) 67 | 68 | 69 | ### Rendering from pre-trained models 70 | We also provide pre-trained models. You can download them by running: 71 | ``` 72 | cd $ROOT_PATH/ 73 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/logs.zip 74 | unzip logs.zip 75 | rm logs.zip 76 | ``` 77 | 78 | Then you can render the results directly by running: 79 | ``` 80 | python run_nerf.py --config configs/config_Balloon2.txt --render_only --ft_path $ROOT_PATH/logs/Balloon2_H270_DyNeRF_pretrain/300000.tar 81 | ``` 82 | 83 | ### Evaluating our method and others 84 | Our goal is to make the evaluation as simple as possible for you. We have collected the fix-view-change-time results of the following methods: 85 | 86 | `NeRF` \ 87 | `NeRF + t` \ 88 | `Yoon et al.` \ 89 | `Non-Rigid NeRF` \ 90 | `NSFF` \ 91 | `DynamicNeRF (ours)` 92 | 93 | Please download the results by running: 94 | ``` 95 | cd $ROOT_PATH/ 96 | wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/results.zip 97 | unzip results.zip 98 | rm results.zip 99 | ``` 100 | 101 | Then you can calculate the PSNR/SSIM/LPIPS by running: 102 | ``` 103 | cd $ROOT_PATH/utils 104 | python evaluation.py 105 | ``` 106 | 107 | | PSNR / LPIPS | Jumping | Skating | Truck | Umbrella | Balloon1 | Balloon2 | Playground | Average | 108 | |:-------------|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:| 109 | | NeRF | 20.99 / 0.305 | 23.67 / 0.311 | 22.73 / 0.229 | 21.29 / 0.440 | 19.82 / 0.205 | 24.37 / 0.098 | 21.07 / 0.165 | 21.99 / 0.250 | 110 | | NeRF + t | 18.04 / 0.455 | 20.32 / 0.512 | 18.33 / 0.382 | 17.69 / 0.728 | 18.54 / 0.275 | 20.69 / 0.216 | 14.68 / 0.421 | 18.33 / 0.427 | 111 | | NR NeRF | 20.09 / 0.287 | 23.95 / 0.227 | 19.33 / 0.446 | 19.63 / 0.421 | 17.39 / 0.348 | 22.41 / 0.213 | 15.06 / 0.317 | 19.69 / 0.323 | 112 | | NSFF | 24.65 / 0.151 | 29.29 / 0.129 | 25.96 / 0.167 | 22.97 / 0.295 | 21.96 / 0.215 | 24.27 / 0.222 | 21.22 / 0.212 | 24.33 / 0.199 | 113 | | Ours | 24.68 / 0.090 | 32.66 / 0.035 | 28.56 / 0.082 | 23.26 / 0.137 | 22.36 / 0.104 | 27.06 / 0.049 | 24.15 / 0.080 | 26.10 / 0.082 | 114 | 115 | 116 | Please note: 117 | 1. The numbers reported in the paper are calculated using TF code. The numbers here are calculated using this improved Pytorch version. 118 | 2. In Yoon's results, the first frame and the last frame are missing. To compare with Yoon's results, we have to omit the first frame and the last frame. To do so, please uncomment line 72 and comment line 73 in `evaluation.py`. 119 | 3. We obtain the results of NSFF and NR NeRF using the official implementation with default parameters. 120 | 121 | 122 | ## Train a model on your sequence 123 | 0. Set some paths 124 | 125 | ``` 126 | ROOT_PATH=/path/to/the/DynamicNeRF/folder 127 | DATASET_NAME=name_of_the_video_without_extension 128 | DATASET_PATH=$ROOT_PATH/data/$DATASET_NAME 129 | ``` 130 | 131 | 1. Prepare training images and background masks from a video. 132 | 133 | ``` 134 | cd $ROOT_PATH/utils 135 | python generate_data.py --videopath /path/to/the/video 136 | ``` 137 | 138 | 2. Use COLMAP to obtain camera poses. 139 | 140 | ``` 141 | colmap feature_extractor \ 142 | --database_path $DATASET_PATH/database.db \ 143 | --image_path $DATASET_PATH/images_colmap \ 144 | --ImageReader.mask_path $DATASET_PATH/background_mask \ 145 | --ImageReader.single_camera 1 146 | 147 | colmap exhaustive_matcher \ 148 | --database_path $DATASET_PATH/database.db 149 | 150 | mkdir $DATASET_PATH/sparse 151 | colmap mapper \ 152 | --database_path $DATASET_PATH/database.db \ 153 | --image_path $DATASET_PATH/images_colmap \ 154 | --output_path $DATASET_PATH/sparse \ 155 | --Mapper.num_threads 16 \ 156 | --Mapper.init_min_tri_angle 4 \ 157 | --Mapper.multiple_models 0 \ 158 | --Mapper.extract_colors 0 159 | ``` 160 | 161 | 3. Save camera poses into the format that NeRF reads. 162 | 163 | ``` 164 | cd $ROOT_PATH/utils 165 | python generate_pose.py --dataset_path $DATASET_PATH 166 | ``` 167 | 168 | 4. Estimate monocular depth. 169 | 170 | ``` 171 | cd $ROOT_PATH/utils 172 | python generate_depth.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/midas_v21-f6b98070.pt 173 | ``` 174 | 175 | 5. Predict optical flows. 176 | 177 | ``` 178 | cd $ROOT_PATH/utils 179 | python generate_flow.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/raft-things.pth 180 | ``` 181 | 182 | 6. Obtain motion mask (code adapted from NSFF). 183 | 184 | ``` 185 | cd $ROOT_PATH/utils 186 | python generate_motion_mask.py --dataset_path $DATASET_PATH 187 | ``` 188 | 189 | 7. Train a model. Please change `expname` and `datadir` in `configs/config.txt`. 190 | 191 | ``` 192 | cd $ROOT_PATH/ 193 | python run_nerf.py --config configs/config.txt 194 | ``` 195 | 196 | Explanation of each parameter: 197 | 198 | - `expname`: experiment name 199 | - `basedir`: where to store ckpts and logs 200 | - `datadir`: input data directory 201 | - `factor`: downsample factor for the input images 202 | - `N_rand`: number of random rays per gradient step 203 | - `N_samples`: number of samples per ray 204 | - `netwidth`: channels per layer 205 | - `use_viewdirs`: whether enable view-dependency for StaticNeRF 206 | - `use_viewdirsDyn`: whether enable view-dependency for DynamicNeRF 207 | - `raw_noise_std`: std dev of noise added to regularize sigma_a output 208 | - `no_ndc`: do not use normalized device coordinates 209 | - `lindisp`: sampling linearly in disparity rather than depth 210 | - `i_video`: frequency of novel view-time synthesis video saving 211 | - `i_testset`: frequency of testset video saving 212 | - `N_iters`: number of training iterations 213 | - `i_img`: frequency of tensorboard image logging 214 | - `DyNeRF_blending`: whether use DynamicNeRF to predict blending weight 215 | - `pretrain`: whether pre-train StaticNeRF 216 | 217 | ## License 218 | This work is licensed under MIT License. See [LICENSE](LICENSE) for details. 219 | 220 | If you find this code useful for your research, please consider citing the following paper: 221 | 222 | @inproceedings{Gao-ICCV-DynNeRF, 223 | author = {Gao, Chen and Saraf, Ayush and Kopf, Johannes and Huang, Jia-Bin}, 224 | title = {Dynamic View Synthesis from Dynamic Monocular Video}, 225 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision}, 226 | year = {2021} 227 | } 228 | 229 | ## Acknowledgments 230 | Our training code is build upon 231 | [NeRF](https://github.com/bmild/nerf), 232 | [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch), and 233 | [NSFF](https://github.com/zl548/Neural-Scene-Flow-Fields). 234 | Our flow prediction code is modified from [RAFT](https://github.com/princeton-vl/RAFT). 235 | Our depth prediction code is modified from [MiDaS](https://github.com/isl-org/MiDaS). 236 | -------------------------------------------------------------------------------- /configs/config.txt: -------------------------------------------------------------------------------- 1 | expname = xxxxxx_DyNeRF_pretrain_test 2 | basedir = ./logs 3 | datadir = ./data/xxxxxx/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 4 8 | N_rand = 1024 9 | N_samples = 64 10 | netwidth = 256 11 | 12 | i_video = 100000 13 | i_testset = 100000 14 | N_iters = 500001 15 | i_img = 500 16 | 17 | use_viewdirs = True 18 | use_viewdirsDyn = True 19 | raw_noise_std = 1e0 20 | no_ndc = False 21 | lindisp = False 22 | 23 | dynamic_loss_lambda = 1.0 24 | static_loss_lambda = 1.0 25 | full_loss_lambda = 3.0 26 | depth_loss_lambda = 0.04 27 | order_loss_lambda = 0.1 28 | flow_loss_lambda = 0.02 29 | slow_loss_lambda = 0.01 30 | smooth_loss_lambda = 0.1 31 | consistency_loss_lambda = 1.0 32 | mask_loss_lambda = 0.01 33 | sparse_loss_lambda = 0.001 34 | DyNeRF_blending = True 35 | pretrain = True 36 | -------------------------------------------------------------------------------- /configs/config_Balloon1.txt: -------------------------------------------------------------------------------- 1 | expname = Balloon1_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Balloon1/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = False 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /configs/config_Balloon2.txt: -------------------------------------------------------------------------------- 1 | expname = Balloon2_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Balloon2/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = True 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /configs/config_Jumping.txt: -------------------------------------------------------------------------------- 1 | expname = Jumping_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Jumping/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = False 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /configs/config_Playground.txt: -------------------------------------------------------------------------------- 1 | expname = Playground_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Playground/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = True 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /configs/config_Skating.txt: -------------------------------------------------------------------------------- 1 | expname = Skating_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Skating/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = True 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /configs/config_Truck.txt: -------------------------------------------------------------------------------- 1 | expname = Truck_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Truck/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = True 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /configs/config_Umbrella.txt: -------------------------------------------------------------------------------- 1 | expname = Umbrella_H270_DyNeRF_pretrain 2 | basedir = ./logs 3 | datadir = ./data/Umbrella/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | N_rand = 1024 9 | N_samples = 64 10 | N_importance = 0 11 | netwidth = 256 12 | 13 | i_video = 100000 14 | i_testset = 100000 15 | N_iters = 300001 16 | i_img = 500 17 | 18 | use_viewdirs = True 19 | use_viewdirsDyn = True 20 | raw_noise_std = 1e0 21 | no_ndc = False 22 | lindisp = False 23 | 24 | dynamic_loss_lambda = 1.0 25 | static_loss_lambda = 1.0 26 | full_loss_lambda = 3.0 27 | depth_loss_lambda = 0.04 28 | order_loss_lambda = 0.1 29 | flow_loss_lambda = 0.02 30 | slow_loss_lambda = 0.01 31 | smooth_loss_lambda = 0.1 32 | consistency_loss_lambda = 1.0 33 | mask_loss_lambda = 0.1 34 | sparse_loss_lambda = 0.001 35 | DyNeRF_blending = True 36 | pretrain = True 37 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import imageio 4 | import numpy as np 5 | 6 | from utils.flow_utils import resize_flow 7 | from run_nerf_helpers import get_grid 8 | 9 | 10 | def _minify(basedir, factors=[], resolutions=[]): 11 | needtoload = False 12 | for r in factors: 13 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 14 | if not os.path.exists(imgdir): 15 | needtoload = True 16 | for r in resolutions: 17 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 18 | if not os.path.exists(imgdir): 19 | needtoload = True 20 | if not needtoload: 21 | return 22 | 23 | from shutil import copy 24 | from subprocess import check_output 25 | 26 | imgdir = os.path.join(basedir, 'images') 27 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 28 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 29 | imgdir_orig = imgdir 30 | 31 | wd = os.getcwd() 32 | 33 | for r in factors + resolutions: 34 | if isinstance(r, int): 35 | name = 'images_{}'.format(r) 36 | resizearg = '{}%'.format(100./r) 37 | else: 38 | name = 'images_{}x{}'.format(r[1], r[0]) 39 | resizearg = '{}x{}'.format(r[1], r[0]) 40 | imgdir = os.path.join(basedir, name) 41 | if os.path.exists(imgdir): 42 | continue 43 | 44 | print('Minifying', r, basedir) 45 | 46 | os.makedirs(imgdir) 47 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 48 | 49 | ext = imgs[0].split('.')[-1] 50 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 51 | print(args) 52 | os.chdir(imgdir) 53 | check_output(args, shell=True) 54 | os.chdir(wd) 55 | 56 | if ext != 'png': 57 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 58 | print('Removed duplicates') 59 | print('Done') 60 | 61 | 62 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 63 | print('factor ', factor) 64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 66 | bds = poses_arr[:, -2:].transpose([1,0]) 67 | 68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 70 | sh = imageio.imread(img0).shape 71 | 72 | sfx = '' 73 | 74 | if factor is not None: 75 | sfx = '_{}'.format(factor) 76 | _minify(basedir, factors=[factor]) 77 | factor = factor 78 | elif height is not None: 79 | factor = sh[0] / float(height) 80 | width = int(sh[1] / factor) 81 | if width % 2 == 1: 82 | width -= 1 83 | _minify(basedir, resolutions=[[height, width]]) 84 | sfx = '_{}x{}'.format(width, height) 85 | elif width is not None: 86 | factor = sh[1] / float(width) 87 | height = int(sh[0] / factor) 88 | if height % 2 == 1: 89 | height -= 1 90 | _minify(basedir, resolutions=[[height, width]]) 91 | sfx = '_{}x{}'.format(width, height) 92 | else: 93 | factor = 1 94 | 95 | imgdir = os.path.join(basedir, 'images' + sfx) 96 | if not os.path.exists(imgdir): 97 | print( imgdir, 'does not exist, returning' ) 98 | return 99 | 100 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) \ 101 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 102 | if poses.shape[-1] != len(imgfiles): 103 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 104 | return 105 | 106 | sh = imageio.imread(imgfiles[0]).shape 107 | num_img = len(imgfiles) 108 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 109 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 110 | 111 | if not load_imgs: 112 | return poses, bds 113 | 114 | def imread(f): 115 | if f.endswith('png'): 116 | return imageio.imread(f, ignoregamma=True) 117 | else: 118 | return imageio.imread(f) 119 | 120 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles] 121 | imgs = np.stack(imgs, -1) 122 | 123 | assert imgs.shape[0] == sh[0] 124 | assert imgs.shape[1] == sh[1] 125 | 126 | disp_dir = os.path.join(basedir, 'disp') 127 | 128 | dispfiles = [os.path.join(disp_dir, f) \ 129 | for f in sorted(os.listdir(disp_dir)) if f.endswith('npy')] 130 | 131 | disp = [cv2.resize(np.load(f), 132 | (sh[1], sh[0]), 133 | interpolation=cv2.INTER_NEAREST) for f in dispfiles] 134 | disp = np.stack(disp, -1) 135 | 136 | mask_dir = os.path.join(basedir, 'motion_masks') 137 | maskfiles = [os.path.join(mask_dir, f) \ 138 | for f in sorted(os.listdir(mask_dir)) if f.endswith('png')] 139 | 140 | masks = [cv2.resize(imread(f)/255., (sh[1], sh[0]), 141 | interpolation=cv2.INTER_NEAREST) for f in maskfiles] 142 | masks = np.stack(masks, -1) 143 | masks = np.float32(masks > 1e-3) 144 | 145 | flow_dir = os.path.join(basedir, 'flow') 146 | flows_f = [] 147 | flow_masks_f = [] 148 | flows_b = [] 149 | flow_masks_b = [] 150 | for i in range(num_img): 151 | if i == num_img - 1: 152 | fwd_flow, fwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1])) 153 | else: 154 | fwd_flow_path = os.path.join(flow_dir, '%03d_fwd.npz'%i) 155 | fwd_data = np.load(fwd_flow_path) 156 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask'] 157 | fwd_flow = resize_flow(fwd_flow, sh[0], sh[1]) 158 | fwd_mask = np.float32(fwd_mask) 159 | fwd_mask = cv2.resize(fwd_mask, (sh[1], sh[0]), 160 | interpolation=cv2.INTER_NEAREST) 161 | flows_f.append(fwd_flow) 162 | flow_masks_f.append(fwd_mask) 163 | 164 | if i == 0: 165 | bwd_flow, bwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1])) 166 | else: 167 | bwd_flow_path = os.path.join(flow_dir, '%03d_bwd.npz'%i) 168 | bwd_data = np.load(bwd_flow_path) 169 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask'] 170 | bwd_flow = resize_flow(bwd_flow, sh[0], sh[1]) 171 | bwd_mask = np.float32(bwd_mask) 172 | bwd_mask = cv2.resize(bwd_mask, (sh[1], sh[0]), 173 | interpolation=cv2.INTER_NEAREST) 174 | flows_b.append(bwd_flow) 175 | flow_masks_b.append(bwd_mask) 176 | 177 | flows_f = np.stack(flows_f, -1) 178 | flow_masks_f = np.stack(flow_masks_f, -1) 179 | flows_b = np.stack(flows_b, -1) 180 | flow_masks_b = np.stack(flow_masks_b, -1) 181 | 182 | print(imgs.shape) 183 | print(disp.shape) 184 | print(masks.shape) 185 | print(flows_f.shape) 186 | print(flow_masks_f.shape) 187 | 188 | assert(imgs.shape[0] == disp.shape[0]) 189 | assert(imgs.shape[0] == masks.shape[0]) 190 | assert(imgs.shape[0] == flows_f.shape[0]) 191 | assert(imgs.shape[0] == flow_masks_f.shape[0]) 192 | 193 | assert(imgs.shape[1] == disp.shape[1]) 194 | assert(imgs.shape[1] == masks.shape[1]) 195 | 196 | return poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b 197 | 198 | 199 | def normalize(x): 200 | return x / np.linalg.norm(x) 201 | 202 | def viewmatrix(z, up, pos): 203 | vec2 = normalize(z) 204 | vec1_avg = up 205 | vec0 = normalize(np.cross(vec1_avg, vec2)) 206 | vec1 = normalize(np.cross(vec2, vec0)) 207 | m = np.stack([vec0, vec1, vec2, pos], 1) 208 | return m 209 | 210 | 211 | def poses_avg(poses): 212 | 213 | hwf = poses[0, :3, -1:] 214 | 215 | center = poses[:, :3, 3].mean(0) 216 | vec2 = normalize(poses[:, :3, 2].sum(0)) 217 | up = poses[:, :3, 1].sum(0) 218 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 219 | 220 | return c2w 221 | 222 | 223 | 224 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 225 | render_poses = [] 226 | rads = np.array(list(rads) + [1.]) 227 | hwf = c2w[:,4:5] 228 | 229 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 230 | c = np.dot(c2w[:3, :4], 231 | np.array([np.cos(theta), 232 | -np.sin(theta), 233 | -np.sin(theta*zrate), 234 | 1.]) * rads) 235 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 236 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 237 | return render_poses 238 | 239 | 240 | 241 | def recenter_poses(poses): 242 | 243 | poses_ = poses+0 244 | bottom = np.reshape([0,0,0,1.], [1,4]) 245 | c2w = poses_avg(poses) 246 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 247 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 248 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 249 | 250 | poses = np.linalg.inv(c2w) @ poses 251 | poses_[:,:3,:4] = poses[:,:3,:4] 252 | poses = poses_ 253 | return poses 254 | 255 | 256 | def spherify_poses(poses, bds): 257 | 258 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 259 | 260 | rays_d = poses[:,:3,2:3] 261 | rays_o = poses[:,:3,3:4] 262 | 263 | def min_line_dist(rays_o, rays_d): 264 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 265 | b_i = -A_i @ rays_o 266 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 267 | return pt_mindist 268 | 269 | pt_mindist = min_line_dist(rays_o, rays_d) 270 | 271 | center = pt_mindist 272 | up = (poses[:,:3,3] - center).mean(0) 273 | 274 | vec0 = normalize(up) 275 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 276 | vec2 = normalize(np.cross(vec0, vec1)) 277 | pos = center 278 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 279 | 280 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 281 | 282 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 283 | 284 | sc = 1./rad 285 | poses_reset[:,:3,3] *= sc 286 | bds *= sc 287 | rad *= sc 288 | 289 | centroid = np.mean(poses_reset[:,:3,3], 0) 290 | zh = centroid[2] 291 | radcircle = np.sqrt(rad**2-zh**2) 292 | new_poses = [] 293 | 294 | for th in np.linspace(0.,2.*np.pi, 120): 295 | 296 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 297 | up = np.array([0,0,-1.]) 298 | 299 | vec2 = normalize(camorigin) 300 | vec0 = normalize(np.cross(vec2, up)) 301 | vec1 = normalize(np.cross(vec2, vec0)) 302 | pos = camorigin 303 | p = np.stack([vec0, vec1, vec2, pos], 1) 304 | 305 | new_poses.append(p) 306 | 307 | new_poses = np.stack(new_poses, 0) 308 | 309 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 310 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 311 | 312 | return poses_reset, new_poses, bds 313 | 314 | 315 | def load_llff_data(args, basedir, 316 | factor=2, 317 | recenter=True, bd_factor=.75, 318 | spherify=False, path_zflat=False, 319 | frame2dolly=10): 320 | 321 | poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b = \ 322 | _load_data(basedir, factor=factor) # factor=2 downsamples original imgs by 2x 323 | 324 | print('Loaded', basedir, bds.min(), bds.max()) 325 | 326 | # Correct rotation matrix ordering and move variable dim to axis 0 327 | poses = np.concatenate([poses[:, 1:2, :], 328 | -poses[:, 0:1, :], 329 | poses[:, 2:, :]], 1) 330 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 331 | images = np.moveaxis(imgs, -1, 0).astype(np.float32) 332 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 333 | disp = np.moveaxis(disp, -1, 0).astype(np.float32) 334 | masks = np.moveaxis(masks, -1, 0).astype(np.float32) 335 | flows_f = np.moveaxis(flows_f, -1, 0).astype(np.float32) 336 | flow_masks_f = np.moveaxis(flow_masks_f, -1, 0).astype(np.float32) 337 | flows_b = np.moveaxis(flows_b, -1, 0).astype(np.float32) 338 | flow_masks_b = np.moveaxis(flow_masks_b, -1, 0).astype(np.float32) 339 | 340 | # Rescale if bd_factor is provided 341 | sc = 1. if bd_factor is None else 1./(np.percentile(bds[:, 0], 5) * bd_factor) 342 | 343 | poses[:, :3, 3] *= sc 344 | bds *= sc 345 | 346 | if recenter: 347 | poses = recenter_poses(poses) 348 | 349 | # Only for rendering 350 | if frame2dolly == -1: 351 | c2w = poses_avg(poses) 352 | else: 353 | c2w = poses[frame2dolly, :, :] 354 | 355 | H, W, _ = c2w[:, -1] 356 | 357 | # Generate poses for novel views 358 | render_poses, render_focals = generate_path(c2w, args) 359 | render_poses = np.array(render_poses).astype(np.float32) 360 | 361 | grids = get_grid(int(H), int(W), len(poses), flows_f, flow_masks_f, flows_b, flow_masks_b) # [N, H, W, 8] 362 | 363 | return images, disp, masks, poses, bds,\ 364 | render_poses, render_focals, grids 365 | 366 | 367 | def generate_path(c2w, args): 368 | hwf = c2w[:, 4:5] 369 | num_novelviews = args.num_novelviews 370 | max_disp = 48.0 371 | H, W, focal = hwf[:, 0] 372 | 373 | max_trans = max_disp / focal 374 | output_poses = [] 375 | output_focals = [] 376 | 377 | # Rendering teaser. Add translation. 378 | for i in range(num_novelviews): 379 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.x_trans_multiplier 380 | y_trans = max_trans * (np.cos(2.0 * np.pi * float(i) / float(num_novelviews)) - 1.) * args.y_trans_multiplier 381 | z_trans = 0. 382 | 383 | i_pose = np.concatenate([ 384 | np.concatenate( 385 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 386 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 387 | ],axis=0) 388 | 389 | i_pose = np.linalg.inv(i_pose) 390 | 391 | ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 392 | 393 | render_pose = np.dot(ref_pose, i_pose) 394 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1)) 395 | output_focals.append(focal) 396 | 397 | # Rendering teaser. Add zooming. 398 | if args.frame2dolly != -1: 399 | for i in range(num_novelviews // 2 + 1): 400 | x_trans = 0. 401 | y_trans = 0. 402 | # z_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.z_trans_multiplier 403 | z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2) 404 | i_pose = np.concatenate([ 405 | np.concatenate( 406 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 407 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 408 | ],axis=0) 409 | 410 | i_pose = np.linalg.inv(i_pose) #torch.tensor(np.linalg.inv(i_pose)).float() 411 | 412 | ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 413 | 414 | render_pose = np.dot(ref_pose, i_pose) 415 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1)) 416 | output_focals.append(focal) 417 | print(z_trans / max_trans / args.z_trans_multiplier) 418 | 419 | # Rendering teaser. Add dolly zoom. 420 | if args.frame2dolly != -1: 421 | for i in range(num_novelviews // 2 + 1): 422 | x_trans = 0. 423 | y_trans = 0. 424 | z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2) 425 | i_pose = np.concatenate([ 426 | np.concatenate( 427 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 428 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 429 | ],axis=0) 430 | 431 | i_pose = np.linalg.inv(i_pose) 432 | 433 | ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 434 | 435 | render_pose = np.dot(ref_pose, i_pose) 436 | output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1)) 437 | new_focal = focal - args.focal_decrease * z_trans / max_trans / args.z_trans_multiplier 438 | output_focals.append(new_focal) 439 | print(z_trans / max_trans / args.z_trans_multiplier, new_focal) 440 | 441 | return output_poses, output_focals 442 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | # Misc utils 12 | def img2mse(x, y, M=None): 13 | if M == None: 14 | return torch.mean((x - y) ** 2) 15 | else: 16 | return torch.sum((x - y) ** 2 * M) / (torch.sum(M) + 1e-8) / x.shape[-1] 17 | 18 | 19 | def img2mae(x, y, M=None): 20 | if M == None: 21 | return torch.mean(torch.abs(x - y)) 22 | else: 23 | return torch.sum(torch.abs(x - y) * M) / (torch.sum(M) + 1e-8) / x.shape[-1] 24 | 25 | 26 | def L1(x, M=None): 27 | if M == None: 28 | return torch.mean(torch.abs(x)) 29 | else: 30 | return torch.sum(torch.abs(x) * M) / (torch.sum(M) + 1e-8) / x.shape[-1] 31 | 32 | 33 | def L2(x, M=None): 34 | if M == None: 35 | return torch.mean(x ** 2) 36 | else: 37 | return torch.sum((x ** 2) * M) / (torch.sum(M) + 1e-8) / x.shape[-1] 38 | 39 | 40 | def entropy(x): 41 | return -torch.sum(x * torch.log(x + 1e-19)) / x.shape[0] 42 | 43 | 44 | def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 45 | 46 | 47 | def to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8) 48 | 49 | 50 | class Embedder: 51 | 52 | def __init__(self, **kwargs): 53 | 54 | self.kwargs = kwargs 55 | self.create_embedding_fn() 56 | 57 | def create_embedding_fn(self): 58 | 59 | embed_fns = [] 60 | d = self.kwargs['input_dims'] 61 | out_dim = 0 62 | if self.kwargs['include_input']: 63 | embed_fns.append(lambda x: x) 64 | out_dim += d 65 | 66 | max_freq = self.kwargs['max_freq_log2'] 67 | N_freqs = self.kwargs['num_freqs'] 68 | 69 | if self.kwargs['log_sampling']: 70 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 71 | else: 72 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 73 | 74 | for freq in freq_bands: 75 | for p_fn in self.kwargs['periodic_fns']: 76 | embed_fns.append(lambda x, p_fn=p_fn, 77 | freq=freq : p_fn(x * freq)) 78 | out_dim += d 79 | 80 | self.embed_fns = embed_fns 81 | self.out_dim = out_dim 82 | 83 | def embed(self, inputs): 84 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 85 | 86 | 87 | def get_embedder(multires, i=0, input_dims=3): 88 | 89 | if i == -1: 90 | return nn.Identity(), 3 91 | 92 | embed_kwargs = { 93 | 'include_input': True, 94 | 'input_dims': input_dims, 95 | 'max_freq_log2': multires-1, 96 | 'num_freqs': multires, 97 | 'log_sampling': True, 98 | 'periodic_fns': [torch.sin, torch.cos], 99 | } 100 | 101 | embedder_obj = Embedder(**embed_kwargs) 102 | def embed(x, eo=embedder_obj): return eo.embed(x) 103 | return embed, embedder_obj.out_dim 104 | 105 | 106 | # Dynamic NeRF model architecture 107 | class NeRF_d(nn.Module): 108 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirsDyn=True): 109 | """ 110 | """ 111 | super(NeRF_d, self).__init__() 112 | self.D = D 113 | self.W = W 114 | self.input_ch = input_ch 115 | self.input_ch_views = input_ch_views 116 | self.skips = skips 117 | self.use_viewdirsDyn = use_viewdirsDyn 118 | 119 | self.pts_linears = nn.ModuleList( 120 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 121 | 122 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 123 | 124 | if self.use_viewdirsDyn: 125 | self.feature_linear = nn.Linear(W, W) 126 | self.alpha_linear = nn.Linear(W, 1) 127 | self.rgb_linear = nn.Linear(W//2, 3) 128 | else: 129 | self.output_linear = nn.Linear(W, output_ch) 130 | 131 | self.sf_linear = nn.Linear(W, 6) 132 | self.weight_linear = nn.Linear(W, 1) 133 | 134 | def forward(self, x): 135 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 136 | h = input_pts 137 | for i, l in enumerate(self.pts_linears): 138 | h = self.pts_linears[i](h) 139 | h = F.relu(h) 140 | if i in self.skips: 141 | h = torch.cat([input_pts, h], -1) 142 | 143 | # Scene flow should be unbounded. However, in NDC space the coordinate is 144 | # bounded in [-1, 1]. 145 | sf = torch.tanh(self.sf_linear(h)) 146 | blending = torch.sigmoid(self.weight_linear(h)) 147 | 148 | if self.use_viewdirsDyn: 149 | alpha = self.alpha_linear(h) 150 | feature = self.feature_linear(h) 151 | h = torch.cat([feature, input_views], -1) 152 | 153 | for i, l in enumerate(self.views_linears): 154 | h = self.views_linears[i](h) 155 | h = F.relu(h) 156 | 157 | rgb = self.rgb_linear(h) 158 | outputs = torch.cat([rgb, alpha], -1) 159 | else: 160 | outputs = self.output_linear(h) 161 | 162 | return torch.cat([outputs, sf, blending], dim=-1) 163 | 164 | 165 | # Static NeRF model architecture 166 | class NeRF_s(nn.Module): 167 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=True): 168 | """ 169 | """ 170 | super(NeRF_s, self).__init__() 171 | self.D = D 172 | self.W = W 173 | self.input_ch = input_ch 174 | self.input_ch_views = input_ch_views 175 | self.skips = skips 176 | self.use_viewdirs = use_viewdirs 177 | 178 | self.pts_linears = nn.ModuleList( 179 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 180 | 181 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 182 | 183 | if self.use_viewdirs: 184 | self.feature_linear = nn.Linear(W, W) 185 | self.alpha_linear = nn.Linear(W, 1) 186 | self.rgb_linear = nn.Linear(W//2, 3) 187 | else: 188 | self.output_linear = nn.Linear(W, output_ch) 189 | 190 | self.weight_linear = nn.Linear(W, 1) 191 | 192 | def forward(self, x): 193 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 194 | h = input_pts 195 | for i, l in enumerate(self.pts_linears): 196 | h = self.pts_linears[i](h) 197 | h = F.relu(h) 198 | if i in self.skips: 199 | h = torch.cat([input_pts, h], -1) 200 | 201 | blending = torch.sigmoid(self.weight_linear(h)) 202 | if self.use_viewdirs: 203 | alpha = self.alpha_linear(h) 204 | feature = self.feature_linear(h) 205 | h = torch.cat([feature, input_views], -1) 206 | 207 | for i, l in enumerate(self.views_linears): 208 | h = self.views_linears[i](h) 209 | h = F.relu(h) 210 | 211 | rgb = self.rgb_linear(h) 212 | outputs = torch.cat([rgb, alpha], -1) 213 | else: 214 | outputs = self.output_linear(h) 215 | 216 | return torch.cat([outputs, blending], -1) 217 | 218 | 219 | def batchify(fn, chunk): 220 | """Constructs a version of 'fn' that applies to smaller batches. 221 | """ 222 | if chunk is None: 223 | return fn 224 | 225 | def ret(inputs): 226 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 227 | return ret 228 | 229 | 230 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 231 | """Prepares inputs and applies network 'fn'. 232 | """ 233 | 234 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 235 | 236 | embedded = embed_fn(inputs_flat) 237 | if viewdirs is not None: 238 | input_dirs = viewdirs[:, None].expand(inputs[:, :, :3].shape) 239 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 240 | embedded_dirs = embeddirs_fn(input_dirs_flat) 241 | embedded = torch.cat([embedded, embedded_dirs], -1) 242 | 243 | outputs_flat = batchify(fn, netchunk)(embedded) 244 | outputs = torch.reshape(outputs_flat, list( 245 | inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 246 | return outputs 247 | 248 | 249 | def create_nerf(args): 250 | """Instantiate NeRF's MLP model. 251 | """ 252 | 253 | embed_fn_d, input_ch_d = get_embedder(args.multires, args.i_embed, 4) 254 | # 10 * 2 * 4 + 4 = 84 255 | # L * (sin, cos) * (x, y, z, t) + (x, y, z, t) 256 | 257 | input_ch_views = 0 258 | embeddirs_fn = None 259 | if args.use_viewdirs: 260 | embeddirs_fn, input_ch_views = get_embedder( 261 | args.multires_views, args.i_embed, 3) 262 | # 4 * 2 * 3 + 3 = 27 263 | # L * (sin, cos) * (3 Cartesian viewing direction unit vector from [theta, phi]) + (3 Cartesian viewing direction unit vector from [theta, phi]) 264 | output_ch = 5 if args.N_importance > 0 else 4 265 | skips = [4] 266 | model_d = NeRF_d(D=args.netdepth, W=args.netwidth, 267 | input_ch=input_ch_d, output_ch=output_ch, skips=skips, 268 | input_ch_views=input_ch_views, 269 | use_viewdirsDyn=args.use_viewdirsDyn).to(device) 270 | 271 | device_ids = list(range(torch.cuda.device_count())) 272 | model_d = torch.nn.DataParallel(model_d, device_ids=device_ids) 273 | grad_vars = list(model_d.parameters()) 274 | 275 | embed_fn_s, input_ch_s = get_embedder(args.multires, args.i_embed, 3) 276 | # 10 * 2 * 3 + 3 = 63 277 | # L * (sin, cos) * (x, y, z) + (x, y, z) 278 | 279 | model_s = NeRF_s(D=args.netdepth, W=args.netwidth, 280 | input_ch=input_ch_s, output_ch=output_ch, skips=skips, 281 | input_ch_views=input_ch_views, 282 | use_viewdirs=args.use_viewdirs).to(device) 283 | 284 | model_s = torch.nn.DataParallel(model_s, device_ids=device_ids) 285 | grad_vars += list(model_s.parameters()) 286 | 287 | model_fine = None 288 | if args.N_importance > 0: 289 | raise NotImplementedError 290 | 291 | def network_query_fn_d(inputs, viewdirs, network_fn): return run_network( 292 | inputs, viewdirs, network_fn, 293 | embed_fn=embed_fn_d, 294 | embeddirs_fn=embeddirs_fn, 295 | netchunk=args.netchunk) 296 | 297 | def network_query_fn_s(inputs, viewdirs, network_fn): return run_network( 298 | inputs, viewdirs, network_fn, 299 | embed_fn=embed_fn_s, 300 | embeddirs_fn=embeddirs_fn, 301 | netchunk=args.netchunk) 302 | 303 | render_kwargs_train = { 304 | 'network_query_fn_d': network_query_fn_d, 305 | 'network_query_fn_s': network_query_fn_s, 306 | 'network_fn_d': model_d, 307 | 'network_fn_s': model_s, 308 | 'perturb': args.perturb, 309 | 'N_importance': args.N_importance, 310 | 'N_samples': args.N_samples, 311 | 'use_viewdirs': args.use_viewdirs, 312 | 'raw_noise_std': args.raw_noise_std, 313 | 'inference': False, 314 | 'DyNeRF_blending': args.DyNeRF_blending, 315 | } 316 | 317 | # NDC only good for LLFF-style forward facing data 318 | if args.dataset_type != 'llff' or args.no_ndc: 319 | print('Not ndc!') 320 | render_kwargs_train['ndc'] = False 321 | render_kwargs_train['lindisp'] = args.lindisp 322 | else: 323 | render_kwargs_train['ndc'] = True 324 | 325 | render_kwargs_test = { 326 | k: render_kwargs_train[k] for k in render_kwargs_train} 327 | render_kwargs_test['perturb'] = False 328 | render_kwargs_test['raw_noise_std'] = 0. 329 | render_kwargs_test['inference'] = True 330 | 331 | # Create optimizer 332 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 333 | 334 | start = 0 335 | basedir = args.basedir 336 | expname = args.expname 337 | 338 | if args.ft_path is not None and args.ft_path != 'None': 339 | ckpts = [args.ft_path] 340 | else: 341 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] 342 | print('Found ckpts', ckpts) 343 | if len(ckpts) > 0 and not args.no_reload: 344 | ckpt_path = ckpts[-1] 345 | print('Reloading from', ckpt_path) 346 | ckpt = torch.load(ckpt_path) 347 | 348 | start = ckpt['global_step'] + 1 349 | # optimizer.load_state_dict(ckpt['optimizer_state_dict']) 350 | model_d.load_state_dict(ckpt['network_fn_d_state_dict']) 351 | model_s.load_state_dict(ckpt['network_fn_s_state_dict']) 352 | print('Resetting step to', start) 353 | 354 | if model_fine is not None: 355 | raise NotImplementedError 356 | 357 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 358 | 359 | 360 | # Ray helpers 361 | def get_rays(H, W, focal, c2w): 362 | """Get ray origins, directions from a pinhole camera.""" 363 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 364 | i = i.t() 365 | j = j.t() 366 | dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1) 367 | # Rotate ray directions from camera frame to the world frame 368 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 369 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 370 | rays_o = c2w[:3, -1].expand(rays_d.shape) 371 | return rays_o, rays_d 372 | 373 | 374 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 375 | """Normalized device coordinate rays. 376 | Space such that the canvas is a cube with sides [-1, 1] in each axis. 377 | Args: 378 | H: int. Height in pixels. 379 | W: int. Width in pixels. 380 | focal: float. Focal length of pinhole camera. 381 | near: float or array of shape[batch_size]. Near depth bound for the scene. 382 | rays_o: array of shape [batch_size, 3]. Camera origin. 383 | rays_d: array of shape [batch_size, 3]. Ray direction. 384 | Returns: 385 | rays_o: array of shape [batch_size, 3]. Camera origin in NDC. 386 | rays_d: array of shape [batch_size, 3]. Ray direction in NDC. 387 | """ 388 | # Shift ray origins to near plane 389 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 390 | rays_o = rays_o + t[..., None] * rays_d 391 | 392 | # Projection 393 | o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2] 394 | o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2] 395 | o2 = 1. + 2. * near / rays_o[..., 2] 396 | 397 | d0 = -1./(W/(2.*focal)) * \ 398 | (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2]) 399 | d1 = -1./(H/(2.*focal)) * \ 400 | (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2]) 401 | d2 = -2. * near / rays_o[..., 2] 402 | 403 | rays_o = torch.stack([o0, o1, o2], -1) 404 | rays_d = torch.stack([d0, d1, d2], -1) 405 | 406 | return rays_o, rays_d 407 | 408 | 409 | def get_grid(H, W, num_img, flows_f, flow_masks_f, flows_b, flow_masks_b): 410 | 411 | # |--------------------| |--------------------| 412 | # | j | | v | 413 | # | i * | | u * | 414 | # | | | | 415 | # |--------------------| |--------------------| 416 | 417 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), 418 | np.arange(H, dtype=np.float32), indexing='xy') 419 | 420 | grid = np.empty((0, H, W, 8), np.float32) 421 | for idx in range(num_img): 422 | grid = np.concatenate((grid, np.stack([i, 423 | j, 424 | flows_f[idx, :, :, 0], 425 | flows_f[idx, :, :, 1], 426 | flow_masks_f[idx, :, :], 427 | flows_b[idx, :, :, 0], 428 | flows_b[idx, :, :, 1], 429 | flow_masks_b[idx, :, :]], -1)[None, ...])) 430 | return grid 431 | 432 | 433 | def NDC2world(pts, H, W, f): 434 | 435 | # NDC coordinate to world coordinate 436 | pts_z = 2 / (torch.clamp(pts[..., 2:], min=-1., max=1-1e-3) - 1) 437 | pts_x = - pts[..., 0:1] * pts_z * W / 2 / f 438 | pts_y = - pts[..., 1:2] * pts_z * H / 2 / f 439 | pts_world = torch.cat([pts_x, pts_y, pts_z], -1) 440 | 441 | return pts_world 442 | 443 | 444 | def render_3d_point(H, W, f, pose, weights, pts): 445 | """Render 3D position along each ray and project it to the image plane. 446 | """ 447 | 448 | c2w = pose 449 | w2c = c2w[:3, :3].transpose(0, 1) # same as np.linalg.inv(c2w[:3, :3]) 450 | 451 | # Rendered 3D position in NDC coordinate 452 | pts_map_NDC = torch.sum(weights[..., None] * pts, -2) 453 | 454 | # NDC coordinate to world coordinate 455 | pts_map_world = NDC2world(pts_map_NDC, H, W, f) 456 | 457 | # World coordinate to camera coordinate 458 | # Translate 459 | pts_map_world = pts_map_world - c2w[:, 3] 460 | # Rotate 461 | pts_map_cam = torch.sum(pts_map_world[..., None, :] * w2c[:3, :3], -1) 462 | 463 | # Camera coordinate to 2D image coordinate 464 | pts_plane = torch.cat([pts_map_cam[..., 0:1] / (- pts_map_cam[..., 2:]) * f + W * .5, 465 | - pts_map_cam[..., 1:2] / (- pts_map_cam[..., 2:]) * f + H * .5], 466 | -1) 467 | 468 | return pts_plane 469 | 470 | 471 | def induce_flow(H, W, focal, pose_neighbor, weights, pts_3d_neighbor, pts_2d): 472 | 473 | # Render 3D position along each ray and project it to the neighbor frame's image plane. 474 | pts_2d_neighbor = render_3d_point(H, W, focal, 475 | pose_neighbor, 476 | weights, 477 | pts_3d_neighbor) 478 | induced_flow = pts_2d_neighbor - pts_2d 479 | 480 | return induced_flow 481 | 482 | 483 | def compute_depth_loss(dyn_depth, gt_depth): 484 | 485 | t_d = torch.median(dyn_depth) 486 | s_d = torch.mean(torch.abs(dyn_depth - t_d)) 487 | dyn_depth_norm = (dyn_depth - t_d) / s_d 488 | 489 | t_gt = torch.median(gt_depth) 490 | s_gt = torch.mean(torch.abs(gt_depth - t_gt)) 491 | gt_depth_norm = (gt_depth - t_gt) / s_gt 492 | 493 | return torch.mean((dyn_depth_norm - gt_depth_norm) ** 2) 494 | 495 | 496 | def normalize_depth(depth): 497 | return torch.clamp(depth / percentile(depth, 97), 0., 1.) 498 | 499 | 500 | def percentile(t, q): 501 | """ 502 | Return the ``q``-th percentile of the flattened input tensor's data. 503 | 504 | CAUTION: 505 | * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. 506 | * Values are not interpolated, which corresponds to 507 | ``numpy.percentile(..., interpolation="nearest")``. 508 | 509 | :param t: Input tensor. 510 | :param q: Percentile to compute, which must be between 0 and 100 inclusive. 511 | :return: Resulting value (scalar). 512 | """ 513 | 514 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 515 | result = t.view(-1).kthvalue(k).values.item() 516 | return result 517 | 518 | 519 | def save_res(moviebase, ret, fps=None): 520 | 521 | if fps == None: 522 | if len(ret['rgbs']) < 25: 523 | fps = 4 524 | else: 525 | fps = 24 526 | 527 | for k in ret: 528 | if 'rgbs' in k: 529 | imageio.mimwrite(moviebase + k + '.mp4', 530 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1) 531 | # imageio.mimsave(moviebase + k + '.gif', 532 | # to8b(ret[k]), format='gif', fps=fps) 533 | elif 'depths' in k: 534 | imageio.mimwrite(moviebase + k + '.mp4', 535 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1) 536 | # imageio.mimsave(moviebase + k + '.gif', 537 | # to8b(ret[k]), format='gif', fps=fps) 538 | elif 'disps' in k: 539 | imageio.mimwrite(moviebase + k + '.mp4', 540 | to8b(ret[k] / np.max(ret[k])), fps=fps, quality=8, macro_block_size=1) 541 | # imageio.mimsave(moviebase + k + '.gif', 542 | # to8b(ret[k] / np.max(ret[k])), format='gif', fps=fps) 543 | elif 'sceneflow_' in k: 544 | imageio.mimwrite(moviebase + k + '.mp4', 545 | to8b(norm_sf(ret[k])), fps=fps, quality=8, macro_block_size=1) 546 | # imageio.mimsave(moviebase + k + '.gif', 547 | # to8b(norm_sf(ret[k])), format='gif', fps=fps) 548 | elif 'flows' in k: 549 | imageio.mimwrite(moviebase + k + '.mp4', 550 | ret[k], fps=fps, quality=8, macro_block_size=1) 551 | # imageio.mimsave(moviebase + k + '.gif', 552 | # ret[k], format='gif', fps=fps) 553 | elif 'dynamicness' in k: 554 | imageio.mimwrite(moviebase + k + '.mp4', 555 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1) 556 | # imageio.mimsave(moviebase + k + '.gif', 557 | # to8b(ret[k]), format='gif', fps=fps) 558 | elif 'disocclusions' in k: 559 | imageio.mimwrite(moviebase + k + '.mp4', 560 | to8b(ret[k][..., 0]), fps=fps, quality=8, macro_block_size=1) 561 | # imageio.mimsave(moviebase + k + '.gif', 562 | # to8b(ret[k][..., 0]), format='gif', fps=fps) 563 | elif 'blending' in k: 564 | blending = ret[k][..., None] 565 | blending = np.moveaxis(blending, [0, 1, 2, 3], [1, 2, 0, 3]) 566 | imageio.mimwrite(moviebase + k + '.mp4', 567 | to8b(blending), fps=fps, quality=8, macro_block_size=1) 568 | # imageio.mimsave(moviebase + k + '.gif', 569 | # to8b(blending), format='gif', fps=fps) 570 | elif 'weights' in k: 571 | imageio.mimwrite(moviebase + k + '.mp4', 572 | to8b(ret[k]), fps=fps, quality=8, macro_block_size=1) 573 | else: 574 | raise NotImplementedError 575 | 576 | 577 | def norm_sf_channel(sf_ch): 578 | 579 | # Make sure zero scene flow is not shifted 580 | sf_ch[sf_ch >= 0] = sf_ch[sf_ch >= 0] / sf_ch.max() / 2 581 | sf_ch[sf_ch < 0] = sf_ch[sf_ch < 0] / np.abs(sf_ch.min()) / 2 582 | sf_ch = sf_ch + 0.5 583 | return sf_ch 584 | 585 | 586 | def norm_sf(sf): 587 | 588 | sf = np.concatenate((norm_sf_channel(sf[..., 0:1]), 589 | norm_sf_channel(sf[..., 1:2]), 590 | norm_sf_channel(sf[..., 2:3])), -1) 591 | sf = np.moveaxis(sf, [0, 1, 2, 3], [1, 2, 0, 3]) 592 | return sf 593 | 594 | 595 | # Spatial smoothness (adapted from NSFF) 596 | def compute_sf_smooth_s_loss(pts1, pts2, H, W, f): 597 | 598 | N_samples = pts1.shape[1] 599 | 600 | # NDC coordinate to world coordinate 601 | pts1_world = NDC2world(pts1[..., :int(N_samples * 0.95), :], H, W, f) 602 | pts2_world = NDC2world(pts2[..., :int(N_samples * 0.95), :], H, W, f) 603 | 604 | # scene flow in world coordinate 605 | scene_flow_world = pts1_world - pts2_world 606 | 607 | return L1(scene_flow_world[..., :-1, :] - scene_flow_world[..., 1:, :]) 608 | 609 | 610 | # Temporal smoothness 611 | def compute_sf_smooth_loss(pts, pts_f, pts_b, H, W, f): 612 | 613 | N_samples = pts.shape[1] 614 | 615 | pts_world = NDC2world(pts[..., :int(N_samples * 0.9), :], H, W, f) 616 | pts_f_world = NDC2world(pts_f[..., :int(N_samples * 0.9), :], H, W, f) 617 | pts_b_world = NDC2world(pts_b[..., :int(N_samples * 0.9), :], H, W, f) 618 | 619 | # scene flow in world coordinate 620 | sceneflow_f = pts_f_world - pts_world 621 | sceneflow_b = pts_b_world - pts_world 622 | 623 | # For a 3D point, its forward and backward sceneflow should be opposite. 624 | return L2(sceneflow_f + sceneflow_b) 625 | -------------------------------------------------------------------------------- /utils/RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | -------------------------------------------------------------------------------- /utils/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = \ 79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) 80 | return fmap1_grad, fmap2_grad, coords_grad, None 81 | 82 | 83 | class AlternateCorrBlock: 84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 85 | self.num_levels = num_levels 86 | self.radius = radius 87 | 88 | self.pyramid = [(fmap1, fmap2)] 89 | for i in range(self.num_levels): 90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 92 | self.pyramid.append((fmap1, fmap2)) 93 | 94 | def __call__(self, coords): 95 | 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /utils/RAFT/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /utils/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import cv2 5 | import glob 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from .raft import RAFT 11 | from .utils import flow_viz 12 | from .utils.utils import InputPadder 13 | 14 | 15 | 16 | DEVICE = 'cuda' 17 | 18 | def load_image(imfile): 19 | img = np.array(Image.open(imfile)).astype(np.uint8) 20 | img = torch.from_numpy(img).permute(2, 0, 1).float() 21 | return img 22 | 23 | 24 | def load_image_list(image_files): 25 | images = [] 26 | for imfile in sorted(image_files): 27 | images.append(load_image(imfile)) 28 | 29 | images = torch.stack(images, dim=0) 30 | images = images.to(DEVICE) 31 | 32 | padder = InputPadder(images.shape) 33 | return padder.pad(images)[0] 34 | 35 | 36 | def viz(img, flo): 37 | img = img[0].permute(1,2,0).cpu().numpy() 38 | flo = flo[0].permute(1,2,0).cpu().numpy() 39 | 40 | # map flow to rgb image 41 | flo = flow_viz.flow_to_image(flo) 42 | # img_flo = np.concatenate([img, flo], axis=0) 43 | img_flo = flo 44 | 45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 47 | # cv2.waitKey() 48 | 49 | 50 | def demo(args): 51 | model = torch.nn.DataParallel(RAFT(args)) 52 | model.load_state_dict(torch.load(args.model)) 53 | 54 | model = model.module 55 | model.to(DEVICE) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 60 | glob.glob(os.path.join(args.path, '*.jpg')) 61 | 62 | images = load_image_list(images) 63 | for i in range(images.shape[0]-1): 64 | image1 = images[i,None] 65 | image2 = images[i+1,None] 66 | 67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 68 | viz(image1, flow_up) 69 | 70 | 71 | def RAFT_infer(args): 72 | model = torch.nn.DataParallel(RAFT(args)) 73 | model.load_state_dict(torch.load(args.model)) 74 | 75 | model = model.module 76 | model.to(DEVICE) 77 | model.eval() 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /utils/RAFT/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /utils/RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in args._get_kwargs(): 42 | args.dropout = 0 43 | 44 | if 'alternate_corr' not in args._get_kwargs(): 45 | args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | 87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 88 | """ Estimate optical flow between pair of frames """ 89 | 90 | image1 = 2 * (image1 / 255.0) - 1.0 91 | image2 = 2 * (image2 / 255.0) - 1.0 92 | 93 | image1 = image1.contiguous() 94 | image2 = image2.contiguous() 95 | 96 | hdim = self.hidden_dim 97 | cdim = self.context_dim 98 | 99 | # run the feature network 100 | with autocast(enabled=self.args.mixed_precision): 101 | fmap1, fmap2 = self.fnet([image1, image2]) 102 | 103 | fmap1 = fmap1.float() 104 | fmap2 = fmap2.float() 105 | if self.args.alternate_corr: 106 | corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) 107 | else: 108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 109 | 110 | # run the context network 111 | with autocast(enabled=self.args.mixed_precision): 112 | cnet = self.cnet(image1) 113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 114 | net = torch.tanh(net) 115 | inp = torch.relu(inp) 116 | 117 | coords0, coords1 = self.initialize_flow(image1) 118 | 119 | if flow_init is not None: 120 | coords1 = coords1 + flow_init 121 | 122 | flow_predictions = [] 123 | for itr in range(iters): 124 | coords1 = coords1.detach() 125 | corr = corr_fn(coords1) # index correlation volume 126 | 127 | flow = coords1 - coords0 128 | with autocast(enabled=self.args.mixed_precision): 129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 130 | 131 | # F(t+1) = F(t) + \Delta(t) 132 | coords1 = coords1 + delta_flow 133 | 134 | # upsample predictions 135 | if up_mask is None: 136 | flow_up = upflow8(coords1 - coords0) 137 | else: 138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 139 | 140 | flow_predictions.append(flow_up) 141 | 142 | if test_mode: 143 | return coords1 - coords0, flow_up 144 | 145 | return flow_predictions 146 | -------------------------------------------------------------------------------- /utils/RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /utils/RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /utils/RAFT/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /utils/RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /utils/RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /utils/RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /utils/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec 297 | 298 | 299 | def main(): 300 | if len(sys.argv) != 3: 301 | print("Usage: python read_model.py path/to/model/folder [.txt,.bin]") 302 | return 303 | 304 | cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2]) 305 | 306 | print("num_cameras:", len(cameras)) 307 | print("num_images:", len(images)) 308 | print("num_points3D:", len(points3D)) 309 | 310 | 311 | if __name__ == "__main__": 312 | main() 313 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import lpips 4 | import torch 5 | import numpy as np 6 | from skimage.metrics import structural_similarity 7 | 8 | 9 | def im2tensor(img): 10 | return torch.Tensor(img.transpose(2, 0, 1) / 127.5 - 1.0)[None, ...] 11 | 12 | 13 | def create_dir(dir): 14 | if not os.path.exists(dir): 15 | os.makedirs(dir) 16 | 17 | 18 | def readimage(data_dir, sequence, time, method): 19 | img = cv2.imread(os.path.join(data_dir, method, sequence, 'v000_t' + str(time).zfill(3) + '.png')) 20 | return img 21 | 22 | 23 | def calculate_metrics(data_dir, sequence, methods, lpips_loss): 24 | 25 | PSNRs = np.zeros((len(methods))) 26 | SSIMs = np.zeros((len(methods))) 27 | LPIPSs = np.zeros((len(methods))) 28 | 29 | nFrame = 0 30 | 31 | # Yoon's results do not include v000_t000 and v000_t011. Omit these two 32 | # frames if evaluating Yoon's method. 33 | if 'Yoon' in methods: 34 | time_start = 1 35 | time_end = 11 36 | else: 37 | time_start = 0 38 | time_end = 12 39 | 40 | for time in range(time_start, time_end): # Fix view v0, change time 41 | 42 | nFrame += 1 43 | 44 | img_true = readimage(data_dir, sequence, time, 'gt') 45 | 46 | for method_idx, method in enumerate(methods): 47 | 48 | if 'Yoon' in methods and sequence == 'Truck' and time == 10: 49 | break 50 | 51 | img = readimage(data_dir, sequence, time, method) 52 | PSNR = cv2.PSNR(img_true, img) 53 | SSIM = structural_similarity(img_true, img, multichannel=True) 54 | LPIPS = lpips_loss.forward(im2tensor(img_true), im2tensor(img)).item() 55 | 56 | PSNRs[method_idx] += PSNR 57 | SSIMs[method_idx] += SSIM 58 | LPIPSs[method_idx] += LPIPS 59 | 60 | PSNRs = PSNRs / nFrame 61 | SSIMs = SSIMs / nFrame 62 | LPIPSs = LPIPSs / nFrame 63 | 64 | return PSNRs, SSIMs, LPIPSs 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | lpips_loss = lpips.LPIPS(net='alex') # best forward scores 70 | data_dir = '../results' 71 | sequences = ['Balloon1', 'Balloon2', 'Jumping', 'Playground', 'Skating', 'Truck', 'Umbrella'] 72 | # methods = ['NeRF', 'NeRF_t', 'Yoon', 'NR', 'NSFF', 'Ours'] 73 | methods = ['NeRF', 'NeRF_t', 'NR', 'NSFF', 'Ours'] 74 | 75 | PSNRs_total = np.zeros((len(methods))) 76 | SSIMs_total = np.zeros((len(methods))) 77 | LPIPSs_total = np.zeros((len(methods))) 78 | for sequence in sequences: 79 | print(sequence) 80 | PSNRs, SSIMs, LPIPSs = calculate_metrics(data_dir, sequence, methods, lpips_loss) 81 | for method_idx, method in enumerate(methods): 82 | print(method.ljust(7) + '%.2f'%(PSNRs[method_idx]) + ' / %.4f'%(SSIMs[method_idx]) + ' / %.3f'%(LPIPSs[method_idx])) 83 | 84 | PSNRs_total += PSNRs 85 | SSIMs_total += SSIMs 86 | LPIPSs_total += LPIPSs 87 | 88 | PSNRs_total = PSNRs_total / len(sequences) 89 | SSIMs_total = SSIMs_total / len(sequences) 90 | LPIPSs_total = LPIPSs_total / len(sequences) 91 | print('Avg.') 92 | for method_idx, method in enumerate(methods): 93 | print(method.ljust(7) + '%.2f'%(PSNRs_total[method_idx]) + ' / %.4f'%(SSIMs_total[method_idx]) + ' / %.3f'%(LPIPSs_total[method_idx])) 94 | -------------------------------------------------------------------------------- /utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from os.path import * 6 | UNKNOWN_FLOW_THRESH = 1e7 7 | 8 | def flow_to_image(flow, global_max=None): 9 | """ 10 | Convert flow into middlebury color code image 11 | :param flow: optical flow map 12 | :return: optical flow image in middlebury color 13 | """ 14 | u = flow[:, :, 0] 15 | v = flow[:, :, 1] 16 | 17 | maxu = -999. 18 | maxv = -999. 19 | minu = 999. 20 | minv = 999. 21 | 22 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 23 | u[idxUnknow] = 0 24 | v[idxUnknow] = 0 25 | 26 | maxu = max(maxu, np.max(u)) 27 | minu = min(minu, np.min(u)) 28 | 29 | maxv = max(maxv, np.max(v)) 30 | minv = min(minv, np.min(v)) 31 | 32 | rad = np.sqrt(u ** 2 + v ** 2) 33 | 34 | if global_max == None: 35 | maxrad = max(-1, np.max(rad)) 36 | else: 37 | maxrad = global_max 38 | 39 | u = u/(maxrad + np.finfo(float).eps) 40 | v = v/(maxrad + np.finfo(float).eps) 41 | 42 | img = compute_color(u, v) 43 | 44 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 45 | img[idx] = 0 46 | 47 | return np.uint8(img) 48 | 49 | 50 | def compute_color(u, v): 51 | """ 52 | compute optical flow color map 53 | :param u: optical flow horizontal map 54 | :param v: optical flow vertical map 55 | :return: optical flow in color code 56 | """ 57 | [h, w] = u.shape 58 | img = np.zeros([h, w, 3]) 59 | nanIdx = np.isnan(u) | np.isnan(v) 60 | u[nanIdx] = 0 61 | v[nanIdx] = 0 62 | 63 | colorwheel = make_color_wheel() 64 | ncols = np.size(colorwheel, 0) 65 | 66 | rad = np.sqrt(u**2+v**2) 67 | 68 | a = np.arctan2(-v, -u) / np.pi 69 | 70 | fk = (a+1) / 2 * (ncols - 1) + 1 71 | 72 | k0 = np.floor(fk).astype(int) 73 | 74 | k1 = k0 + 1 75 | k1[k1 == ncols+1] = 1 76 | f = fk - k0 77 | 78 | for i in range(0, np.size(colorwheel,1)): 79 | tmp = colorwheel[:, i] 80 | col0 = tmp[k0-1] / 255 81 | col1 = tmp[k1-1] / 255 82 | col = (1-f) * col0 + f * col1 83 | 84 | idx = rad <= 1 85 | col[idx] = 1-rad[idx]*(1-col[idx]) 86 | notidx = np.logical_not(idx) 87 | 88 | col[notidx] *= 0.75 89 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 90 | 91 | return img 92 | 93 | 94 | def make_color_wheel(): 95 | """ 96 | Generate color wheel according Middlebury color code 97 | :return: Color wheel 98 | """ 99 | RY = 15 100 | YG = 6 101 | GC = 4 102 | CB = 11 103 | BM = 13 104 | MR = 6 105 | 106 | ncols = RY + YG + GC + CB + BM + MR 107 | 108 | colorwheel = np.zeros([ncols, 3]) 109 | 110 | col = 0 111 | 112 | # RY 113 | colorwheel[0:RY, 0] = 255 114 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 115 | col += RY 116 | 117 | # YG 118 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 119 | colorwheel[col:col+YG, 1] = 255 120 | col += YG 121 | 122 | # GC 123 | colorwheel[col:col+GC, 1] = 255 124 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 125 | col += GC 126 | 127 | # CB 128 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 129 | colorwheel[col:col+CB, 2] = 255 130 | col += CB 131 | 132 | # BM 133 | colorwheel[col:col+BM, 2] = 255 134 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 135 | col += + BM 136 | 137 | # MR 138 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 139 | colorwheel[col:col+MR, 0] = 255 140 | 141 | return colorwheel 142 | 143 | 144 | def resize_flow(flow, H_new, W_new): 145 | H_old, W_old = flow.shape[0:2] 146 | flow_resized = cv2.resize(flow, (W_new, H_new), interpolation=cv2.INTER_LINEAR) 147 | flow_resized[:, :, 0] *= H_new / H_old 148 | flow_resized[:, :, 1] *= W_new / W_old 149 | return flow_resized 150 | 151 | 152 | 153 | def warp_flow(img, flow): 154 | h, w = flow.shape[:2] 155 | flow_new = flow.copy() 156 | flow_new[:,:,0] += np.arange(w) 157 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis] 158 | 159 | res = cv2.remap(img, flow_new, None, 160 | cv2.INTER_CUBIC, 161 | borderMode=cv2.BORDER_CONSTANT) 162 | return res 163 | 164 | 165 | def consistCheck(flowB, flowF): 166 | 167 | # |--------------------| |--------------------| 168 | # | y | | v | 169 | # | x * | | u * | 170 | # | | | | 171 | # |--------------------| |--------------------| 172 | 173 | # sub: numPix * [y x t] 174 | 175 | imgH, imgW, _ = flowF.shape 176 | 177 | (fy, fx) = np.mgrid[0 : imgH, 0 : imgW].astype(np.float32) 178 | fxx = fx + flowB[:, :, 0] # horizontal 179 | fyy = fy + flowB[:, :, 1] # vertical 180 | 181 | u = (fxx + cv2.remap(flowF[:, :, 0], fxx, fyy, cv2.INTER_LINEAR) - fx) 182 | v = (fyy + cv2.remap(flowF[:, :, 1], fxx, fyy, cv2.INTER_LINEAR) - fy) 183 | BFdiff = (u ** 2 + v ** 2) ** 0.5 184 | 185 | return BFdiff, np.stack((u, v), axis=2) 186 | 187 | 188 | def read_optical_flow(basedir, img_i_name, read_fwd): 189 | flow_dir = os.path.join(basedir, 'flow') 190 | 191 | fwd_flow_path = os.path.join(flow_dir, '%s_fwd.npz'%img_i_name[:-4]) 192 | bwd_flow_path = os.path.join(flow_dir, '%s_bwd.npz'%img_i_name[:-4]) 193 | 194 | if read_fwd: 195 | fwd_data = np.load(fwd_flow_path) 196 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask'] 197 | return fwd_flow, fwd_mask 198 | else: 199 | bwd_data = np.load(bwd_flow_path) 200 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask'] 201 | return bwd_flow, bwd_mask 202 | 203 | 204 | def compute_epipolar_distance(T_21, K, p_1, p_2): 205 | R_21 = T_21[:3, :3] 206 | t_21 = T_21[:3, 3] 207 | 208 | E_mat = np.dot(skew(t_21), R_21) 209 | # compute bearing vector 210 | inv_K = np.linalg.inv(K) 211 | 212 | F_mat = np.dot(np.dot(inv_K.T, E_mat), inv_K) 213 | 214 | l_2 = np.dot(F_mat, p_1) 215 | algebric_e_distance = np.sum(p_2 * l_2, axis=0) 216 | n_term = np.sqrt(l_2[0, :]**2 + l_2[1, :]**2) + 1e-8 217 | geometric_e_distance = algebric_e_distance/n_term 218 | geometric_e_distance = np.abs(geometric_e_distance) 219 | 220 | return geometric_e_distance 221 | 222 | 223 | def skew(x): 224 | return np.array([[0, -x[2], x[1]], 225 | [x[2], 0, -x[0]], 226 | [-x[1], x[0], 0]]) 227 | -------------------------------------------------------------------------------- /utils/generate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | import glob 5 | import torch 6 | import torchvision 7 | import skimage.morphology 8 | import argparse 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def create_dir(dir): 14 | if not os.path.exists(dir): 15 | os.makedirs(dir) 16 | 17 | 18 | def multi_view_multi_time(args): 19 | """ 20 | Generating multi view multi time data 21 | """ 22 | 23 | Maskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval() 24 | threshold = 0.5 25 | 26 | videoname, ext = os.path.splitext(os.path.basename(args.videopath)) 27 | 28 | imgs = [] 29 | reader = imageio.get_reader(args.videopath) 30 | for i, im in enumerate(reader): 31 | imgs.append(im) 32 | 33 | imgs = np.array(imgs) 34 | num_frames, H, W, _ = imgs.shape 35 | imgs = imgs[::int(np.ceil(num_frames / 100))] 36 | 37 | create_dir(os.path.join(args.data_dir, videoname, 'images')) 38 | create_dir(os.path.join(args.data_dir, videoname, 'images_colmap')) 39 | create_dir(os.path.join(args.data_dir, videoname, 'background_mask')) 40 | 41 | for idx, img in enumerate(imgs): 42 | print(idx) 43 | imageio.imwrite(os.path.join(args.data_dir, videoname, 'images', str(idx).zfill(3) + '.png'), img) 44 | imageio.imwrite(os.path.join(args.data_dir, videoname, 'images_colmap', str(idx).zfill(3) + '.jpg'), img) 45 | 46 | # Get coarse background mask 47 | img = torchvision.transforms.functional.to_tensor(img).to(device) 48 | background_mask = torch.FloatTensor(H, W).fill_(1.0).to(device) 49 | objPredictions = Maskrcnn([img])[0] 50 | 51 | for intMask in range(len(objPredictions['masks'])): 52 | if objPredictions['scores'][intMask].item() > threshold: 53 | if objPredictions['labels'][intMask].item() == 1: # person 54 | background_mask[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 55 | 56 | background_mask_np = ((background_mask.cpu().numpy() > 0.1) * 255).astype(np.uint8) 57 | imageio.imwrite(os.path.join(args.data_dir, videoname, 'background_mask', str(idx).zfill(3) + '.jpg.png'), background_mask_np) 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("--videopath", type=str, 63 | help='video path') 64 | parser.add_argument("--data_dir", type=str, default='../data/', 65 | help='where to store data') 66 | 67 | args = parser.parse_args() 68 | 69 | multi_view_multi_time(args) 70 | -------------------------------------------------------------------------------- /utils/generate_depth.py: -------------------------------------------------------------------------------- 1 | """Compute depth maps for images in the input folder. 2 | """ 3 | import os 4 | import cv2 5 | import glob 6 | import torch 7 | import argparse 8 | import numpy as np 9 | 10 | from torchvision.transforms import Compose 11 | from midas.midas_net import MidasNet 12 | from midas.transforms import Resize, NormalizeImage, PrepareForNet 13 | 14 | 15 | def create_dir(dir): 16 | if not os.path.exists(dir): 17 | os.makedirs(dir) 18 | 19 | 20 | def read_image(path): 21 | """Read image and output RGB image (0-1). 22 | 23 | Args: 24 | path (str): path to file 25 | 26 | Returns: 27 | array: RGB image (0-1) 28 | """ 29 | img = cv2.imread(path) 30 | 31 | if img.ndim == 2: 32 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 33 | 34 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 35 | 36 | return img 37 | 38 | 39 | def run(input_path, output_path, output_img_path, model_path): 40 | """Run MonoDepthNN to compute depth maps. 41 | Args: 42 | input_path (str): path to input folder 43 | output_path (str): path to output folder 44 | model_path (str): path to saved model 45 | """ 46 | print("initialize") 47 | 48 | # select device 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | print("device: %s" % device) 51 | 52 | # load network 53 | model = MidasNet(model_path, non_negative=True) 54 | sh = cv2.imread(sorted(glob.glob(os.path.join(input_path, "*")))[0]).shape 55 | net_w, net_h = sh[1], sh[0] 56 | 57 | resize_mode="upper_bound" 58 | 59 | transform = Compose( 60 | [ 61 | Resize( 62 | net_w, 63 | net_h, 64 | resize_target=None, 65 | keep_aspect_ratio=True, 66 | ensure_multiple_of=32, 67 | resize_method=resize_mode, 68 | image_interpolation_method=cv2.INTER_CUBIC, 69 | ), 70 | NormalizeImage(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]), 72 | PrepareForNet(), 73 | ] 74 | ) 75 | 76 | model.eval() 77 | model.to(device) 78 | 79 | # get input 80 | img_names = sorted(glob.glob(os.path.join(input_path, "*"))) 81 | num_images = len(img_names) 82 | 83 | # create output folder 84 | os.makedirs(output_path, exist_ok=True) 85 | 86 | print("start processing") 87 | 88 | for ind, img_name in enumerate(img_names): 89 | 90 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) 91 | 92 | # input 93 | img = read_image(img_name) 94 | img_input = transform({"image": img})["image"] 95 | 96 | # compute 97 | with torch.no_grad(): 98 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0) 99 | prediction = model.forward(sample) 100 | prediction = ( 101 | torch.nn.functional.interpolate( 102 | prediction.unsqueeze(1), 103 | size=[net_h, net_w], 104 | mode="bicubic", 105 | align_corners=False, 106 | ) 107 | .squeeze() 108 | .cpu() 109 | .numpy() 110 | ) 111 | 112 | # output 113 | filename = os.path.join( 114 | output_path, os.path.splitext(os.path.basename(img_name))[0] 115 | ) 116 | 117 | print(filename + '.npy') 118 | np.save(filename + '.npy', prediction.astype(np.float32)) 119 | 120 | depth_min = prediction.min() 121 | depth_max = prediction.max() 122 | 123 | max_val = (2**(8*2))-1 124 | 125 | if depth_max - depth_min > np.finfo("float").eps: 126 | out = max_val * (prediction - depth_min) / (depth_max - depth_min) 127 | else: 128 | out = np.zeros(prediction.shape, dtype=prediction.type) 129 | 130 | cv2.imwrite(os.path.join(output_img_path, os.path.splitext(os.path.basename(img_name))[0] + '.png'), out.astype("uint16")) 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument("--dataset_path", type=str, help='Dataset path') 136 | parser.add_argument('--model', help="restore midas checkpoint") 137 | args = parser.parse_args() 138 | 139 | input_path = os.path.join(args.dataset_path, 'images') 140 | output_path = os.path.join(args.dataset_path, 'disp') 141 | output_img_path = os.path.join(args.dataset_path, 'disp_png') 142 | create_dir(output_path) 143 | create_dir(output_img_path) 144 | 145 | # set torch options 146 | torch.backends.cudnn.enabled = True 147 | torch.backends.cudnn.benchmark = True 148 | 149 | # compute depth maps 150 | run(input_path, output_path, output_img_path, args.model) 151 | -------------------------------------------------------------------------------- /utils/generate_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import glob 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | from RAFT.raft import RAFT 10 | from RAFT.utils import flow_viz 11 | from RAFT.utils.utils import InputPadder 12 | 13 | from flow_utils import * 14 | 15 | DEVICE = 'cuda' 16 | 17 | 18 | def create_dir(dir): 19 | if not os.path.exists(dir): 20 | os.makedirs(dir) 21 | 22 | 23 | def load_image(imfile): 24 | img = np.array(Image.open(imfile)).astype(np.uint8) 25 | img = torch.from_numpy(img).permute(2, 0, 1).float() 26 | return img[None].to(DEVICE) 27 | 28 | 29 | def warp_flow(img, flow): 30 | h, w = flow.shape[:2] 31 | flow_new = flow.copy() 32 | flow_new[:,:,0] += np.arange(w) 33 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis] 34 | 35 | res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT) 36 | return res 37 | 38 | 39 | def compute_fwdbwd_mask(fwd_flow, bwd_flow): 40 | alpha_1 = 0.5 41 | alpha_2 = 0.5 42 | 43 | bwd2fwd_flow = warp_flow(bwd_flow, fwd_flow) 44 | fwd_lr_error = np.linalg.norm(fwd_flow + bwd2fwd_flow, axis=-1) 45 | fwd_mask = fwd_lr_error < alpha_1 * (np.linalg.norm(fwd_flow, axis=-1) \ 46 | + np.linalg.norm(bwd2fwd_flow, axis=-1)) + alpha_2 47 | 48 | fwd2bwd_flow = warp_flow(fwd_flow, bwd_flow) 49 | bwd_lr_error = np.linalg.norm(bwd_flow + fwd2bwd_flow, axis=-1) 50 | 51 | bwd_mask = bwd_lr_error < alpha_1 * (np.linalg.norm(bwd_flow, axis=-1) \ 52 | + np.linalg.norm(fwd2bwd_flow, axis=-1)) + alpha_2 53 | 54 | return fwd_mask, bwd_mask 55 | 56 | def run(args, input_path, output_path, output_img_path): 57 | model = torch.nn.DataParallel(RAFT(args)) 58 | model.load_state_dict(torch.load(args.model)) 59 | 60 | model = model.module 61 | model.to(DEVICE) 62 | model.eval() 63 | 64 | with torch.no_grad(): 65 | images = glob.glob(os.path.join(input_path, '*.png')) + \ 66 | glob.glob(os.path.join(input_path, '*.jpg')) 67 | 68 | images = sorted(images) 69 | for i in range(len(images) - 1): 70 | print(i) 71 | image1 = load_image(images[i]) 72 | image2 = load_image(images[i + 1]) 73 | 74 | padder = InputPadder(image1.shape) 75 | image1, image2 = padder.pad(image1, image2) 76 | 77 | _, flow_fwd = model(image1, image2, iters=20, test_mode=True) 78 | _, flow_bwd = model(image2, image1, iters=20, test_mode=True) 79 | 80 | flow_fwd = padder.unpad(flow_fwd[0]).cpu().numpy().transpose(1, 2, 0) 81 | flow_bwd = padder.unpad(flow_bwd[0]).cpu().numpy().transpose(1, 2, 0) 82 | 83 | mask_fwd, mask_bwd = compute_fwdbwd_mask(flow_fwd, flow_bwd) 84 | 85 | # Save flow 86 | np.savez(os.path.join(output_path, '%03d_fwd.npz'%i), flow=flow_fwd, mask=mask_fwd) 87 | np.savez(os.path.join(output_path, '%03d_bwd.npz'%(i + 1)), flow=flow_bwd, mask=mask_bwd) 88 | 89 | # Save flow_img 90 | Image.fromarray(flow_viz.flow_to_image(flow_fwd)).save(os.path.join(output_img_path, '%03d_fwd.png'%i)) 91 | Image.fromarray(flow_viz.flow_to_image(flow_bwd)).save(os.path.join(output_img_path, '%03d_bwd.png'%(i + 1))) 92 | 93 | Image.fromarray(mask_fwd).save(os.path.join(output_img_path, '%03d_fwd_mask.png'%i)) 94 | Image.fromarray(mask_bwd).save(os.path.join(output_img_path, '%03d_bwd_mask.png'%(i + 1))) 95 | 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--dataset_path", type=str, help='Dataset path') 100 | parser.add_argument('--model', help="restore RAFT checkpoint") 101 | parser.add_argument('--small', action='store_true', help='use small model') 102 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 103 | args = parser.parse_args() 104 | 105 | input_path = os.path.join(args.dataset_path, 'images') 106 | output_path = os.path.join(args.dataset_path, 'flow') 107 | output_img_path = os.path.join(args.dataset_path, 'flow_png') 108 | create_dir(output_path) 109 | create_dir(output_img_path) 110 | 111 | run(args, input_path, output_path, output_img_path) 112 | -------------------------------------------------------------------------------- /utils/generate_motion_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import PIL 4 | import glob 5 | import torch 6 | import argparse 7 | import numpy as np 8 | 9 | from colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary 10 | 11 | import skimage.morphology 12 | import torchvision 13 | from flow_utils import read_optical_flow, compute_epipolar_distance, skew 14 | 15 | 16 | 17 | def create_dir(dir): 18 | if not os.path.exists(dir): 19 | os.makedirs(dir) 20 | 21 | 22 | def extract_poses(im): 23 | R = im.qvec2rotmat() 24 | t = im.tvec.reshape([3,1]) 25 | bottom = np.array([0,0,0,1.]).reshape([1,4]) 26 | 27 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 28 | 29 | return m 30 | 31 | 32 | def load_colmap_data(realdir): 33 | 34 | camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin') 35 | camdata = read_cameras_binary(camerasfile) 36 | 37 | list_of_keys = list(camdata.keys()) 38 | cam = camdata[list_of_keys[0]] 39 | print( 'Cameras', len(cam)) 40 | 41 | h, w, f = cam.height, cam.width, cam.params[0] 42 | # w, h, f = factor * w, factor * h, factor * f 43 | hwf = np.array([h,w,f]).reshape([3,1]) 44 | 45 | imagesfile = os.path.join(realdir, 'sparse/0/images.bin') 46 | imdata = read_images_binary(imagesfile) 47 | 48 | w2c_mats = [] 49 | # bottom = np.array([0,0,0,1.]).reshape([1,4]) 50 | 51 | names = [imdata[k].name for k in imdata] 52 | img_keys = [k for k in imdata] 53 | 54 | print( 'Images #', len(names)) 55 | perm = np.argsort(names) 56 | 57 | return imdata, perm, img_keys, hwf 58 | 59 | 60 | def run_maskrcnn(model, img_path, intWidth=1024, intHeight=576): 61 | 62 | # intHeight = 576 63 | # intWidth = 1024 64 | 65 | threshold = 0.5 66 | 67 | o_image = PIL.Image.open(img_path) 68 | image = o_image.resize((intWidth, intHeight), PIL.Image.ANTIALIAS) 69 | 70 | image_tensor = torchvision.transforms.functional.to_tensor(image).cuda() 71 | 72 | tenHumans = torch.FloatTensor(intHeight, intWidth).fill_(1.0).cuda() 73 | 74 | objPredictions = model([image_tensor])[0] 75 | 76 | for intMask in range(objPredictions['masks'].size(0)): 77 | if objPredictions['scores'][intMask].item() > threshold: 78 | if objPredictions['labels'][intMask].item() == 1: # person 79 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 80 | 81 | if objPredictions['labels'][intMask].item() == 4: # motorcycle 82 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 83 | 84 | if objPredictions['labels'][intMask].item() == 2: # bicycle 85 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 86 | 87 | if objPredictions['labels'][intMask].item() == 8: # truck 88 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 89 | 90 | if objPredictions['labels'][intMask].item() == 28: # umbrella 91 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 92 | 93 | if objPredictions['labels'][intMask].item() == 17: # cat 94 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 95 | 96 | if objPredictions['labels'][intMask].item() == 18: # dog 97 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 98 | 99 | if objPredictions['labels'][intMask].item() == 36: # snowboard 100 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 101 | 102 | if objPredictions['labels'][intMask].item() == 41: # skateboard 103 | tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0 104 | 105 | npyMask = skimage.morphology.erosion(tenHumans.cpu().numpy(), 106 | skimage.morphology.disk(1)) 107 | npyMask = ((npyMask < 1e-3) * 255.0).clip(0.0, 255.0).astype(np.uint8) 108 | return npyMask 109 | 110 | 111 | def motion_segmentation(basedir, threshold, 112 | input_semantic_w=1024, 113 | input_semantic_h=576): 114 | 115 | points3dfile = os.path.join(basedir, 'sparse/0/points3D.bin') 116 | pts3d = read_points3d_binary(points3dfile) 117 | 118 | img_dir = glob.glob(basedir + '/images_colmap')[0] 119 | img0 = glob.glob(glob.glob(img_dir)[0] + '/*jpg')[0] 120 | shape_0 = cv2.imread(img0).shape 121 | 122 | resized_height, resized_width = shape_0[0], shape_0[1] 123 | 124 | imdata, perm, img_keys, hwf = load_colmap_data(basedir) 125 | scale_x, scale_y = resized_width / float(hwf[1]), resized_height / float(hwf[0]) 126 | 127 | K = np.eye(3) 128 | K[0, 0] = hwf[2] 129 | K[0, 2] = hwf[1] / 2. 130 | K[1, 1] = hwf[2] 131 | K[1, 2] = hwf[0] / 2. 132 | 133 | xx = range(0, resized_width) 134 | yy = range(0, resized_height) 135 | xv, yv = np.meshgrid(xx, yy) 136 | p_ref = np.float32(np.stack((xv, yv), axis=-1)) 137 | p_ref_h = np.reshape(p_ref, (-1, 2)) 138 | p_ref_h = np.concatenate((p_ref_h, np.ones((p_ref_h.shape[0], 1))), axis=-1).T 139 | 140 | num_frames = len(perm) 141 | 142 | if os.path.isdir(os.path.join(basedir, 'images_colmap')): 143 | num_colmap_frames = len(glob.glob(os.path.join(basedir, 'images_colmap', '*.jpg'))) 144 | num_data_frames = len(glob.glob(os.path.join(basedir, 'images', '*.png'))) 145 | 146 | if num_colmap_frames != num_data_frames: 147 | num_frames = num_data_frames 148 | 149 | 150 | save_mask_dir = os.path.join(basedir, 'motion_segmentation') 151 | create_dir(save_mask_dir) 152 | 153 | for i in range(0, num_frames): 154 | im_prev = imdata[img_keys[perm[max(0, i - 1)]]] 155 | im_ref = imdata[img_keys[perm[i]]] 156 | im_post = imdata[img_keys[perm[min(num_frames -1, i + 1)]]] 157 | 158 | print(im_prev.name, im_ref.name, im_post.name) 159 | 160 | T_prev_G = extract_poses(im_prev) 161 | T_ref_G = extract_poses(im_ref) 162 | T_post_G = extract_poses(im_post) 163 | 164 | T_ref2prev = np.dot(T_prev_G, np.linalg.inv(T_ref_G)) 165 | T_ref2post = np.dot(T_post_G, np.linalg.inv(T_ref_G)) 166 | # load optical flow 167 | 168 | if i == 0: 169 | fwd_flow, _ = read_optical_flow(basedir, 170 | im_ref.name, 171 | read_fwd=True) 172 | bwd_flow = np.zeros_like(fwd_flow) 173 | elif i == num_frames - 1: 174 | bwd_flow, _ = read_optical_flow(basedir, 175 | im_ref.name, 176 | read_fwd=False) 177 | fwd_flow = np.zeros_like(bwd_flow) 178 | else: 179 | fwd_flow, _ = read_optical_flow(basedir, 180 | im_ref.name, 181 | read_fwd=True) 182 | bwd_flow, _ = read_optical_flow(basedir, 183 | im_ref.name, 184 | read_fwd=False) 185 | 186 | p_post = p_ref + fwd_flow 187 | p_post_h = np.reshape(p_post, (-1, 2)) 188 | p_post_h = np.concatenate((p_post_h, np.ones((p_post_h.shape[0], 1))), axis=-1).T 189 | 190 | fwd_e_dist = compute_epipolar_distance(T_ref2post, K, 191 | p_ref_h, p_post_h) 192 | fwd_e_dist = np.reshape(fwd_e_dist, (fwd_flow.shape[0], fwd_flow.shape[1])) 193 | 194 | p_prev = p_ref + bwd_flow 195 | p_prev_h = np.reshape(p_prev, (-1, 2)) 196 | p_prev_h = np.concatenate((p_prev_h, np.ones((p_prev_h.shape[0], 1))), axis=-1).T 197 | 198 | bwd_e_dist = compute_epipolar_distance(T_ref2prev, K, 199 | p_ref_h, p_prev_h) 200 | bwd_e_dist = np.reshape(bwd_e_dist, (bwd_flow.shape[0], bwd_flow.shape[1])) 201 | 202 | e_dist = np.maximum(bwd_e_dist, fwd_e_dist) 203 | 204 | motion_mask = skimage.morphology.binary_opening(e_dist > threshold, skimage.morphology.disk(1)) 205 | 206 | cv2.imwrite(os.path.join(save_mask_dir, im_ref.name.replace('.jpg', '.png')), np.uint8(255 * (0. + motion_mask))) 207 | 208 | # RUN SEMANTIC SEGMENTATION 209 | img_dir = os.path.join(basedir, 'images') 210 | img_path_list = sorted(glob.glob(os.path.join(img_dir, '*.jpg'))) \ 211 | + sorted(glob.glob(os.path.join(img_dir, '*.png'))) 212 | semantic_mask_dir = os.path.join(basedir, 'semantic_mask') 213 | netMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval() 214 | create_dir(semantic_mask_dir) 215 | 216 | 217 | for i in range(0, len(img_path_list)): 218 | img_path = img_path_list[i] 219 | img_name = img_path.split('/')[-1] 220 | semantic_mask = run_maskrcnn(netMaskrcnn, img_path, 221 | input_semantic_w, 222 | input_semantic_h) 223 | cv2.imwrite(os.path.join(semantic_mask_dir, 224 | img_name.replace('.jpg', '.png')), 225 | semantic_mask) 226 | 227 | # combine them 228 | save_mask_dir = os.path.join(basedir, 'motion_masks') 229 | create_dir(save_mask_dir) 230 | 231 | mask_dir = os.path.join(basedir, 'motion_segmentation') 232 | mask_path_list = sorted(glob.glob(os.path.join(mask_dir, '*.png'))) 233 | 234 | semantic_dir = os.path.join(basedir, 'semantic_mask') 235 | 236 | for mask_path in mask_path_list: 237 | print(mask_path) 238 | 239 | motion_mask = cv2.imread(mask_path) 240 | motion_mask = cv2.resize(motion_mask, (resized_width, resized_height), 241 | interpolation=cv2.INTER_NEAREST) 242 | motion_mask = motion_mask[:, :, 0] > 0.1 243 | 244 | # combine from motion segmentation 245 | semantic_mask = cv2.imread(os.path.join(semantic_dir, mask_path.split('/')[-1])) 246 | semantic_mask = cv2.resize(semantic_mask, (resized_width, resized_height), 247 | interpolation=cv2.INTER_NEAREST) 248 | semantic_mask = semantic_mask[:, :, 0] > 0.1 249 | motion_mask = semantic_mask | motion_mask 250 | 251 | motion_mask = skimage.morphology.dilation(motion_mask, skimage.morphology.disk(2)) 252 | cv2.imwrite(os.path.join(save_mask_dir, '%s'%mask_path.split('/')[-1]), 253 | np.uint8(np.clip((motion_mask), 0, 1) * 255) ) 254 | 255 | # delete old mask dir 256 | os.system('rm -r %s'%mask_dir) 257 | os.system('rm -r %s'%semantic_dir) 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = argparse.ArgumentParser() 262 | parser.add_argument("--dataset_path", type=str, help='Dataset path') 263 | parser.add_argument("--epi_threshold", type=float, 264 | default=1.0, 265 | help='epipolar distance threshold for physical motion segmentation') 266 | 267 | parser.add_argument("--input_flow_w", type=int, 268 | default=768, 269 | help='input image width for optical flow, \ 270 | the height will be computed based on original aspect ratio ') 271 | 272 | parser.add_argument("--input_semantic_w", type=int, 273 | default=1024, 274 | help='input image width for semantic segmentation') 275 | 276 | parser.add_argument("--input_semantic_h", type=int, 277 | default=576, 278 | help='input image height for semantic segmentation') 279 | args = parser.parse_args() 280 | 281 | motion_segmentation(args.dataset_path, args.epi_threshold, 282 | args.input_semantic_w, 283 | args.input_semantic_h) 284 | -------------------------------------------------------------------------------- /utils/generate_pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import numpy as np 5 | from colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary 6 | 7 | 8 | def load_colmap_data(realdir): 9 | 10 | camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin') 11 | camdata = read_cameras_binary(camerasfile) 12 | 13 | list_of_keys = list(camdata.keys()) 14 | cam = camdata[list_of_keys[0]] 15 | print( 'Cameras', len(cam)) 16 | 17 | h, w, f = cam.height, cam.width, cam.params[0] 18 | # w, h, f = factor * w, factor * h, factor * f 19 | hwf = np.array([h,w,f]).reshape([3,1]) 20 | 21 | imagesfile = os.path.join(realdir, 'sparse/0/images.bin') 22 | imdata = read_images_binary(imagesfile) 23 | 24 | w2c_mats = [] 25 | bottom = np.array([0,0,0,1.]).reshape([1,4]) 26 | 27 | names = [imdata[k].name for k in imdata] 28 | img_keys = [k for k in imdata] 29 | 30 | print('Images #', len(names)) 31 | perm = np.argsort(names) 32 | 33 | points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin') 34 | pts3d = read_points3d_binary(points3dfile) 35 | 36 | bounds_mats = [] 37 | 38 | for i in perm[0:len(img_keys)]: 39 | 40 | im = imdata[img_keys[i]] 41 | print(im.name) 42 | R = im.qvec2rotmat() 43 | t = im.tvec.reshape([3,1]) 44 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 45 | w2c_mats.append(m) 46 | 47 | pts_3d_idx = im.point3D_ids 48 | pts_3d_vis_idx = pts_3d_idx[pts_3d_idx >= 0] 49 | 50 | # 51 | depth_list = [] 52 | for k in range(len(pts_3d_vis_idx)): 53 | point_info = pts3d[pts_3d_vis_idx[k]] 54 | 55 | P_g = point_info.xyz 56 | P_c = np.dot(R, P_g.reshape(3, 1)) + t.reshape(3, 1) 57 | depth_list.append(P_c[2]) 58 | 59 | zs = np.array(depth_list) 60 | close_depth, inf_depth = np.percentile(zs, 5), np.percentile(zs, 95) 61 | bounds = np.array([close_depth, inf_depth]) 62 | bounds_mats.append(bounds) 63 | 64 | w2c_mats = np.stack(w2c_mats, 0) 65 | c2w_mats = np.linalg.inv(w2c_mats) 66 | 67 | poses = c2w_mats[:, :3, :4].transpose([1,2,0]) 68 | poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis], 69 | [1,1,poses.shape[-1]])], 1) 70 | 71 | # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t] 72 | poses = np.concatenate([poses[:, 1:2, :], 73 | poses[:, 0:1, :], 74 | -poses[:, 2:3, :], 75 | poses[:, 3:4, :], 76 | poses[:, 4:5, :]], 1) 77 | 78 | save_arr = [] 79 | 80 | for i in range((poses.shape[2])): 81 | save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0)) 82 | 83 | save_arr = np.array(save_arr) 84 | print(save_arr.shape) 85 | 86 | # Use all frames to calculate COLMAP camera poses. 87 | if os.path.isdir(os.path.join(realdir, 'images_colmap')): 88 | num_colmap_frames = len(glob.glob(os.path.join(realdir, 'images_colmap', '*.jpg'))) 89 | num_data_frames = len(glob.glob(os.path.join(realdir, 'images', '*.png'))) 90 | 91 | assert num_colmap_frames == save_arr.shape[0] 92 | np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr[:num_data_frames, :]) 93 | else: 94 | np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr) 95 | 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--dataset_path", type=str, 100 | help='Dataset path') 101 | 102 | args = parser.parse_args() 103 | 104 | load_colmap_data(args.dataset_path) 105 | -------------------------------------------------------------------------------- /utils/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | Args: 8 | path (str): file path 9 | """ 10 | parameters = torch.load(path, map_location=torch.device('cpu')) 11 | 12 | if "optimizer" in parameters: 13 | parameters = parameters["model"] 14 | 15 | self.load_state_dict(parameters) 16 | -------------------------------------------------------------------------------- /utils/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | -------------------------------------------------------------------------------- /utils/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /utils/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /utils/midas/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Slice(nn.Module): 10 | def __init__(self, start_index=1): 11 | super(Slice, self).__init__() 12 | self.start_index = start_index 13 | 14 | def forward(self, x): 15 | return x[:, self.start_index :] 16 | 17 | 18 | class AddReadout(nn.Module): 19 | def __init__(self, start_index=1): 20 | super(AddReadout, self).__init__() 21 | self.start_index = start_index 22 | 23 | def forward(self, x): 24 | if self.start_index == 2: 25 | readout = (x[:, 0] + x[:, 1]) / 2 26 | else: 27 | readout = x[:, 0] 28 | return x[:, self.start_index :] + readout.unsqueeze(1) 29 | 30 | 31 | class ProjectReadout(nn.Module): 32 | def __init__(self, in_features, start_index=1): 33 | super(ProjectReadout, self).__init__() 34 | self.start_index = start_index 35 | 36 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 37 | 38 | def forward(self, x): 39 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 40 | features = torch.cat((x[:, self.start_index :], readout), -1) 41 | 42 | return self.project(features) 43 | 44 | 45 | class Transpose(nn.Module): 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0 = dim0 49 | self.dim1 = dim1 50 | 51 | def forward(self, x): 52 | x = x.transpose(self.dim0, self.dim1) 53 | return x 54 | 55 | 56 | def forward_vit(pretrained, x): 57 | b, c, h, w = x.shape 58 | 59 | glob = pretrained.model.forward_flex(x) 60 | 61 | layer_1 = pretrained.activations["1"] 62 | layer_2 = pretrained.activations["2"] 63 | layer_3 = pretrained.activations["3"] 64 | layer_4 = pretrained.activations["4"] 65 | 66 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 67 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 68 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 69 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 70 | 71 | unflatten = nn.Sequential( 72 | nn.Unflatten( 73 | 2, 74 | torch.Size( 75 | [ 76 | h // pretrained.model.patch_size[1], 77 | w // pretrained.model.patch_size[0], 78 | ] 79 | ), 80 | ) 81 | ) 82 | 83 | if layer_1.ndim == 3: 84 | layer_1 = unflatten(layer_1) 85 | if layer_2.ndim == 3: 86 | layer_2 = unflatten(layer_2) 87 | if layer_3.ndim == 3: 88 | layer_3 = unflatten(layer_3) 89 | if layer_4.ndim == 3: 90 | layer_4 = unflatten(layer_4) 91 | 92 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 93 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 94 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 95 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 96 | 97 | return layer_1, layer_2, layer_3, layer_4 98 | 99 | 100 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 101 | posemb_tok, posemb_grid = ( 102 | posemb[:, : self.start_index], 103 | posemb[0, self.start_index :], 104 | ) 105 | 106 | gs_old = int(math.sqrt(len(posemb_grid))) 107 | 108 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 109 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 110 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 111 | 112 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 113 | 114 | return posemb 115 | 116 | 117 | def forward_flex(self, x): 118 | b, c, h, w = x.shape 119 | 120 | pos_embed = self._resize_pos_embed( 121 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 122 | ) 123 | 124 | B = x.shape[0] 125 | 126 | if hasattr(self.patch_embed, "backbone"): 127 | x = self.patch_embed.backbone(x) 128 | if isinstance(x, (list, tuple)): 129 | x = x[-1] # last feature if backbone outputs list/tuple of features 130 | 131 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 132 | 133 | if getattr(self, "dist_token", None) is not None: 134 | cls_tokens = self.cls_token.expand( 135 | B, -1, -1 136 | ) # stole cls_tokens impl from Phil Wang, thanks 137 | dist_token = self.dist_token.expand(B, -1, -1) 138 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 139 | else: 140 | cls_tokens = self.cls_token.expand( 141 | B, -1, -1 142 | ) # stole cls_tokens impl from Phil Wang, thanks 143 | x = torch.cat((cls_tokens, x), dim=1) 144 | 145 | x = x + pos_embed 146 | x = self.pos_drop(x) 147 | 148 | for blk in self.blocks: 149 | x = blk(x) 150 | 151 | x = self.norm(x) 152 | 153 | return x 154 | 155 | 156 | activations = {} 157 | 158 | 159 | def get_activation(name): 160 | def hook(model, input, output): 161 | activations[name] = output 162 | 163 | return hook 164 | 165 | 166 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 167 | if use_readout == "ignore": 168 | readout_oper = [Slice(start_index)] * len(features) 169 | elif use_readout == "add": 170 | readout_oper = [AddReadout(start_index)] * len(features) 171 | elif use_readout == "project": 172 | readout_oper = [ 173 | ProjectReadout(vit_features, start_index) for out_feat in features 174 | ] 175 | else: 176 | assert ( 177 | False 178 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 179 | 180 | return readout_oper 181 | 182 | 183 | def _make_vit_b16_backbone( 184 | model, 185 | features=[96, 192, 384, 768], 186 | size=[384, 384], 187 | hooks=[2, 5, 8, 11], 188 | vit_features=768, 189 | use_readout="ignore", 190 | start_index=1, 191 | ): 192 | pretrained = nn.Module() 193 | 194 | pretrained.model = model 195 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 196 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 197 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 198 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 199 | 200 | pretrained.activations = activations 201 | 202 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 203 | 204 | # 32, 48, 136, 384 205 | pretrained.act_postprocess1 = nn.Sequential( 206 | readout_oper[0], 207 | Transpose(1, 2), 208 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 209 | nn.Conv2d( 210 | in_channels=vit_features, 211 | out_channels=features[0], 212 | kernel_size=1, 213 | stride=1, 214 | padding=0, 215 | ), 216 | nn.ConvTranspose2d( 217 | in_channels=features[0], 218 | out_channels=features[0], 219 | kernel_size=4, 220 | stride=4, 221 | padding=0, 222 | bias=True, 223 | dilation=1, 224 | groups=1, 225 | ), 226 | ) 227 | 228 | pretrained.act_postprocess2 = nn.Sequential( 229 | readout_oper[1], 230 | Transpose(1, 2), 231 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 232 | nn.Conv2d( 233 | in_channels=vit_features, 234 | out_channels=features[1], 235 | kernel_size=1, 236 | stride=1, 237 | padding=0, 238 | ), 239 | nn.ConvTranspose2d( 240 | in_channels=features[1], 241 | out_channels=features[1], 242 | kernel_size=2, 243 | stride=2, 244 | padding=0, 245 | bias=True, 246 | dilation=1, 247 | groups=1, 248 | ), 249 | ) 250 | 251 | pretrained.act_postprocess3 = nn.Sequential( 252 | readout_oper[2], 253 | Transpose(1, 2), 254 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 255 | nn.Conv2d( 256 | in_channels=vit_features, 257 | out_channels=features[2], 258 | kernel_size=1, 259 | stride=1, 260 | padding=0, 261 | ), 262 | ) 263 | 264 | pretrained.act_postprocess4 = nn.Sequential( 265 | readout_oper[3], 266 | Transpose(1, 2), 267 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 268 | nn.Conv2d( 269 | in_channels=vit_features, 270 | out_channels=features[3], 271 | kernel_size=1, 272 | stride=1, 273 | padding=0, 274 | ), 275 | nn.Conv2d( 276 | in_channels=features[3], 277 | out_channels=features[3], 278 | kernel_size=3, 279 | stride=2, 280 | padding=1, 281 | ), 282 | ) 283 | 284 | pretrained.model.start_index = start_index 285 | pretrained.model.patch_size = [16, 16] 286 | 287 | # We inject this function into the VisionTransformer instances so that 288 | # we can use it with interpolated position embeddings without modifying the library source. 289 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 290 | pretrained.model._resize_pos_embed = types.MethodType( 291 | _resize_pos_embed, pretrained.model 292 | ) 293 | 294 | return pretrained 295 | 296 | 297 | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): 298 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 299 | 300 | hooks = [5, 11, 17, 23] if hooks == None else hooks 301 | return _make_vit_b16_backbone( 302 | model, 303 | features=[256, 512, 1024, 1024], 304 | hooks=hooks, 305 | vit_features=1024, 306 | use_readout=use_readout, 307 | ) 308 | 309 | 310 | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): 311 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 312 | 313 | hooks = [2, 5, 8, 11] if hooks == None else hooks 314 | return _make_vit_b16_backbone( 315 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 316 | ) 317 | 318 | 319 | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): 320 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 321 | 322 | hooks = [2, 5, 8, 11] if hooks == None else hooks 323 | return _make_vit_b16_backbone( 324 | model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout 325 | ) 326 | 327 | 328 | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): 329 | model = timm.create_model( 330 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 331 | ) 332 | 333 | hooks = [2, 5, 8, 11] if hooks == None else hooks 334 | return _make_vit_b16_backbone( 335 | model, 336 | features=[96, 192, 384, 768], 337 | hooks=hooks, 338 | use_readout=use_readout, 339 | start_index=2, 340 | ) 341 | 342 | 343 | def _make_vit_b_rn50_backbone( 344 | model, 345 | features=[256, 512, 768, 768], 346 | size=[384, 384], 347 | hooks=[0, 1, 8, 11], 348 | vit_features=768, 349 | use_vit_only=False, 350 | use_readout="ignore", 351 | start_index=1, 352 | ): 353 | pretrained = nn.Module() 354 | 355 | pretrained.model = model 356 | 357 | if use_vit_only == True: 358 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 359 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 360 | else: 361 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 362 | get_activation("1") 363 | ) 364 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 365 | get_activation("2") 366 | ) 367 | 368 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 369 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 370 | 371 | pretrained.activations = activations 372 | 373 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 374 | 375 | if use_vit_only == True: 376 | pretrained.act_postprocess1 = nn.Sequential( 377 | readout_oper[0], 378 | Transpose(1, 2), 379 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 380 | nn.Conv2d( 381 | in_channels=vit_features, 382 | out_channels=features[0], 383 | kernel_size=1, 384 | stride=1, 385 | padding=0, 386 | ), 387 | nn.ConvTranspose2d( 388 | in_channels=features[0], 389 | out_channels=features[0], 390 | kernel_size=4, 391 | stride=4, 392 | padding=0, 393 | bias=True, 394 | dilation=1, 395 | groups=1, 396 | ), 397 | ) 398 | 399 | pretrained.act_postprocess2 = nn.Sequential( 400 | readout_oper[1], 401 | Transpose(1, 2), 402 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 403 | nn.Conv2d( 404 | in_channels=vit_features, 405 | out_channels=features[1], 406 | kernel_size=1, 407 | stride=1, 408 | padding=0, 409 | ), 410 | nn.ConvTranspose2d( 411 | in_channels=features[1], 412 | out_channels=features[1], 413 | kernel_size=2, 414 | stride=2, 415 | padding=0, 416 | bias=True, 417 | dilation=1, 418 | groups=1, 419 | ), 420 | ) 421 | else: 422 | pretrained.act_postprocess1 = nn.Sequential( 423 | nn.Identity(), nn.Identity(), nn.Identity() 424 | ) 425 | pretrained.act_postprocess2 = nn.Sequential( 426 | nn.Identity(), nn.Identity(), nn.Identity() 427 | ) 428 | 429 | pretrained.act_postprocess3 = nn.Sequential( 430 | readout_oper[2], 431 | Transpose(1, 2), 432 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 433 | nn.Conv2d( 434 | in_channels=vit_features, 435 | out_channels=features[2], 436 | kernel_size=1, 437 | stride=1, 438 | padding=0, 439 | ), 440 | ) 441 | 442 | pretrained.act_postprocess4 = nn.Sequential( 443 | readout_oper[3], 444 | Transpose(1, 2), 445 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 446 | nn.Conv2d( 447 | in_channels=vit_features, 448 | out_channels=features[3], 449 | kernel_size=1, 450 | stride=1, 451 | padding=0, 452 | ), 453 | nn.Conv2d( 454 | in_channels=features[3], 455 | out_channels=features[3], 456 | kernel_size=3, 457 | stride=2, 458 | padding=1, 459 | ), 460 | ) 461 | 462 | pretrained.model.start_index = start_index 463 | pretrained.model.patch_size = [16, 16] 464 | 465 | # We inject this function into the VisionTransformer instances so that 466 | # we can use it with interpolated position embeddings without modifying the library source. 467 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 468 | 469 | # We inject this function into the VisionTransformer instances so that 470 | # we can use it with interpolated position embeddings without modifying the library source. 471 | pretrained.model._resize_pos_embed = types.MethodType( 472 | _resize_pos_embed, pretrained.model 473 | ) 474 | 475 | return pretrained 476 | 477 | 478 | def _make_pretrained_vitb_rn50_384( 479 | pretrained, use_readout="ignore", hooks=None, use_vit_only=False 480 | ): 481 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 482 | 483 | hooks = [0, 1, 8, 11] if hooks == None else hooks 484 | return _make_vit_b_rn50_backbone( 485 | model, 486 | features=[256, 512, 768, 768], 487 | size=[384, 384], 488 | hooks=hooks, 489 | use_vit_only=use_vit_only, 490 | use_readout=use_readout, 491 | ) 492 | --------------------------------------------------------------------------------