├── LICENSE ├── README.md ├── data ├── 7Scenes │ ├── .DS_Store │ ├── ._.DS_Store │ ├── chess │ │ ├── pose_avg_stats.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json │ ├── fire │ │ ├── pose_avg_stats.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json │ ├── heads │ │ ├── pose_avg_stats.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json │ ├── office │ │ ├── pose_avg_stats.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json │ ├── pumpkin │ │ ├── pose_avg_stats.txt │ │ ├── pose_avg_stats_old.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json │ ├── redkitchen │ │ ├── pose_avg_stats.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json │ └── stairs │ │ ├── pose_avg_stats.txt │ │ ├── pose_stats.txt │ │ ├── stats.txt │ │ └── world_setup.json ├── Cambridge │ ├── GreatCourt │ │ ├── pose_avg_stats.txt │ │ └── world_setup.json │ ├── KingsCollege │ │ ├── pose_avg_stats.txt │ │ └── world_setup.json │ ├── OldHospital │ │ ├── pose_avg_stats.txt │ │ └── world_setup.json │ ├── ShopFacade │ │ ├── pose_avg_stats.txt │ │ └── world_setup.json │ └── StMarysChurch │ │ ├── pose_avg_stats.txt │ │ └── world_setup.json └── deepslam_data │ └── .gitignore ├── dataset_loaders ├── cambridge_scenes.py ├── load_7Scenes.py ├── load_Cambridge.py ├── seven_scenes.py └── utils │ └── color.py ├── imgs └── DFNet.png ├── requirements.txt └── script ├── config_dfnet.txt ├── config_dfnetdm.txt ├── config_nerfh.txt ├── dm ├── __init__.py ├── callbacks.py ├── direct_pose_model.py ├── options.py ├── pose_model.py └── prepare_data.py ├── feature ├── dfnet.py ├── direct_feature_matching.py ├── efficientnet.py ├── misc.py ├── model.py └── options.py ├── models ├── __init__.py ├── losses.py ├── metrics.py ├── nerf.py ├── nerfw.py ├── options.py ├── ray_utils.py └── rendering.py ├── run_feature.py ├── run_nerf.py ├── train.py └── utils ├── set_sys_path.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Active Vision Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFNet: Enhance Absolute Pose Regression with Direct Feature Matching 2 | 3 | **[Shuai Chen](https://scholar.google.com/citations?user=c0xTh_YAAAAJ&hl=en), [Xinghui Li](https://scholar.google.com/citations?user=XLlgbBoAAAAJ&hl=en), [Zirui Wang](https://scholar.google.com/citations?user=zCBKqa8AAAAJ&hl=en), and [Victor Adrian Prisacariu](https://scholar.google.com/citations?user=GmWA-LoAAAAJ&hl=en) (ECCV 2022)** 4 | 5 | **[Project Page](https://dfnet.active.vision) | [Paper](https://arxiv.org/abs/2204.00559)** 6 | 7 | [![DFNet](imgs/DFNet.png)](https://arxiv.org/abs/2204.00559) 8 | 9 | ## Setup 10 | ### Installing Requirements 11 | We tested our code based on CUDA11.3+, PyTorch 1.11.0+, and Python 3.7+ using [docker](https://docs.docker.com/engine/install/ubuntu/). 12 | 13 | Rest of dependencies are in requirement.txt 14 | 15 | ### Data Preparation 16 | - **7-Scenes** 17 | 18 | We use a similar data preparation as in [MapNet](https://github.com/NVlabs/geomapnet). You can download the [7-Scenes](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/) datasets to the `data/deepslam_data` directory. 19 | 20 | Or you can use simlink 21 | 22 | ```sh 23 | cd data/deepslam_data && ln -s 7SCENES_DIR 7Scenes 24 | ``` 25 | 26 | Notice that we additionally computed a pose averaging stats (pose_avg_stats.txt) and manually tuned world_setup.json in `data/7Scenes` to align the 7Scenes' coordinate system with NeRF's coordinate system. You could generate your own re-alignment to a new pose_avg_stats.txt using the `--save_pose_avg_stats` configuration. 27 | 28 | - **Cambridge Landmarks** 29 | 30 | You can download the Cambridge Landmarks dataset using this script [here](https://github.com/vislearn/dsacstar/blob/master/datasets/setup_cambridge.py). Please also put the pose_avg_stats.txt and world_setup.json to the `data/Cambridge/CAMBRIDGE_SCENES` like we provided in the source code. 31 | 32 | ## Training 33 | 34 | Our method relies on a pretrained Histogram-assisted NeRF model and a DFNet model as we stated in the paper. We have provide example config files in our repo. The followings are examples to train the models. 35 | 36 | - NeRF model 37 | 38 | ```sh 39 | python run_nerf.py --config config_nerfh.txt 40 | ``` 41 | 42 | - DFNet model 43 | 44 | ```sh 45 | python run_feature.py --config config_dfnet.txt 46 | ``` 47 | 48 | - Direct Feature Matching (DFNetdm) 49 | 50 | ```sh 51 | python train.py --config config_dfnetdm.txt 52 | ``` 53 | 54 | ## Evaluation 55 | We provide methods to evaluate our models. 56 | - To evaluate the NeRF model in PSNR, simply add `--render_test` argument. 57 | 58 | ```sh 59 | python run_nerf.py --config config_nerfh.txt --render_test 60 | ``` 61 | 62 | - To evaluate APR performance of the DFNet model, you can just add `--eval --testskip=1 --pretrain_model_path=../logs/PATH_TO_CHECKPOINT`. For example: 63 | 64 | ```sh 65 | python run_feature.py --config config_dfnet.txt --eval --testskip=1 --pretrain_model_path=../logs/heads/dfnet/checkpoint.pt 66 | ``` 67 | 68 | - Same to evaluate APR performance for the DFNetdm model 69 | 70 | ```sh 71 | python train.py --config config_dfnetdm.txt --eval --testskip=1 --pretrain_model_path=../logs/heads/dfnetdm/checkpoint.pt 72 | ``` 73 | 74 | ## Pre-trained model 75 | We provide the 7-Scenes and Cambridge pre-trained models [here](https://www.robots.ox.ac.uk/~shuaic/DFNet2022/pretrain_models.tar.gz). Some models have slight better results than our paper reported. We suggest the models to be put in a new directory (`./logs/`) of the project. 76 | 77 | Notice we additionally provided Cambridge's Great Court scene models, although we didn't include the results in our main paper for fair comparisons with other works. 78 | 79 | Due to my limited resource, my pre-trained models are trained using 3080ti or 1080ti. I noticed earlier that the model's performance might vary slightly (could be better or worse) when inferencing with different types of GPUs, even using the exact same model. Therefore, all experiments on the paper are reported based on the same GPUs as they were trained. 80 | 81 | ## Acknowledgement 82 | We thank Michael Hobley, Theo Costain, Lixiong Chen, and Kejie Li for their generous discussion on this work. 83 | 84 | Most of our code is built upon [Direct-PoseNet](https://github.com/ActiveVisionLab/direct-posenet). Part of our Histogram-assisted NeRF implementation is referenced from the reproduced NeRFW code [here](https://github.com/kwea123/nerf_pl/tree/nerfw). We thank [@kwea123](https://github.com/kwea123) for this excellent work! 85 | 86 | ## Citation 87 | Please cite our paper and star this repo if you find our work helpful. Thanks! 88 | ``` 89 | @inproceedings{chen2022dfnet, 90 | title={DFNet: Enhance Absolute Pose Regression with Direct Feature Matching}, 91 | author={Chen, Shuai and Li, Xinghui and Wang, Zirui and Prisacariu, Victor}, 92 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 93 | year={2022} 94 | } 95 | ``` 96 | This code builds on previous camera relocalization pipelines, namely Direct-PoseNet. Please consider citing: 97 | ``` 98 | @inproceedings{chen2021direct, 99 | title={Direct-PoseNet: Absolute pose regression with photometric consistency}, 100 | author={Chen, Shuai and Wang, Zirui and Prisacariu, Victor}, 101 | booktitle={2021 International Conference on 3D Vision (3DV)}, 102 | pages={1175--1185}, 103 | year={2021}, 104 | organization={IEEE} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /data/7Scenes/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/data/7Scenes/.DS_Store -------------------------------------------------------------------------------- /data/7Scenes/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/data/7Scenes/._.DS_Store -------------------------------------------------------------------------------- /data/7Scenes/chess/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.350092088047137207e-01 1.857172640673737662e-01 -3.021040835171091565e-01 4.217527911023845610e-01 2 | -3.778758022057916027e-02 8.992282578255548220e-01 4.358447419771065423e-01 -7.391797421190476891e-01 3 | 3.526044217412146464e-01 -3.961030650668416753e-01 8.478045078986057304e-01 2.995973868145236363e-01 4 | -------------------------------------------------------------------------------- /data/7Scenes/chess/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/chess/stats.txt: -------------------------------------------------------------------------------- 1 | 5.009650708326967017e-01 4.413125411911532625e-01 4.458285283490354689e-01 2 | 4.329720281018845096e-02 5.278270383679337097e-02 4.760929057962018374e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/chess/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":2, 4 | "pose_scale": 0.5, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 1.0] 7 | } -------------------------------------------------------------------------------- /data/7Scenes/fire/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.682675339923704216e-01 1.308464847914461437e-01 -2.129252921427031431e-01 4.483261509061373107e-02 2 | -9.619441241477342738e-03 8.708690325271444266e-01 4.914209952122894909e-01 -4.201599145869548413e-01 3 | 2.497307529451176511e-01 -4.737787728496888895e-01 8.444928806274848432e-01 7.115903476590906829e-01 4 | -------------------------------------------------------------------------------- /data/7Scenes/fire/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/fire/stats.txt: -------------------------------------------------------------------------------- 1 | 5.222627479256024552e-01 4.620521564670138082e-01 4.212473626365915158e-01 2 | 5.550322239689903236e-02 5.943252514694064015e-02 5.525370066993806617e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/fire/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":3, 4 | "pose_scale": 1, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 1.0] 7 | } -------------------------------------------------------------------------------- /data/7Scenes/heads/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.821477519358328134e-01 -8.872136230840418913e-02 1.658743899386847798e-01 -1.157979895409090576e-02 2 | 3.111216594703735891e-02 9.462589180611266082e-01 3.219100699261678855e-01 -1.317583753459090623e-01 3 | -1.855204207020725304e-01 -3.110025399573565497e-01 9.321263828701550347e-01 9.636789777181822836e-02 4 | -------------------------------------------------------------------------------- /data/7Scenes/heads/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/heads/stats.txt: -------------------------------------------------------------------------------- 1 | 4.570619554738562518e-01 4.504317877348855137e-01 4.586057516467524908e-01 2 | 7.874170624948270691e-02 7.747845434384653673e-02 7.183367877515742239e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/heads/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":2.5, 4 | "pose_scale": 1, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 1.0] 7 | } -------------------------------------------------------------------------------- /data/7Scenes/office/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.192216692444369341e-01 -2.380853847231172160e-01 3.136030490488191935e-01 -2.580896777083788868e-02 2 | 8.889970878427481960e-02 9.014018431780023155e-01 4.237588452095970570e-01 -8.784274026985845474e-01 3 | -3.835731541303978309e-01 -3.616490933163600818e-01 8.497538283137727744e-01 1.063082783627855354e+00 4 | -------------------------------------------------------------------------------- /data/7Scenes/office/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/office/stats.txt: -------------------------------------------------------------------------------- 1 | 4.703657901067226921e-01 4.414751487847252132e-01 4.351020758221028628e-01 2 | 7.105139804377599844e-02 7.191485421006868495e-02 6.783299267371162289e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/office/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":2, 4 | "pose_scale": 0.5, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 0.5] 7 | } -------------------------------------------------------------------------------- /data/7Scenes/pumpkin/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.994189485731389544e-01 7.232278172020349324e-03 -3.330854823320933411e-02 -6.206658357107126822e-02 2 | -8.310756492857687694e-03 9.994418989496620664e-01 -3.235462795969396704e-02 -7.690604500476190264e-01 3 | 3.305596102789840757e-02 3.261264749044923833e-02 9.989212775109888032e-01 4.472261112787878634e-01 4 | -------------------------------------------------------------------------------- /data/7Scenes/pumpkin/pose_avg_stats_old.txt: -------------------------------------------------------------------------------- 1 | 9.867033112503645897e-01 8.544426416488330733e-02 -1.382600929006191914e-01 7.374091044342952206e-02 2 | -9.057380802494104099e-02 9.953998641700174677e-01 -3.123292669878471872e-02 -7.475368646794867677e-01 3 | 1.349554032539169723e-01 4.334037530562200036e-02 9.899033543740218821e-01 3.342737444938814195e-01 4 | -------------------------------------------------------------------------------- /data/7Scenes/pumpkin/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/pumpkin/stats.txt: -------------------------------------------------------------------------------- 1 | 5.503370888799515859e-01 4.492568432042766124e-01 4.579284152018213705e-01 2 | 4.053158612557544727e-02 4.899782680513672939e-02 3.385843494567825074e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/pumpkin/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":2.5, 4 | "pose_scale": 0.5, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 1.0] 7 | } -------------------------------------------------------------------------------- /data/7Scenes/redkitchen/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.397923675833245172e-01 1.720816216688166589e-01 -2.952595829367096747e-01 1.068215764548530455e-01 2 | -1.553194639452672166e-01 9.846599603919916621e-01 7.950236801879838333e-02 -4.759053293419227559e-01 3 | 3.044111856550024142e-01 -2.885615852243439416e-02 9.521035406737251572e-01 9.771192826975949597e-01 4 | -------------------------------------------------------------------------------- /data/7Scenes/redkitchen/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/redkitchen/stats.txt: -------------------------------------------------------------------------------- 1 | 5.262172203420504291e-01 4.400453064527823366e-01 4.320846191351511711e-01 2 | 4.872459633076364760e-02 6.484063059696282272e-02 5.724255797232574716e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/redkitchen/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":2, 4 | "pose_scale": 0.5, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 0.5] 7 | } -------------------------------------------------------------------------------- /data/7Scenes/stairs/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.981044025604641767e-01 6.017846147761087006e-02 -1.289008780444797844e-02 4.026755230123485463e-02 2 | -4.703056619089211743e-02 8.809109637765394352e-01 4.709394862846526530e-01 -9.428126168765132986e-01 3 | 3.969543340464730397e-02 -4.694405464725816546e-01 8.820713383249345618e-01 3.852607943350118691e-01 4 | -------------------------------------------------------------------------------- /data/7Scenes/stairs/pose_stats.txt: -------------------------------------------------------------------------------- 1 | 0.0000000 0.0000000 0.0000000 2 | 1.0000000 1.0000000 1.0000000 3 | -------------------------------------------------------------------------------- /data/7Scenes/stairs/stats.txt: -------------------------------------------------------------------------------- 1 | 4.472714732115506964e-01 4.312183359438830910e-01 4.291487246732026972e-01 2 | 3.258580609153208241e-02 2.618736971489385446e-02 1.208855922484347589e-02 3 | -------------------------------------------------------------------------------- /data/7Scenes/stairs/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0, 3 | "far":4, 4 | "pose_scale": 1, 5 | "pose_scale2": 1, 6 | "move_all_cam_vec": [0.0, 0.0, 0.0] 7 | } -------------------------------------------------------------------------------- /data/Cambridge/GreatCourt/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 3.405292754461299309e-01 4.953070871398960184e-01 7.991937825040464904e-01 4.704508373014760281e+01 2 | -9.402316354416121458e-01 1.812586354202151695e-01 2.882876667504056800e-01 3.467785281210451842e+01 3 | -2.069849976503225410e-03 -8.495976674371187309e-01 5.274272643753655787e-01 1.101080132352710184e+00 4 | -------------------------------------------------------------------------------- /data/Cambridge/GreatCourt/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0.0, 3 | "far":10.0, 4 | "pose_scale": 0.3027, 5 | "pose_scale2": 0.2, 6 | "move_all_cam_vec": [0.0, 0.0, 0.0] 7 | } -------------------------------------------------------------------------------- /data/Cambridge/KingsCollege/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.995083419588323137e-01 -1.453974655309233331e-02 2.777895111190991154e-02 2.004095163645802913e+01 2 | -2.395968310872182219e-02 2.172811532927548528e-01 9.758149588979971867e-01 -2.354010655332784197e+01 3 | -2.022394471995193205e-02 -9.760007664924973403e-01 2.168259575466510436e-01 1.650110331018928678e+00 4 | -------------------------------------------------------------------------------- /data/Cambridge/KingsCollege/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0.0, 3 | "far":10.0, 4 | "pose_scale": 0.3027, 5 | "pose_scale2": 0.2, 6 | "move_all_cam_vec": [0.0, 0.0, 0.0] 7 | } -------------------------------------------------------------------------------- /data/Cambridge/OldHospital/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 9.997941252129602940e-01 6.239930741698496326e-03 1.930726428032739084e-02 1.319547963328867723e+01 2 | -3.333807443587469103e-03 -8.880897259859261705e-01 4.596580515189216398e-01 -6.473184854291670343e-01 3 | 2.001481745059596404e-02 -4.596277862168271500e-01 -8.878860879751624413e-01 2.310333011616541654e+01 4 | -------------------------------------------------------------------------------- /data/Cambridge/OldHospital/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0.0, 3 | "far":10.0, 4 | "pose_scale": 0.3027, 5 | "pose_scale2": 0.2, 6 | "move_all_cam_vec": [0.0, 0.0, 5.0] 7 | } -------------------------------------------------------------------------------- /data/Cambridge/ShopFacade/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | 2.084004683986779016e-01 1.972095064159990266e-02 9.778447365901210553e-01 -4.512817941282106560e+00 2 | -9.780353328393808221e-01 8.307943784904847639e-03 2.082735359757174609e-01 1.914896116567694540e+00 3 | -4.016526979027209426e-03 -9.997710048685441997e-01 2.101916590087021808e-02 1.768500113487243564e+00 4 | -------------------------------------------------------------------------------- /data/Cambridge/ShopFacade/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0.0, 3 | "far":20.0, 4 | "pose_scale": 0.3027, 5 | "pose_scale2": 0.32, 6 | "move_all_cam_vec": [0.0, 0.0, 2.5] 7 | } -------------------------------------------------------------------------------- /data/Cambridge/StMarysChurch/pose_avg_stats.txt: -------------------------------------------------------------------------------- 1 | -6.692001528162709878e-01 7.430812642562667492e-01 1.179059789653581552e-03 1.114036505648812359e+01 2 | 3.891382817260490012e-02 3.662925707351961935e-02 -9.985709847092467673e-01 -5.441265972613005403e-02 3 | -7.420625778515127502e-01 -6.681979738352623599e-01 -5.342844106669619036e-02 1.708768320112491068e+01 4 | -------------------------------------------------------------------------------- /data/Cambridge/StMarysChurch/world_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "near":0.0, 3 | "far":10.0, 4 | "pose_scale": 0.3027, 5 | "pose_scale2": 0.2, 6 | "move_all_cam_vec": [0.0, 0.0, 0.0] 7 | } -------------------------------------------------------------------------------- /data/deepslam_data/.gitignore: -------------------------------------------------------------------------------- 1 | 7Scenes 2 | -------------------------------------------------------------------------------- /dataset_loaders/cambridge_scenes.py: -------------------------------------------------------------------------------- 1 | """ 2 | pytorch data loader for the Cambridge Landmark dataset 3 | """ 4 | import os 5 | import os.path as osp 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | from torch.utils import data 10 | import sys 11 | import pdb 12 | import cv2 13 | from dataset_loaders.utils.color import rgb_to_yuv 14 | from torchvision.utils import save_image 15 | import json 16 | 17 | sys.path.insert(0, '../') 18 | #from common.pose_utils import process_poses 19 | 20 | ## NUMPY 21 | def qlog(q): 22 | """ 23 | Applies logarithm map to q 24 | :param q: (4,) 25 | :return: (3,) 26 | """ 27 | if all(q[1:] == 0): 28 | q = np.zeros(3) 29 | else: 30 | q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:]) 31 | return q 32 | 33 | def process_poses_rotmat(poses_in, rot_mat): 34 | """ 35 | processes the position + quaternion raw pose from dataset to position + rotation matrix 36 | :param poses_in: N x 7 37 | :param rot_mat: N x 3 x 3 38 | :return: processed poses N x 12 39 | """ 40 | 41 | poses = np.zeros((poses_in.shape[0], 3, 4)) 42 | poses[:,:3,:3] = rot_mat 43 | poses[...,:3,3] = poses_in[:, :3] 44 | poses = poses.reshape(poses_in.shape[0], 12) 45 | return poses 46 | 47 | from torchvision.datasets.folder import default_loader 48 | def load_image(filename, loader=default_loader): 49 | try: 50 | img = loader(filename) 51 | except IOError as e: 52 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e)) 53 | return None 54 | except: 55 | print('Could not load image {:s}, unexpected error'.format(filename)) 56 | return None 57 | return img 58 | 59 | def load_depth_image(filename): 60 | try: 61 | img_depth = Image.fromarray(np.array(Image.open(filename)).astype("uint16")) 62 | except IOError as e: 63 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e)) 64 | return None 65 | return img_depth 66 | 67 | def normalize(x): 68 | return x / np.linalg.norm(x) 69 | 70 | def viewmatrix(z, up, pos): 71 | vec2 = normalize(z) 72 | vec1_avg = up 73 | vec0 = normalize(np.cross(vec1_avg, vec2)) 74 | vec1 = normalize(np.cross(vec2, vec0)) 75 | m = np.stack([vec0, vec1, vec2, pos], 1) 76 | return m 77 | 78 | def normalize_recenter_pose(poses, sc, hwf): 79 | ''' normalize xyz into [-1, 1], and recenter pose ''' 80 | target_pose = poses.reshape(poses.shape[0],3,4) 81 | target_pose[:,:3,3] = target_pose[:,:3,3] * sc 82 | 83 | 84 | x_norm = target_pose[:,0,3] 85 | y_norm = target_pose[:,1,3] 86 | z_norm = target_pose[:,2,3] 87 | 88 | tpose_ = target_pose+0 89 | 90 | # find the center of pose 91 | center = np.array([x_norm.mean(), y_norm.mean(), z_norm.mean()]) 92 | bottom = np.reshape([0,0,0,1.], [1,4]) 93 | 94 | # pose avg 95 | vec2 = normalize(tpose_[:, :3, 2].sum(0)) 96 | up = tpose_[:, :3, 1].sum(0) 97 | hwf=np.array(hwf).transpose() 98 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 99 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 100 | 101 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [tpose_.shape[0],1,1]) 102 | poses = np.concatenate([tpose_[:,:3,:4], bottom], -2) 103 | poses = np.linalg.inv(c2w) @ poses 104 | return poses[:,:3,:].reshape(poses.shape[0],12) 105 | 106 | def downscale_pose(poses, sc): 107 | ''' downscale translation pose to [-1:1] only ''' 108 | target_pose = poses.reshape(poses.shape[0],3,4) 109 | target_pose[:,:3,3] = target_pose[:,:3,3] * sc 110 | return target_pose.reshape(poses.shape[0],12) 111 | 112 | class Cambridge2(data.Dataset): 113 | def __init__(self, scene, data_path, train, transform=None, 114 | target_transform=None, mode=0, seed=7, 115 | skip_images=False, df=2., trainskip=1, testskip=1, hwf=[480,854,744.], 116 | ret_idx=False, fix_idx=False, ret_hist=False, hist_bin=10): 117 | """ 118 | :param scene: scene name ['chess', 'pumpkin', ...] 119 | :param data_path: root 7scenes data directory. 120 | Usually '../data/deepslam_data/7Scenes' 121 | :param train: if True, return the training images. If False, returns the 122 | testing images 123 | :param transform: transform to apply to the images 124 | :param target_transform: transform to apply to the poses 125 | :param mode: (Obsolete) 0: just color image, 1: color image in NeRF 0-1 and resized 126 | :param skip_images: If True, skip loading images and return None instead 127 | :param df: downscale factor 128 | :param trainskip: due to 7scenes are so big, now can use less training sets # of trainset = 1/trainskip 129 | :param testskip: skip part of testset, # of testset = 1/testskip 130 | :param hwf: H,W,Focal from COLMAP 131 | """ 132 | 133 | self.transform = transform 134 | self.target_transform = target_transform 135 | self.df = df 136 | 137 | self.H, self.W, self.focal = hwf 138 | self.H = int(self.H) 139 | self.W = int(self.W) 140 | np.random.seed(seed) 141 | 142 | self.train = train 143 | self.ret_idx = ret_idx 144 | self.fix_idx = fix_idx 145 | self.ret_hist = ret_hist 146 | self.hist_bin = hist_bin # histogram bin size 147 | 148 | if self.train: 149 | root_dir = osp.join(data_path, scene) + '/train' 150 | else: 151 | root_dir = osp.join(data_path, scene) + '/test' 152 | 153 | rgb_dir = root_dir + '/rgb/' 154 | 155 | pose_dir = root_dir + '/poses/' 156 | 157 | world_setup_fn = osp.join(data_path, scene) + '/world_setup.json' 158 | 159 | # collect poses and image names 160 | self.rgb_files = os.listdir(rgb_dir) 161 | self.rgb_files = [rgb_dir + f for f in self.rgb_files] 162 | self.rgb_files.sort() 163 | 164 | self.pose_files = os.listdir(pose_dir) 165 | self.pose_files = [pose_dir + f for f in self.pose_files] 166 | self.pose_files.sort() 167 | 168 | # remove some abnormal data, need to fix later 169 | if scene == 'ShopFacade' and self.train: 170 | del self.rgb_files[42] 171 | del self.rgb_files[35] 172 | del self.pose_files[42] 173 | del self.pose_files[35] 174 | 175 | if len(self.rgb_files) != len(self.pose_files): 176 | raise Exception('RGB file count does not match pose file count!') 177 | 178 | # read json file 179 | with open(world_setup_fn, 'r') as myfile: 180 | data=myfile.read() 181 | 182 | # parse json file 183 | obj = json.loads(data) 184 | self.near = obj['near'] 185 | self.far = obj['far'] 186 | self.pose_scale = obj['pose_scale'] 187 | self.pose_scale2 = obj['pose_scale2'] 188 | self.move_all_cam_vec = obj['move_all_cam_vec'] 189 | 190 | # trainskip and testskip 191 | frame_idx = np.arange(len(self.rgb_files)) 192 | if train and trainskip > 1: 193 | frame_idx_tmp = frame_idx[::trainskip] 194 | frame_idx = frame_idx_tmp 195 | elif not train and testskip > 1: 196 | frame_idx_tmp = frame_idx[::testskip] 197 | frame_idx = frame_idx_tmp 198 | self.gt_idx = frame_idx 199 | 200 | self.rgb_files = [self.rgb_files[i] for i in frame_idx] 201 | self.pose_files = [self.pose_files[i] for i in frame_idx] 202 | 203 | if len(self.rgb_files) != len(self.pose_files): 204 | raise Exception('RGB file count does not match pose file count!') 205 | 206 | # read poses 207 | poses = [] 208 | for i in range(len(self.pose_files)): 209 | pose = np.loadtxt(self.pose_files[i]) 210 | poses.append(pose) 211 | poses = np.array(poses) # [N, 4, 4] 212 | self.poses = poses[:, :3, :4].reshape(poses.shape[0], 12) 213 | 214 | # debug read one img and get the shape of the img 215 | img = load_image(self.rgb_files[0]) 216 | img_np = (np.array(img) / 255.).astype(np.float32) # (480,854,3) 217 | 218 | self.H, self.W = img_np.shape[:2] 219 | if self.df != 1.: 220 | self.H = int(self.H//self.df) 221 | self.W = int(self.W//self.df) 222 | self.focal = self.focal/self.df 223 | 224 | def __len__(self): 225 | return len(self.rgb_files) 226 | 227 | def __getitem__(self, index): 228 | img = load_image(self.rgb_files[index]) 229 | pose = self.poses[index] 230 | if self.df != 1.: 231 | img_np = (np.array(img) / 255.).astype(np.float32) 232 | dims = (self.W, self.H) 233 | img_half_res = cv2.resize(img_np, dims, interpolation=cv2.INTER_AREA) # (H, W, 3) 234 | img = img_half_res 235 | 236 | if self.transform is not None: 237 | img = self.transform(img) 238 | 239 | if self.target_transform is not None: 240 | pose = self.target_transform(pose) 241 | 242 | if self.ret_idx: 243 | if self.train and self.fix_idx==False: 244 | return img, pose, index 245 | else: 246 | return img, pose, 0 247 | if self.ret_hist: 248 | yuv = rgb_to_yuv(img) 249 | y_img = yuv[0] # extract y channel only 250 | hist = torch.histc(y_img, bins=self.hist_bin, min=0., max=1.) # compute intensity histogram 251 | hist = hist/(hist.sum())*100 # convert to histogram density, in terms of percentage per bin 252 | hist = torch.round(hist) 253 | return img, pose, hist 254 | 255 | return img, pose 256 | 257 | def main(): 258 | """ 259 | visualizes the dataset 260 | """ 261 | #from common.vis_utils import show_batch, show_stereo_batch 262 | from torchvision.utils import make_grid 263 | import torchvision.transforms as transforms 264 | seq = 'ShopFacade' 265 | mode = 1 266 | # num_workers = 1 267 | 268 | # transformer 269 | data_transform = transforms.Compose([ 270 | transforms.ToTensor(), 271 | ]) 272 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x)) 273 | kwargs = dict(ret_hist=True) 274 | dset = Cambridge2(seq, '../data/Cambridge/', True, data_transform, target_transform=target_transform, mode=mode, df=7.15, trainskip=2, **kwargs) 275 | print('Loaded Cambridge sequence {:s}, length = {:d}'.format(seq, len(dset))) 276 | 277 | data_loader = data.DataLoader(dset, batch_size=4, shuffle=False) 278 | 279 | batch_count = 0 280 | N = 2 281 | for batch in data_loader: 282 | print('Minibatch {:d}'.format(batch_count)) 283 | pdb.set_trace() 284 | 285 | batch_count += 1 286 | if batch_count >= N: 287 | break 288 | 289 | if __name__ == '__main__': 290 | main() 291 | -------------------------------------------------------------------------------- /dataset_loaders/load_Cambridge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import torch.cuda 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms, models 10 | 11 | from dataset_loaders.cambridge_scenes import Cambridge2, normalize_recenter_pose, load_image 12 | import pdb 13 | 14 | #from dataset_loaders.frustum.frustum_util import initK, generate_sampling_frustum, compute_frustums_overlap 15 | 16 | #focal_length = 555 # This is an approximate https://github.com/NVlabs/geomapnet/issues/8 17 | # Official says (585,585) 18 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 19 | 20 | # translation z axis 21 | trans_t = lambda t : np.array([ 22 | [1,0,0,0], 23 | [0,1,0,0], 24 | [0,0,1,t], 25 | [0,0,0,1]]).astype(float) 26 | 27 | # x rotation 28 | rot_phi = lambda phi : np.array([ 29 | [1,0,0,0], 30 | [0,np.cos(phi),-np.sin(phi),0], 31 | [0,np.sin(phi), np.cos(phi),0], 32 | [0,0,0,1]]).astype(float) 33 | 34 | # y rotation 35 | rot_theta = lambda th : np.array([ 36 | [np.cos(th),0,-np.sin(th),0], 37 | [0,1,0,0], 38 | [np.sin(th),0, np.cos(th),0], 39 | [0,0,0,1]]).astype(float) 40 | 41 | # z rotation 42 | rot_psi = lambda psi : np.array([ 43 | [np.cos(psi),-np.sin(psi),0,0], 44 | [np.sin(psi),np.cos(psi),0,0], 45 | [0,0,1,0], 46 | [0,0,0,1]]).astype(float) 47 | 48 | def is_inside_frustum(p, x_res, y_res): 49 | return (0 < p[0]) & (p[0] < x_res) & (0 < p[1]) & (p[1] < y_res) 50 | 51 | def initK(f, cx, cy): 52 | K = np.eye(3, 3) 53 | K[0, 0] = K[1, 1] = f 54 | K[0, 2] = cx 55 | K[1, 2] = cy 56 | return K 57 | 58 | def generate_sampling_frustum(step, depth, K, f, cx, cy, x_res, y_res): 59 | #pdb.set_trace() 60 | x_max = depth * (x_res - cx) / f 61 | x_min = -depth * cx / f 62 | y_max = depth * (y_res - cy) / f 63 | y_min = -depth * cy / f 64 | 65 | zs = np.arange(0, depth, step) 66 | xs = np.arange(x_min, x_max, step) 67 | ys = np.arange(y_min, y_max, step) 68 | 69 | X0 = [] 70 | for z in zs: 71 | for x in xs: 72 | for y in ys: 73 | P = np.array([x, y, z]) 74 | p = np.dot(K, P) 75 | if p[2] < 0.00001: 76 | continue 77 | p = p / p[2] 78 | if is_inside_frustum(p, x_res, y_res): 79 | X0.append(P) 80 | X0 = np.array(X0) 81 | return X0 82 | 83 | def compute_frustums_overlap(pose0, pose1, sampling_frustum, K, x_res, y_res): 84 | R0 = pose0[0:3, 0:3] 85 | t0 = pose0[0:3, 3] 86 | R1 = pose1[0:3, 0:3] 87 | t1 = pose1[0:3, 3] 88 | 89 | R10 = np.dot(R1.T, R0) 90 | t10 = np.dot(R1.T, t0 - t1) 91 | 92 | _P = np.dot(R10, sampling_frustum.T).T + t10 93 | p = np.dot(K, _P.T).T 94 | pn = p[:, 2] 95 | p = np.divide(p, pn[:, None]) 96 | res = np.apply_along_axis(is_inside_frustum, 1, p, x_res, y_res) 97 | return np.sum(res) / float(res.shape[0]) 98 | 99 | def perturb_rotation(c2w, theta, phi, psi=0): 100 | last_row = np.tile(np.array([0, 0, 0, 1]), (1, 1)) # (N_images, 1, 4) 101 | c2w = np.concatenate([c2w, last_row], 0) # (N_images, 4, 4) homogeneous coordinate 102 | c2w = rot_phi(phi/180.*np.pi) @ c2w 103 | c2w = rot_theta(theta/180.*np.pi) @ c2w 104 | c2w = rot_psi(psi/180.*np.pi) @ c2w 105 | c2w = c2w[:3,:4] 106 | return c2w 107 | 108 | def viewmatrix(z, up, pos): 109 | vec2 = normalize(z) 110 | vec1_avg = up 111 | vec0 = normalize(np.cross(vec1_avg, vec2)) 112 | vec1 = normalize(np.cross(vec2, vec0)) 113 | m = np.stack([vec0, vec1, vec2, pos], 1) 114 | return m 115 | 116 | def normalize(v): 117 | """Normalize a vector.""" 118 | return v / np.linalg.norm(v) 119 | 120 | def average_poses(poses): 121 | """ 122 | Calculate the average pose, which is then used to center all poses 123 | using @center_poses. Its computation is as follows: 124 | 1. Compute the center: the average of pose centers. 125 | 2. Compute the z axis: the normalized average z axis. 126 | 3. Compute axis y': the average y axis. 127 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 128 | 5. Compute the y axis: z cross product x. 129 | Note that at step 3, we cannot directly use y' as y axis since it's 130 | not necessarily orthogonal to z axis. We need to pass from x to y. 131 | Inputs: 132 | poses: (N_images, 3, 4) 133 | Outputs: 134 | pose_avg: (3, 4) the average pose 135 | """ 136 | # 1. Compute the center 137 | center = poses[..., 3].mean(0) # (3) 138 | # 2. Compute the z axis 139 | z = normalize(poses[..., 2].mean(0)) # (3) 140 | # 3. Compute axis y' (no need to normalize as it's not the final output) 141 | y_ = poses[..., 1].mean(0) # (3) 142 | # 4. Compute the x axis 143 | x = normalize(np.cross(y_, z)) # (3) 144 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 145 | y = np.cross(z, x) # (3) 146 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 147 | return pose_avg 148 | 149 | def center_poses(poses, pose_avg_from_file=None): 150 | """ 151 | Center the poses so that we can use NDC. 152 | See https://github.com/bmild/nerf/issues/34 153 | 154 | Inputs: 155 | poses: (N_images, 3, 4) 156 | pose_avg_from_file: if not None, pose_avg is loaded from pose_avg_stats.txt 157 | 158 | Outputs: 159 | poses_centered: (N_images, 3, 4) the centered poses 160 | pose_avg: (3, 4) the average pose 161 | """ 162 | 163 | 164 | if pose_avg_from_file is None: 165 | pose_avg = average_poses(poses) # (3, 4) # this need to be fixed throughout dataset 166 | else: 167 | pose_avg = pose_avg_from_file 168 | 169 | pose_avg_homo = np.eye(4) 170 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation (4,4) 171 | # by simply adding 0, 0, 0, 1 as the last row 172 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 173 | poses_homo = \ 174 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 175 | 176 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 177 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 178 | 179 | return poses_centered, pose_avg #np.linalg.inv(pose_avg_homo) 180 | 181 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 182 | render_poses = [] 183 | rads = np.array(list(rads) + [1.]) 184 | hwf = c2w[:,4:5] # it's empty here... 185 | 186 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 187 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 188 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 189 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 190 | return render_poses 191 | 192 | def average_poses(poses): 193 | """ 194 | Same as in SingleCamVideoStatic.py 195 | Inputs: 196 | poses: (N_images, 3, 4) 197 | Outputs: 198 | pose_avg: (3, 4) the average pose 199 | """ 200 | center = poses[..., 3].mean(0) # (3) 201 | z = normalize(poses[..., 2].mean(0)) # (3) 202 | y_ = poses[..., 1].mean(0) # (3) 203 | x = normalize(np.cross(y_, z)) # (3) 204 | y = np.cross(z, x) # (3) 205 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 206 | return pose_avg 207 | 208 | def generate_render_pose(poses, bds): 209 | idx = np.random.choice(poses.shape[0]) 210 | c2w=poses[idx] 211 | print(c2w[:3,:4]) 212 | 213 | ## Get spiral 214 | # Get average pose 215 | up = normalize(poses[:, :3, 1].sum(0)) 216 | 217 | # Find a reasonable "focus depth" for this dataset 218 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 219 | dt = .75 220 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 221 | focal = mean_dz 222 | 223 | # Get radii for spiral path 224 | shrink_factor = .8 225 | zdelta = close_depth * .2 226 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 227 | rads = np.percentile(np.abs(tt), 20, 0) # views of 20 degrees 228 | c2w_path = c2w 229 | N_views = 120 # number of views in video 230 | N_rots = 2 231 | 232 | # Generate poses for spiral path 233 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 234 | return render_poses 235 | 236 | def perturb_render_pose(poses, bds, x, angle): 237 | """ 238 | Inputs: 239 | poses: (3, 4) 240 | bds: bounds 241 | x: translational perturb range 242 | angle: rotation angle perturb range in degrees 243 | Outputs: 244 | new_c2w: (N_views, 3, 4) new poses 245 | """ 246 | idx = np.random.choice(poses.shape[0]) 247 | c2w=poses[idx] 248 | 249 | N_views = 10 # number of views in video 250 | new_c2w = np.zeros((N_views, 3, 4)) 251 | 252 | # perturb translational pose 253 | for i in range(N_views): 254 | new_c2w[i] = c2w 255 | new_c2w[i,:,3] = new_c2w[i,:,3] + np.random.uniform(-x,x,3) # perturb pos between -1 to 1 256 | theta=np.random.uniform(-angle,angle,1) # in degrees 257 | phi=np.random.uniform(-angle,angle,1) # in degrees 258 | psi=np.random.uniform(-angle,angle,1) # in degrees 259 | new_c2w[i] = perturb_rotation(new_c2w[i], theta, phi, psi) 260 | return new_c2w, idx 261 | 262 | def remove_overlap_data(train_set, val_set): 263 | ''' Remove some overlap data in val set so that train set and val set do not have overlap ''' 264 | train = train_set.gt_idx 265 | val = val_set.gt_idx 266 | 267 | # find redundant data index in val_set 268 | index = np.where(np.in1d(val, train) == True) # this is a tuple 269 | # delete redundant data 270 | val_set.gt_idx = np.delete(val_set.gt_idx, index) 271 | val_set.poses = np.delete(val_set.poses, index, axis=0) 272 | for i in sorted(index[0], reverse=True): 273 | val_set.c_imgs.pop(i) 274 | val_set.d_imgs.pop(i) 275 | return train_set, val_set 276 | 277 | def fix_coord(args, train_set, val_set, pose_avg_stats_file='', rescale_coord=True): 278 | ''' fix coord for 7 Scenes to align with llff style dataset ''' 279 | 280 | # This is only to store a pre-calculated pose average stats of the dataset 281 | if args.save_pose_avg_stats: 282 | pdb.set_trace() 283 | if pose_avg_stats_file == '': 284 | print('pose_avg_stats_file location unspecified, please double check...') 285 | sys.exit() 286 | 287 | all_poses = train_set.poses 288 | all_poses = all_poses.reshape(all_poses.shape[0], 3, 4) 289 | all_poses, pose_avg = center_poses(all_poses) 290 | 291 | # save pose_avg to pose_avg_stats.txt 292 | np.savetxt(pose_avg_stats_file, pose_avg) 293 | print('pose_avg_stats.txt successfully saved') 294 | sys.exit() 295 | 296 | # get all poses (train+val) 297 | train_poses = train_set.poses 298 | 299 | val_poses = val_set.poses 300 | all_poses = np.concatenate([train_poses, val_poses]) 301 | 302 | # Center the poses for ndc 303 | all_poses = all_poses.reshape(all_poses.shape[0], 3, 4) 304 | 305 | # Here we use either pre-stored pose average stats or calculate pose average stats on the flight to center the poses 306 | if args.load_pose_avg_stats: 307 | pose_avg_from_file = np.loadtxt(pose_avg_stats_file) 308 | all_poses, pose_avg = center_poses(all_poses, pose_avg_from_file) 309 | else: 310 | all_poses, pose_avg = center_poses(all_poses) 311 | 312 | # Correct axis to LLFF Style y,z -> -y,-z 313 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(all_poses), 1, 1)) # (N_images, 1, 4) 314 | all_poses = np.concatenate([all_poses, last_row], 1) 315 | 316 | # rotate tpose 90 degrees at x axis # only corrected translation position 317 | all_poses = rot_phi(180/180.*np.pi) @ all_poses 318 | 319 | # correct view direction except mirror with gt view 320 | all_poses[:,:3,:3] = -all_poses[:,:3,:3] 321 | 322 | # camera direction mirror at x axis mod1 R' = R @ mirror matrix 323 | # ref: https://gamedev.stackexchange.com/questions/149062/how-to-mirror-reflect-flip-a-4d-transformation-matrix 324 | all_poses[:,:3,:3] = all_poses[:,:3,:3] @ np.array([[-1,0,0],[0,1,0],[0,0,1]]) 325 | 326 | all_poses = all_poses[:,:3,:4] 327 | 328 | bounds = np.array([train_set.near, train_set.far]) # manual tuned 329 | 330 | if rescale_coord: 331 | sc=train_set.pose_scale # manual tuned factor, align with colmap scale 332 | all_poses[:,:3,3] *= sc 333 | 334 | ### quite ugly ### 335 | # move center of camera pose 336 | if train_set.move_all_cam_vec != [0.,0.,0.]: 337 | all_poses[:, :3, 3] += train_set.move_all_cam_vec 338 | 339 | if train_set.pose_scale2 != 1.0: 340 | all_poses[:,:3,3] *= train_set.pose_scale2 341 | # end of new mod1 342 | 343 | # Return all poses to dataset loaders 344 | all_poses = all_poses.reshape(all_poses.shape[0], 12) 345 | train_set.poses = all_poses[:train_poses.shape[0]] 346 | val_set.poses = all_poses[train_poses.shape[0]:] 347 | return train_set, val_set, bounds 348 | 349 | def load_Cambridge_dataloader(args): 350 | ''' Data loader for Pose Regression Network ''' 351 | if args.pose_only: # if train posenet is true 352 | pass 353 | else: 354 | raise Exception('load_Cambridge_dataloader() currently only support PoseNet Training, not NeRF training') 355 | data_dir, scene = osp.split(args.datadir) # ../data/7Scenes, chess 356 | dataset_folder, dataset = osp.split(data_dir) # ../data, 7Scenes 357 | 358 | # transformer 359 | data_transform = transforms.Compose([ 360 | transforms.ToTensor(), 361 | ]) 362 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x)) 363 | 364 | ret_idx = False # return frame index 365 | fix_idx = False # return frame index=0 in training 366 | ret_hist = False 367 | 368 | if 'NeRFH' in args: 369 | if args.NeRFH == True: 370 | ret_idx = True 371 | if args.fix_index: 372 | fix_idx = True 373 | 374 | # encode hist experiment 375 | if args.encode_hist: 376 | ret_idx = False 377 | fix_idx = False 378 | ret_hist = True 379 | 380 | kwargs = dict(scene=scene, data_path=data_dir, 381 | transform=data_transform, target_transform=target_transform, 382 | df=args.df, ret_idx=ret_idx, fix_idx=fix_idx, 383 | ret_hist=ret_hist, hist_bin=args.hist_bin) 384 | 385 | if args.finetune_unlabel: # direct-pn + unlabel 386 | train_set = Cambridge2(train=False, testskip=args.trainskip, **kwargs) 387 | val_set = Cambridge2(train=False, testskip=args.testskip, **kwargs) 388 | 389 | # if not args.eval: 390 | # # remove overlap data in val_set that was already in train_set, 391 | # train_set, val_set = remove_overlap_data(train_set, val_set) 392 | else: 393 | train_set = Cambridge2(train=True, trainskip=args.trainskip, **kwargs) 394 | val_set = Cambridge2(train=False, testskip=args.testskip, **kwargs) 395 | L = len(train_set) 396 | 397 | i_train = train_set.gt_idx 398 | i_val = val_set.gt_idx 399 | i_test = val_set.gt_idx 400 | # use a pose average stats computed earlier to unify posenet and nerf training 401 | if args.save_pose_avg_stats or args.load_pose_avg_stats: 402 | pose_avg_stats_file = osp.join(args.datadir, 'pose_avg_stats.txt') 403 | train_set, val_set, bounds = fix_coord(args, train_set, val_set, pose_avg_stats_file, rescale_coord=False) # only adjust coord. systems, rescale are done at training 404 | else: 405 | train_set, val_set, bounds = fix_coord(args, train_set, val_set, rescale_coord=False) 406 | 407 | train_shuffle=True 408 | if args.eval: 409 | train_shuffle=False 410 | 411 | train_dl = DataLoader(train_set, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=8) #num_workers=4 pin_memory=True 412 | val_dl = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, num_workers=2) 413 | test_dl = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=2) 414 | 415 | hwf = [train_set.H, train_set.W, train_set.focal] 416 | i_split = [i_train, i_val, i_test] 417 | 418 | return train_dl, val_dl, test_dl, hwf, i_split, bounds.min(), bounds.max() 419 | 420 | def load_Cambridge_dataloader_NeRF(args): 421 | ''' Data loader for NeRF ''' 422 | 423 | data_dir, scene = osp.split(args.datadir) # ../data/7Scenes, chess 424 | dataset_folder, dataset = osp.split(data_dir) # ../data, 7Scenes 425 | 426 | data_transform = transforms.Compose([ 427 | transforms.ToTensor()]) 428 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x)) 429 | 430 | ret_idx = False # return frame index 431 | fix_idx = False # return frame index=0 in training 432 | ret_hist = False 433 | 434 | if 'NeRFH' in args: 435 | ret_idx = True 436 | if args.fix_index: 437 | fix_idx = True 438 | 439 | # encode hist experiment 440 | if args.encode_hist: 441 | ret_idx = False 442 | fix_idx = False 443 | ret_hist = True 444 | 445 | kwargs = dict(scene=scene, data_path=data_dir, 446 | transform=data_transform, target_transform=target_transform, 447 | df=args.df, ret_idx=ret_idx, fix_idx=fix_idx, ret_hist=ret_hist, hist_bin=args.hist_bin) 448 | 449 | train_set = Cambridge2(train=True, trainskip=args.trainskip, **kwargs) 450 | val_set = Cambridge2(train=False, testskip=args.testskip, **kwargs) 451 | 452 | i_train = train_set.gt_idx 453 | i_val = val_set.gt_idx 454 | i_test = val_set.gt_idx 455 | 456 | # use a pose average stats computed earlier to unify posenet and nerf training 457 | if args.save_pose_avg_stats or args.load_pose_avg_stats: 458 | pose_avg_stats_file = osp.join(args.datadir, 'pose_avg_stats.txt') 459 | train_set, val_set, bounds = fix_coord(args, train_set, val_set, pose_avg_stats_file) 460 | else: 461 | train_set, val_set, bounds = fix_coord(args, train_set, val_set) 462 | 463 | render_poses = None 464 | render_img = None 465 | 466 | train_shuffle=True 467 | if args.render_video_train or args.render_test or args.dataset_type == 'Cambridge2': 468 | train_shuffle=False 469 | train_dl = DataLoader(train_set, batch_size=1, shuffle=train_shuffle) # default 470 | # train_dl = DataLoader(train_set, batch_size=1, shuffle=False) # for debug only 471 | val_dl = DataLoader(val_set, batch_size=1, shuffle=False) 472 | 473 | hwf = [train_set.H, train_set.W, train_set.focal] 474 | 475 | i_split = [i_train, i_val, i_test] 476 | 477 | return train_dl, val_dl, hwf, i_split, bounds, render_poses, render_img -------------------------------------------------------------------------------- /dataset_loaders/seven_scenes.py: -------------------------------------------------------------------------------- 1 | """ 2 | pytorch data loader for the 7-scenes dataset 3 | """ 4 | import os 5 | import os.path as osp 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | from torch.utils import data 10 | import sys 11 | import pickle 12 | import pdb,copy 13 | import cv2 14 | 15 | sys.path.insert(0, '../') 16 | import transforms3d.quaternions as txq 17 | # see for formulas: 18 | # https://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-801-machine-vision-fall-2004/readings/quaternions.pdf 19 | # and "Quaternion and Rotation" - Yan-Bin Jia, September 18, 2016 20 | from dataset_loaders.utils.color import rgb_to_yuv 21 | import json 22 | 23 | def RT2QT(poses_in, mean_t, std_t): 24 | """ 25 | processes the 1x12 raw pose from dataset by aligning and then normalizing 26 | :param poses_in: N x 12 27 | :param mean_t: 3 28 | :param std_t: 3 29 | :return: processed poses (translation + quaternion) N x 7 30 | """ 31 | poses_out = np.zeros((len(poses_in), 7)) 32 | poses_out[:, 0:3] = poses_in[:, [3, 7, 11]] 33 | 34 | # align 35 | for i in range(len(poses_out)): 36 | R = poses_in[i].reshape((3, 4))[:3, :3] 37 | q = txq.mat2quat(R) 38 | q = q/(np.linalg.norm(q) + 1e-12) # normalize 39 | q *= np.sign(q[0]) # constrain to hemisphere 40 | poses_out[i, 3:] = q 41 | 42 | # normalize translation 43 | poses_out[:, :3] -= mean_t 44 | poses_out[:, :3] /= std_t 45 | return poses_out 46 | 47 | def qlog(q): 48 | """ 49 | Applies logarithm map to q 50 | :param q: (4,) 51 | :return: (3,) 52 | """ 53 | if all(q[1:] == 0): 54 | q = np.zeros(3) 55 | else: 56 | q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:]) 57 | return q 58 | 59 | import transforms3d.quaternions as txq # Warning: outdated package 60 | 61 | def process_poses_rotmat(poses_in, mean_t, std_t, align_R, align_t, align_s): 62 | """ 63 | processes the 1x12 raw pose from dataset by aligning and then normalizing 64 | produce logq 65 | :param poses_in: N x 12 66 | :return: processed poses N x 12 67 | """ 68 | return poses_in 69 | 70 | def process_poses_q(poses_in, mean_t, std_t, align_R, align_t, align_s): 71 | """ 72 | processes the 1x12 raw pose from dataset by aligning and then normalizing 73 | produce logq 74 | :param poses_in: N x 12 75 | :param mean_t: 3 76 | :param std_t: 3 77 | :param align_R: 3 x 3 78 | :param align_t: 3 79 | :param align_s: 1 80 | :return: processed poses (translation + log quaternion) N x 6 81 | """ 82 | poses_out = np.zeros((len(poses_in), 6)) # (1000,6) 83 | poses_out[:, 0:3] = poses_in[:, [3, 7, 11]] # x,y,z position 84 | # align 85 | for i in range(len(poses_out)): 86 | R = poses_in[i].reshape((3, 4))[:3, :3] # rotation 87 | q = txq.mat2quat(np.dot(align_R, R)) 88 | q *= np.sign(q[0]) # constrain to hemisphere, first number, +1/-1, q.shape (1,4) 89 | poses_out[i, 3:] = q # logq rotation 90 | t = poses_out[i, :3] - align_t 91 | poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze() 92 | 93 | # normalize translation 94 | poses_out[:, :3] -= mean_t #(1000, 6) 95 | poses_out[:, :3] /= std_t 96 | return poses_out 97 | 98 | def process_poses_logq(poses_in, mean_t, std_t, align_R, align_t, align_s): 99 | """ 100 | processes the 1x12 raw pose from dataset by aligning and then normalizing 101 | produce logq 102 | :param poses_in: N x 12 103 | :param mean_t: 3 104 | :param std_t: 3 105 | :param align_R: 3 x 3 106 | :param align_t: 3 107 | :param align_s: 1 108 | :return: processed poses (translation + log quaternion) N x 6 109 | """ 110 | poses_out = np.zeros((len(poses_in), 6)) # (1000,6) 111 | poses_out[:, 0:3] = poses_in[:, [3, 7, 11]] # x,y,z position 112 | # align 113 | for i in range(len(poses_out)): 114 | R = poses_in[i].reshape((3, 4))[:3, :3] # rotation 115 | q = txq.mat2quat(np.dot(align_R, R)) 116 | q *= np.sign(q[0]) # constrain to hemisphere, first number, +1/-1, q.shape (1,4) 117 | q = qlog(q) # (1,3) 118 | poses_out[i, 3:] = q # logq rotation 119 | t = poses_out[i, :3] - align_t 120 | poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze() 121 | 122 | # normalize translation 123 | poses_out[:, :3] -= mean_t #(1000, 6) 124 | poses_out[:, :3] /= std_t 125 | return poses_out 126 | 127 | from torchvision.datasets.folder import default_loader 128 | def load_image(filename, loader=default_loader): 129 | try: 130 | img = loader(filename) 131 | except IOError as e: 132 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e)) 133 | return None 134 | except: 135 | print('Could not load image {:s}, unexpected error'.format(filename)) 136 | return None 137 | return img 138 | 139 | def load_depth_image(filename): 140 | try: 141 | img_depth = Image.fromarray(np.array(Image.open(filename)).astype("uint16")) 142 | except IOError as e: 143 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e)) 144 | return None 145 | return img_depth 146 | 147 | def normalize(x): 148 | return x / np.linalg.norm(x) 149 | 150 | def viewmatrix(z, up, pos): 151 | vec2 = normalize(z) 152 | vec1_avg = up 153 | vec0 = normalize(np.cross(vec1_avg, vec2)) 154 | vec1 = normalize(np.cross(vec2, vec0)) 155 | m = np.stack([vec0, vec1, vec2, pos], 1) 156 | return m 157 | 158 | def normalize_recenter_pose(poses, sc, hwf): 159 | ''' normalize xyz into [-1, 1], and recenter pose ''' 160 | target_pose = poses.reshape(poses.shape[0],3,4) 161 | target_pose[:,:3,3] = target_pose[:,:3,3] * sc 162 | 163 | x_norm = target_pose[:,0,3] 164 | y_norm = target_pose[:,1,3] 165 | z_norm = target_pose[:,2,3] 166 | 167 | tpose_ = target_pose+0 168 | 169 | # find the center of pose 170 | center = np.array([x_norm.mean(), y_norm.mean(), z_norm.mean()]) 171 | bottom = np.reshape([0,0,0,1.], [1,4]) 172 | 173 | # pose avg 174 | vec2 = normalize(tpose_[:, :3, 2].sum(0)) 175 | up = tpose_[:, :3, 1].sum(0) 176 | hwf=np.array(hwf).transpose() 177 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 178 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 179 | 180 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [tpose_.shape[0],1,1]) 181 | poses = np.concatenate([tpose_[:,:3,:4], bottom], -2) 182 | poses = np.linalg.inv(c2w) @ poses 183 | return poses[:,:3,:].reshape(poses.shape[0],12) 184 | 185 | class SevenScenes(data.Dataset): 186 | def __init__(self, scene, data_path, train, transform=None, 187 | target_transform=None, mode=0, seed=7, 188 | df=1., trainskip=1, testskip=1, hwf=[480,640,585.], 189 | ret_idx=False, fix_idx=False, ret_hist=False, hist_bin=10): 190 | """ 191 | :param scene: scene name ['chess', 'pumpkin', ...] 192 | :param data_path: root 7scenes data directory. 193 | Usually '../data/deepslam_data/7Scenes' 194 | :param train: if True, return the training images. If False, returns the 195 | testing images 196 | :param transform: transform to apply to the images 197 | :param target_transform: transform to apply to the poses 198 | :param mode: (Obsolete) 0: just color image, 1: color image in NeRF 0-1 and resized. 199 | :param df: downscale factor 200 | :param trainskip: due to 7scenes are so big, now can use less training sets # of trainset = 1/trainskip 201 | :param testskip: skip part of testset, # of testset = 1/testskip 202 | :param hwf: H,W,Focal from COLMAP 203 | :param ret_idx: bool, currently only used by NeRF-W 204 | """ 205 | 206 | self.transform = transform 207 | self.target_transform = target_transform 208 | self.df = df 209 | 210 | self.H, self.W, self.focal = hwf 211 | self.H = int(self.H) 212 | self.W = int(self.W) 213 | np.random.seed(seed) 214 | 215 | self.train = train 216 | self.ret_idx = ret_idx 217 | self.fix_idx = fix_idx 218 | self.ret_hist = ret_hist 219 | self.hist_bin = hist_bin # histogram bin size 220 | 221 | # directories 222 | base_dir = osp.join(osp.expanduser(data_path), scene) # '../data/deepslam_data/7Scenes' 223 | data_dir = osp.join('..', 'data', '7Scenes', scene) # '../data/7Scenes/chess' 224 | world_setup_fn = data_dir + '/world_setup.json' 225 | 226 | # read json file 227 | with open(world_setup_fn, 'r') as myfile: 228 | data=myfile.read() 229 | 230 | # parse json file 231 | obj = json.loads(data) 232 | self.near = obj['near'] 233 | self.far = obj['far'] 234 | self.pose_scale = obj['pose_scale'] 235 | self.pose_scale2 = obj['pose_scale2'] 236 | self.move_all_cam_vec = obj['move_all_cam_vec'] 237 | 238 | # decide which sequences to use 239 | if train: 240 | split_file = osp.join(base_dir, 'TrainSplit.txt') 241 | else: 242 | split_file = osp.join(base_dir, 'TestSplit.txt') 243 | with open(split_file, 'r') as f: 244 | seqs = [int(l.split('sequence')[-1]) for l in f if not l.startswith('#')] # parsing 245 | 246 | # read poses and collect image names 247 | self.c_imgs = [] 248 | self.d_imgs = [] 249 | self.gt_idx = np.empty((0,), dtype=np.int) 250 | ps = {} 251 | vo_stats = {} 252 | gt_offset = int(0) 253 | for seq in seqs: 254 | seq_dir = osp.join(base_dir, 'seq-{:02d}'.format(seq)) 255 | seq_data_dir = osp.join(data_dir, 'seq-{:02d}'.format(seq)) 256 | 257 | p_filenames = [n for n in os.listdir(osp.join(seq_dir, '.')) if n.find('pose') >= 0] 258 | idxes = [int(n[6:12]) for n in p_filenames] 259 | 260 | frame_idx = np.array(sorted(idxes)) 261 | 262 | 263 | # trainskip and testskip 264 | if train and trainskip > 1: 265 | frame_idx_tmp = frame_idx[::trainskip] 266 | frame_idx = frame_idx_tmp 267 | elif not train and testskip > 1: 268 | frame_idx_tmp = frame_idx[::testskip] 269 | frame_idx = frame_idx_tmp 270 | 271 | pss = [np.loadtxt(osp.join(seq_dir, 'frame-{:06d}.pose.txt'. 272 | format(i))).flatten()[:12] for i in frame_idx] # all the 3x4 pose matrices 273 | ps[seq] = np.asarray(pss) # list of all poses in file No. seq 274 | vo_stats[seq] = {'R': np.eye(3), 't': np.zeros(3), 's': 1} 275 | 276 | self.gt_idx = np.hstack((self.gt_idx, gt_offset+frame_idx)) 277 | gt_offset += len(p_filenames) 278 | c_imgs = [osp.join(seq_dir, 'frame-{:06d}.color.png'.format(i)) for i in frame_idx] 279 | d_imgs = [osp.join(seq_dir, 'frame-{:06d}.depth.png'.format(i)) for i in frame_idx] 280 | self.c_imgs.extend(c_imgs) 281 | self.d_imgs.extend(d_imgs) 282 | 283 | pose_stats_filename = osp.join(data_dir, 'pose_stats.txt') 284 | if train: 285 | mean_t = np.zeros(3) # optionally, use the ps dictionary to calc stats 286 | std_t = np.ones(3) 287 | np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f') 288 | else: 289 | mean_t, std_t = np.loadtxt(pose_stats_filename) 290 | 291 | # convert pose to translation + log quaternion 292 | logq = False 293 | quat = False 294 | if logq: # (batch_num, 6) 295 | self.poses = np.empty((0, 6)) 296 | elif quat: # (batch_num, 7) 297 | self.poses = np.empty((0, 7)) 298 | else: # (batch_num, 12) 299 | self.poses = np.empty((0, 12)) 300 | 301 | for seq in seqs: 302 | if logq: 303 | pss = process_poses_logq(poses_in=ps[seq], mean_t=mean_t, std_t=std_t, align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'], align_s=vo_stats[seq]['s']) # here returns t + logQed R 304 | self.poses = np.vstack((self.poses, pss)) 305 | elif quat: 306 | pss = RT2QT(poses_in=ps[seq], mean_t=mean_t, std_t=std_t) # here returns t + quaternion R 307 | self.poses = np.vstack((self.poses, pss)) 308 | else: 309 | pss = process_poses_rotmat(poses_in=ps[seq], mean_t=mean_t, std_t=std_t, align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'], align_s=vo_stats[seq]['s']) 310 | self.poses = np.vstack((self.poses, pss)) 311 | 312 | # debug read one img and get the shape of the img 313 | img = load_image(self.c_imgs[0]) 314 | img_np = (np.array(img) / 255.).astype(np.float32) # (480,640,3) 315 | self.H, self.W = img_np.shape[:2] 316 | if self.df != 1.: 317 | self.H = int(self.H//self.df) 318 | self.W = int(self.W//self.df) 319 | self.focal = self.focal/self.df 320 | 321 | def __len__(self): 322 | return self.poses.shape[0] 323 | 324 | def __getitem__(self, index): 325 | # print("index:", index) 326 | img = load_image(self.c_imgs[index]) # chess img.size = (640,480) 327 | pose = self.poses[index] 328 | if self.df != 1.: 329 | img_np = (np.array(img) / 255.).astype(np.float32) 330 | dims = (self.W, self.H) 331 | img_half_res = cv2.resize(img_np, dims, interpolation=cv2.INTER_AREA) # (H, W, 3) 332 | img = img_half_res 333 | 334 | if self.target_transform is not None: 335 | pose = self.target_transform(pose) 336 | 337 | if self.transform is not None: 338 | img = self.transform(img) 339 | 340 | if self.ret_idx: 341 | if self.train and self.fix_idx==False: 342 | return img, pose, index 343 | else: 344 | return img, pose, 0 345 | 346 | if self.ret_hist: 347 | yuv = rgb_to_yuv(img) 348 | y_img = yuv[0] # extract y channel only 349 | hist = torch.histc(y_img, bins=self.hist_bin, min=0., max=1.) # compute intensity histogram 350 | hist = hist/(hist.sum())*100 # convert to histogram density, in terms of percentage per bin 351 | hist = torch.round(hist) 352 | return img, pose, hist 353 | 354 | return img, pose 355 | 356 | 357 | def main(): 358 | """ 359 | visualizes the dataset 360 | """ 361 | # from common.vis_utils import show_batch, show_stereo_batch 362 | from torchvision.utils import make_grid 363 | import torchvision.transforms as transforms 364 | seq = 'heads' 365 | mode = 1 366 | num_workers = 6 367 | transform = transforms.Compose([ 368 | transforms.Scale(256), 369 | transforms.CenterCrop(224), 370 | transforms.ToTensor(), 371 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 372 | ]) 373 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x)) 374 | dset = SevenScenes(seq, '../data/deepslam_data/7Scenes', True, transform, target_transform=target_transform, mode=mode) 375 | print('Loaded 7Scenes sequence {:s}, length = {:d}'.format(seq, len(dset))) 376 | pdb.set_trace() 377 | 378 | data_loader = data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=num_workers) 379 | 380 | batch_count = 0 381 | N = 2 382 | for batch in data_loader: 383 | print('Minibatch {:d}'.format(batch_count)) 384 | pdb.set_trace() 385 | # if mode < 2: 386 | # show_batch(make_grid(batch[0], nrow=1, padding=25, normalize=True)) 387 | # elif mode == 2: 388 | # lb = make_grid(batch[0][0], nrow=1, padding=25, normalize=True) 389 | # rb = make_grid(batch[0][1], nrow=1, padding=25, normalize=True) 390 | # show_stereo_batch(lb, rb) 391 | 392 | batch_count += 1 393 | if batch_count >= N: 394 | break 395 | 396 | if __name__ == '__main__': 397 | main() 398 | -------------------------------------------------------------------------------- /dataset_loaders/utils/color.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def rgb_to_yuv(image: torch.Tensor) -> torch.Tensor: 5 | r""" 6 | From Kornia. 7 | Convert an RGB image to YUV. 8 | 9 | .. image:: _static/img/rgb_to_yuv.png 10 | 11 | The image data is assumed to be in the range of (0, 1). 12 | 13 | Args: 14 | image: RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`. 15 | 16 | Returns: 17 | YUV version of the image with shape :math:`(*, 3, H, W)`. 18 | 19 | Example: 20 | >>> input = torch.rand(2, 3, 4, 5) 21 | >>> output = rgb_to_yuv(input) # 2x3x4x5 22 | """ 23 | if not isinstance(image, torch.Tensor): 24 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") 25 | 26 | if len(image.shape) < 3 or image.shape[-3] != 3: 27 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 28 | 29 | r: torch.Tensor = image[..., 0, :, :] 30 | g: torch.Tensor = image[..., 1, :, :] 31 | b: torch.Tensor = image[..., 2, :, :] 32 | 33 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b 34 | u: torch.Tensor = -0.147 * r - 0.289 * g + 0.436 * b 35 | v: torch.Tensor = 0.615 * r - 0.515 * g - 0.100 * b 36 | 37 | out: torch.Tensor = torch.stack([y, u, v], -3) 38 | 39 | return out 40 | -------------------------------------------------------------------------------- /imgs/DFNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/imgs/DFNet.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Package Version 2 | ----------------------- ------------------- 3 | absl-py 0.14.0 4 | antlr4-python3-runtime 4.8 5 | backcall 0.2.0 6 | beautifulsoup4 4.9.3 7 | brotlipy 0.7.0 8 | cached-property 1.5.2 9 | cachetools 4.2.4 10 | certifi 2021.5.30 11 | cffi 1.14.5 12 | chardet 3.0.4 13 | conda 4.10.1 14 | conda-build 3.21.4 15 | conda-package-handling 1.7.3 16 | ConfigArgParse 1.5.2 17 | cryptography 3.4.7 18 | cycler 0.10.0 19 | decorator 5.0.9 20 | dnspython 2.1.0 21 | efficientnet-pytorch 0.7.1 22 | einops 0.3.2 23 | filelock 3.0.12 24 | fvcore 0.1.5.post20210924 25 | glob2 0.7 26 | google-auth 1.35.0 27 | google-auth-oauthlib 0.4.6 28 | grpcio 1.41.0 29 | h5py 3.5.0 30 | idna 2.10 31 | imageio 2.9.0 32 | imageio-ffmpeg 0.4.5 33 | importlib-metadata 4.8.1 34 | iopath 0.1.9 35 | ipython 7.22.0 36 | ipython-genutils 0.2.0 37 | jedi 0.17.0 38 | Jinja2 3.0.0 39 | kiwisolver 1.3.2 40 | kornia 0.6.2 41 | libarchive-c 2.9 42 | Markdown 3.3.4 43 | MarkupSafe 2.0.1 44 | matplotlib 3.3.2 45 | mkl-fft 1.3.0 46 | mkl-random 1.2.1 47 | mkl-service 2.3.0 48 | numpy 1.20.2 49 | oauthlib 3.1.1 50 | olefile 0.46 51 | omegaconf 2.1.1 52 | opencv-python 4.4.0.46 53 | packaging 21.3 54 | parso 0.8.2 55 | pexpect 4.8.0 56 | pickleshare 0.7.5 57 | Pillow 8.2.0 58 | pip 21.1.2 59 | pkginfo 1.7.0 60 | portalocker 2.3.2 61 | prompt-toolkit 3.0.17 62 | protobuf 3.18.0 63 | psutil 5.8.0 64 | ptyprocess 0.7.0 65 | pyasn1 0.4.8 66 | pyasn1-modules 0.2.8 67 | pycosat 0.6.3 68 | pycparser 2.20 69 | Pygments 2.9.0 70 | pykalman 0.9.5 71 | pyOpenSSL 19.1.0 72 | pyparsing 2.4.7 73 | PySocks 1.7.1 74 | python-dateutil 2.8.2 75 | python-etcd 0.4.5 76 | pytorch3d 0.3.0 77 | pytz 2021.1 78 | PyYAML 5.4.1 79 | requests 2.24.0 80 | requests-oauthlib 1.3.0 81 | rsa 4.7.2 82 | ruamel-yaml-conda 0.15.100 83 | scipy 1.7.3 84 | setuptools 52.0.0.post20210125 85 | six 1.16.0 86 | soupsieve 2.2.1 87 | tabulate 0.8.9 88 | tensorboard 2.6.0 89 | tensorboard-data-server 0.6.1 90 | tensorboard-plugin-wit 1.8.0 91 | termcolor 1.1.0 92 | torch 1.11.0+cu113 93 | torchelastic 0.2.0 94 | torchsummary 1.5.1 95 | torchtext 0.10.0 96 | torchvision 0.10.0 97 | tqdm 4.51.0 98 | traitlets 5.0.5 99 | transforms3d 0.3.1 100 | trimesh 3.9.32 101 | typing-extensions 3.7.4.3 102 | urllib3 1.25.11 103 | wcwidth 0.2.5 104 | Werkzeug 2.0.1 105 | wheel 0.35.1 106 | yacs 0.1.8 107 | zipp 3.6.0 108 | -------------------------------------------------------------------------------- /script/config_dfnet.txt: -------------------------------------------------------------------------------- 1 | ############################################### NeRF-Hist training example Cambridge ############################################### 2 | model_name=dfnet 3 | basedir=../logs/kings 4 | expname=nerfh 5 | datadir=../data/Cambridge/KingsCollege 6 | dataset_type=Cambridge 7 | trainskip=2 # train 8 | testskip=1 # train 9 | df=2 10 | load_pose_avg_stats=True 11 | NeRFH=True 12 | epochs=2000 13 | encode_hist=True 14 | tinyimg=True 15 | DFNet=True 16 | tripletloss=True 17 | featurenet_batch_size=4 # batch size, 4 or 8 18 | random_view_synthesis=True 19 | rvs_refresh_rate=20 20 | rvs_trans=3 21 | rvs_rotation=7.5 22 | d_max=1 23 | # pretrain_model_path = ../logs/kings/dfnet/checkpoint-0604-0.2688.pt # add your trained model for eval 24 | # eval=True # add this for eval 25 | 26 | ############################################### NeRF-Hist training example 7-Scenes ############################################### 27 | # model_name=dfnet 28 | # basedir=../logs/heads 29 | # expname=nerfh 30 | # datadir=../data/7Scenes/heads 31 | # dataset_type=7Scenes 32 | # trainskip=5 # train 33 | # testskip=1 #train 34 | # df=2 # train 35 | # load_pose_avg_stats=True 36 | # NeRFH=True 37 | # epochs=2000 38 | # encode_hist=True 39 | # batch_size=1 # NeRF loader batch size 40 | # tinyimg=True 41 | # DFNet=True 42 | # tripletloss=True 43 | # featurenet_batch_size=4 44 | # val_batch_size=8 # new 45 | # random_view_synthesis=True 46 | # rvs_refresh_rate=20 47 | # rvs_trans=0.2 48 | # rvs_rotation=10 49 | # d_max=0.2 50 | # # pretrain_model_path = ../logs/heads/dfnet/checkpoint-0888-0.0025.pt # add your trained model for eval 51 | # # eval=True # add this for eval -------------------------------------------------------------------------------- /script/config_dfnetdm.txt: -------------------------------------------------------------------------------- 1 | ############################################### NeRF-Hist training example Cambridge ############################################### 2 | model_name=dfnetdm 3 | expname=nerfh 4 | basedir=../logs/kings # change this if change scenes 5 | datadir=../data/Cambridge/KingsCollege # change this if change scenes 6 | dataset_type=Cambridge 7 | pretrain_model_path=../logs/kings/dfnet/checkpoint-0604-0.2688.pt # this is your trained dfnet model for pose regression 8 | pretrain_featurenet_path=../logs/kings/dfnet/checkpoint-0604-0.2688.pt # this is your trained dfnet model for feature extraction 9 | trainskip=2 # train 10 | testskip=1 # train 11 | df=2 12 | load_pose_avg_stats=True 13 | NeRFH=True 14 | encode_hist=True 15 | freezeBN=True 16 | featuremetric=True 17 | pose_only=3 18 | svd_reg=True 19 | combine_loss = True 20 | combine_loss_w = [0., 0., 1.] 21 | finetune_unlabel=True 22 | i_eval=20 23 | DFNet=True 24 | val_on_psnr=True 25 | feature_matching_lvl = [0] 26 | # eval=True # add this for eval 27 | # pretrain_model_path=../logs/kings/dfnetdm/checkpoint-0267-17.1446.pt # add the trained model for eval 28 | 29 | 30 | ############################################### NeRF-Hist training example 7-Scenes ############################################### 31 | # model_name=dfnetdm 32 | # expname=nerfh 33 | # basedir=../logs/heads 34 | # datadir=../data/7Scenes/heads 35 | # dataset_type=7Scenes 36 | # pretrain_model_path=../logs/heads/dfnet/checkpoint-0888-0.0025.pt # this is your trained dfnet model for pose regression 37 | # pretrain_featurenet_path=../logs/heads/dfnet/checkpoint-0888-0.0025.pt # this is your trained dfnet model for feature extraction 38 | # trainskip=5 # train 39 | # testskip=1 #train 40 | # df=2 41 | # load_pose_avg_stats=True 42 | # NeRFH=True 43 | # encode_hist=True 44 | # freezeBN=True 45 | # featuremetric=True 46 | # pose_only=3 47 | # svd_reg=True 48 | # combine_loss = True 49 | # combine_loss_w = [0., 0., 1.] 50 | # finetune_unlabel=True 51 | # i_eval=20 52 | # DFNet=True 53 | # val_on_psnr=True 54 | # feature_matching_lvl = [0] 55 | # # eval=True # add this for eval 56 | # # pretrain_model_path=../logs/heads/dfnetdm/checkpoint-0317-17.5881.pt # add the trained model for eval -------------------------------------------------------------------------------- /script/config_nerfh.txt: -------------------------------------------------------------------------------- 1 | ############################################### NeRF-Hist training example Cambridge ############################################### 2 | expname=nerfh 3 | basedir=../logs/kings 4 | datadir=../data/Cambridge/KingsCollege 5 | dataset_type=Cambridge 6 | lrate_decay=5 7 | trainskip=2 8 | testskip=1 9 | df=4 10 | load_pose_avg_stats=True 11 | NeRFH=True 12 | encode_hist=True 13 | # render_test=True # add this for eval 14 | 15 | ############################################### NeRF-Hist training example 7-Scenes ############################################### 16 | # expname=nerfh 17 | # basedir=../logs/heads 18 | # datadir=../data/7Scenes/heads 19 | # dataset_type=7Scenes 20 | # lrate_decay=0.754 21 | # trainskip=5 22 | # testskip=50 23 | # df=4 24 | # load_pose_avg_stats=True 25 | # NeRFH=True 26 | # encode_hist=True 27 | # # testskip=1 # add this for eval 28 | # # render_test=True # add this for eval -------------------------------------------------------------------------------- /script/dm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/script/dm/__init__.py -------------------------------------------------------------------------------- /script/dm/callbacks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import os, pdb 5 | 6 | def Callback(): 7 | #TODO: Callback func. https://dzlab.github.io/dl/2019/03/16/pytorch-training-loop/ 8 | def __init__(self): pass 9 | def on_train_begin(self): pass 10 | def on_train_end(self): pass 11 | def on_epoch_begin(self): pass 12 | def on_epoch_end(self): pass 13 | def on_batch_begin(self): pass 14 | def on_batch_end(self): pass 15 | def on_loss_begin(self): pass 16 | def on_loss_end(self): pass 17 | def on_step_begin(self): pass 18 | def on_step_end(self): pass 19 | 20 | class EarlyStopping: 21 | """Early stops the training if validation loss doesn't improve after a given patience.""" 22 | # source https://blog.csdn.net/qq_37430422/article/details/103638681 23 | def __init__(self, args, patience=50, verbose=False, delta=0): 24 | """ 25 | Args: 26 | patience (int): How long to wait after last time validation loss improved. 27 | Default: 50 28 | verbose (bool): If True, prints a message for each validation loss improvement. 29 | Default: False 30 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 31 | Default: 0 32 | """ 33 | self.val_on_psnr = args.val_on_psnr 34 | self.patience = patience 35 | self.verbose = verbose 36 | self.counter = 0 37 | self.best_score = None 38 | self.early_stop = False 39 | self.val_loss_min = np.Inf 40 | self.delta = delta 41 | 42 | self.basedir = args.basedir 43 | self.model_name = args.model_name 44 | 45 | self.out_folder = os.path.join(self.basedir, self.model_name) 46 | self.ckpt_save_path = os.path.join(self.out_folder, 'checkpoint.pt') 47 | if not os.path.isdir(self.out_folder): 48 | os.mkdir(self.out_folder) 49 | 50 | def __call__(self, val_loss, model, epoch=-1, save_multiple=False, save_all=False, val_psnr=None): 51 | 52 | # find maximum psnr 53 | if self.val_on_psnr: 54 | score = val_psnr 55 | if self.best_score is None: 56 | self.best_score = score 57 | self.save_checkpoint(val_psnr, model, epoch=epoch, save_multiple=save_multiple) 58 | elif score < self.best_score + self.delta: 59 | self.counter += 1 60 | 61 | if self.counter >= self.patience: 62 | self.early_stop = True 63 | 64 | if save_all: # save all ckpt 65 | self.save_checkpoint(val_psnr, model, epoch=epoch, save_multiple=True, update_best=False) 66 | else: # save best ckpt only 67 | self.best_score = score 68 | self.save_checkpoint(val_psnr, model, epoch=epoch, save_multiple=save_multiple) 69 | self.counter = 0 70 | 71 | # find minimum loss 72 | else: 73 | score = -val_loss 74 | if self.best_score is None: 75 | self.best_score = score 76 | self.save_checkpoint(val_loss, model, epoch=epoch, save_multiple=save_multiple) 77 | elif score < self.best_score + self.delta: 78 | self.counter += 1 79 | 80 | if self.counter >= self.patience: 81 | self.early_stop = True 82 | 83 | if save_all: # save all ckpt 84 | self.save_checkpoint(val_loss, model, epoch=epoch, save_multiple=True, update_best=False) 85 | else: # save best ckpt only 86 | self.best_score = score 87 | self.save_checkpoint(val_loss, model, epoch=epoch, save_multiple=save_multiple) 88 | self.counter = 0 89 | 90 | def save_checkpoint(self, val_loss, model, epoch=-1, save_multiple=False, update_best=True): 91 | '''Saves model when validation loss decrease.''' 92 | if self.verbose: 93 | tqdm.write(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 94 | ckpt_save_path = self.ckpt_save_path 95 | if save_multiple: 96 | ckpt_save_path = ckpt_save_path[:-3]+f'-{epoch:04d}-{val_loss:.4f}.pt' 97 | 98 | torch.save(model.state_dict(), ckpt_save_path) 99 | if update_best: 100 | self.val_loss_min = val_loss 101 | 102 | def isBestModel(self): 103 | ''' Check if current model the best one. 104 | get early stop counter, if counter==0: it means current model has the best validation loss 105 | ''' 106 | return self.counter==0 -------------------------------------------------------------------------------- /script/dm/options.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | def config_parser(): 3 | parser = configargparse.ArgumentParser() 4 | parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1") 5 | parser.add_argument("--multi_gpu", action='store_true', help='use multiple gpu on the server') 6 | 7 | # 7Scenes 8 | parser.add_argument("--trainskip", type=int, default=1, help='will load 1/N images from train sets, useful for large datasets like 7 Scenes') 9 | parser.add_argument("--df", type=float, default=1., help='image downscale factor') 10 | parser.add_argument("--reduce_embedding", type=int, default=-1, help='fourier embedding mode: -1: paper default, \ 11 | 0: reduce by half, 1: remove embedding, 2: DNeRF embedding') 12 | parser.add_argument("--epochToMaxFreq", type=int, default=-1, help='DNeRF embedding mode: (based on DNeRF paper): \ 13 | hyper-parameter for when α should reach the maximum number of frequencies m') 14 | parser.add_argument("--render_pose_only", action='store_true', help='render a spiral video for 7 Scene') 15 | parser.add_argument("--save_pose_avg_stats", action='store_true', help='save a pose avg stats to unify NeRF, posenet, nerf tracking training') 16 | parser.add_argument("--load_pose_avg_stats", action='store_true', help='load precomputed pose avg stats to unify NeRF, posenet, nerf tracking training') 17 | parser.add_argument("--finetune_unlabel", action='store_true', help='finetune unlabeled sequence like MapNet') 18 | parser.add_argument("--i_eval", type=int, default=50, help='frequency of eval posenet result') 19 | parser.add_argument("--save_all_ckpt", action='store_true', help='save all ckpts for each epoch') 20 | parser.add_argument("--train_local_nerf", type=int, default=-1, help='train local NeRF with ith training sequence only, ie. Stairs can pick 0~3') 21 | parser.add_argument("--render_video_train", action='store_true', help='render train set NeRF and save as video, make sure i_eval is True') 22 | parser.add_argument("--render_video_test", action='store_true', help='render val set NeRF and save as video, make sure i_eval is True') 23 | parser.add_argument("--no_DNeRF_viewdir", action='store_true', default=False, help='will not use DNeRF in viewdir encoding') 24 | parser.add_argument("--val_on_psnr", action='store_true', default=False, help='EarlyStopping with max validation psnr') 25 | parser.add_argument("--feature_matching_lvl", nargs='+', type=int, default=[0,1,2], 26 | help='lvl of features used for feature matching, default use lvl 0, 1, 2') 27 | 28 | ##################### PoseNet Settings ######################## 29 | parser.add_argument("--pose_only", type=int, default=0, help='posenet type to train, \ 30 | 1: train baseline posenet, 2: posenet+nerf manual optimize, \ 31 | 3: featurenet,') 32 | parser.add_argument("--learning_rate", type=float, default=0.00001, help='learning rate') 33 | parser.add_argument("--batch_size", type=int, default=1, help='train posenet only') 34 | parser.add_argument("--pretrain_model_path", type=str, default='', help='model path of pretrained pose regrssion model') 35 | parser.add_argument("--pretrain_featurenet_path", type=str, default='', help='model path of pretrained featurenet model') 36 | parser.add_argument("--model_name", type=str, help='pose model output folder name') 37 | parser.add_argument("--combine_loss", action='store_true', 38 | help='combined l2 pose loss + rgb mse loss') 39 | parser.add_argument("--combine_loss_w", nargs='+', type=float, default=[0.5, 0.5], 40 | help='weights of combined loss ex, [0.5 0.5], \ 41 | default None, only use when combine_loss is True') 42 | parser.add_argument("--patience", nargs='+', type=int, default=[200, 50], help='set training schedule for patience [EarlyStopping, reduceLR]') 43 | parser.add_argument("--resize_factor", type=int, default=2, help='image resize downsample ratio') 44 | parser.add_argument("--freezeBN", action='store_true', help='Freeze the Batch Norm layer at training PoseNet') 45 | parser.add_argument("--preprocess_ImgNet", action='store_true', help='Normalize input data for PoseNet') 46 | parser.add_argument("--eval", action='store_true', help='eval model') 47 | parser.add_argument("--no_save_multiple", action='store_true', help='default, save multiple posenet model, if true, save only one posenet model') 48 | parser.add_argument("--resnet34", action='store_true', default=False, help='use resnet34 backbone instead of mobilenetV2') 49 | parser.add_argument("--efficientnet", action='store_true', default=False, help='use efficientnet-b3 backbone instead of mobilenetV2') 50 | parser.add_argument("--efficientnet_block", type=int, default=6, help='choose which features from feature block (0-6) of efficientnet to use') 51 | parser.add_argument("--dropout", type=float, default=0.5, help='dropout rate for resnet34 backbone') 52 | parser.add_argument("--DFNet", action='store_true', default=False, help='use DFNet') 53 | parser.add_argument("--DFNet_s", action='store_true', default=False, help='use accelerated DFNet, performance is similar to DFNet but slightly faster') 54 | parser.add_argument("--val_batch_size", type=int, default=1, help='batch_size for validation, higher number leads to faster speed') 55 | 56 | ##################### NeRF Settings ######################## 57 | parser.add_argument('--config', is_config_file=True, help='config file path') 58 | parser.add_argument("--expname", type=str, help='experiment name') 59 | parser.add_argument("--basedir", type=str, default='../logs/', help='where to store ckpts and logs') 60 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') 61 | 62 | # training options 63 | parser.add_argument("--netdepth", type=int, default=8, help='layers in network') 64 | parser.add_argument("--netwidth", type=int, default=128, help='channels per layer') 65 | parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network') 66 | parser.add_argument("--netwidth_fine", type=int, default=128, help='channels per layer in fine network') 67 | parser.add_argument("--N_rand", type=int, default=1536, help='batch size (number of random rays per gradient step)') 68 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') 69 | parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)') 70 | parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') 71 | parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory') 72 | parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time') 73 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') 74 | parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network') 75 | parser.add_argument("--no_grad_update", default=True, help='do not update nerf in training') 76 | parser.add_argument("--per_channel", default=False, action='store_true', help='using per channel cosine similarity loss instead of per pixel, defualt False') 77 | 78 | # NeRF-Hist training options 79 | parser.add_argument("--NeRFH", action='store_true', help='new implementation for NeRFH') 80 | parser.add_argument("--N_vocab", type=int, default=1000, 81 | help='''number of vocabulary (number of images) 82 | in the dataset for nn.Embedding''') 83 | parser.add_argument("--fix_index", action='store_true', help='fix training frame index as 0') 84 | parser.add_argument("--encode_hist", default=False, action='store_true', help='encode histogram instead of frame index') 85 | parser.add_argument("--hist_bin", type=int, default=10, help='image histogram bin size') 86 | parser.add_argument("--in_channels_a", type=int, default=50, help='appearance embedding dimension, hist_bin*N_a when embedding histogram') 87 | parser.add_argument("--in_channels_t", type=int, default=20, help='transient embedding dimension, hist_bin*N_tau when embedding histogram') 88 | parser.add_argument("--svd_reg", default=False, action='store_true', help='use svd regularize output at training') 89 | 90 | # rendering options 91 | parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray') 92 | parser.add_argument("--N_importance", type=int, default=64,help='number of additional fine samples per ray') 93 | parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter') 94 | parser.add_argument("--use_viewdirs", action='store_true', default=True, help='use full 5D input instead of 3D') 95 | parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none') 96 | parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') 97 | parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') 98 | parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 99 | 100 | parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') 101 | parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') 102 | parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 103 | 104 | # legacy mesh options 105 | parser.add_argument("--mesh_only", action='store_true', help='do not optimize, reload weights and save mesh to a file') 106 | parser.add_argument("--mesh_grid_size", type=int, default=80,help='number of grid points to sample in each dimension for marching cubes') 107 | 108 | # training options 109 | parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops') 110 | parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops') 111 | 112 | # dataset options 113 | parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels') 114 | parser.add_argument("--testskip", type=int, default=1, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 115 | 116 | ## legacy blender flags 117 | parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') 118 | parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800') 119 | 120 | ## llff flags 121 | parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images') 122 | parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') 123 | parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth') 124 | parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes') 125 | parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8') 126 | parser.add_argument("--no_bd_factor", action='store_true', default=False, help='do not use bd factor') 127 | 128 | # featruremetric supervision 129 | parser.add_argument("--featuremetric", action='store_true', help='use featuremetric supervision if true') 130 | 131 | # logging/saving options 132 | parser.add_argument("--i_print", type=int, default=1, help='frequency of console printout and metric loggin') 133 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') 134 | parser.add_argument("--i_weights", type=int, default=200, help='frequency of weight ckpt saving') 135 | parser.add_argument("--i_testset", type=int, default=200, help='frequency of testset saving') 136 | parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') 137 | 138 | return parser -------------------------------------------------------------------------------- /script/dm/pose_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init 5 | import numpy as np 6 | from torchvision import models 7 | from efficientnet_pytorch import EfficientNet 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | import pdb 11 | import matplotlib.pyplot as plt 12 | 13 | import math 14 | import time 15 | 16 | import pytorch3d.transforms as transforms 17 | 18 | def preprocess_data(inputs, device): 19 | # normalize inputs according to https://pytorch.org/hub/pytorch_vision_mobilenet_v2/ 20 | mean = torch.Tensor([0.485, 0.456, 0.406]).to(device) # per channel subtraction 21 | std = torch.Tensor([0.229, 0.224, 0.225]).to(device) # per channel division 22 | inputs = (inputs - mean[None,:,None,None])/std[None,:,None,None] 23 | return inputs 24 | 25 | def filter_hook(m, g_in, g_out): 26 | g_filtered = [] 27 | for g in g_in: 28 | g = g.clone() 29 | g[g != g] = 0 30 | g_filtered.append(g) 31 | return tuple(g_filtered) 32 | 33 | def vis_pose(vis_info): 34 | ''' 35 | visualize predicted pose result vs. gt pose 36 | ''' 37 | pdb.set_trace() 38 | pose = vis_info['pose'] 39 | pose_gt = vis_info['pose_gt'] 40 | theta = vis_info['theta'] 41 | ang_threshold=10 42 | seq_num = theta.shape[0] 43 | # # create figure object 44 | # plot translation traj. 45 | fig = plt.figure(figsize = (8,6)) 46 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1) 47 | ax1 = fig.add_axes([0, 0.2, 0.9, 0.85], projection='3d') 48 | ax1.scatter(pose[10:,0],pose[10:,1],zs=pose[10:,2], c='r', s=3**2,depthshade=0) # predict 49 | ax1.scatter(pose_gt[:,0], pose_gt[:,1], zs=pose_gt[:,2], c='g', s=3**2,depthshade=0) # GT 50 | ax1.scatter(pose[0:10,0],pose[0:10,1],zs=pose[0:10,2], c='k', s=3**2,depthshade=0) # predict 51 | # ax1.plot(pose[:,0],pose[:,1],zs=pose[:,2], c='r') # predict 52 | # ax1.plot(pose_gt[:,0], pose_gt[:,1], zs=pose_gt[:,2], c='g') # GT 53 | ax1.view_init(30, 120) 54 | ax1.set_xlabel('x (m)') 55 | ax1.set_ylabel('y (m)') 56 | ax1.set_zlabel('z (m)') 57 | # ax1.set_xlim(-10, 10) 58 | # ax1.set_ylim(-10, 10) 59 | # ax1.set_zlim(-10, 10) 60 | 61 | ax1.set_xlim(-1, 1) 62 | ax1.set_ylim(-1, 1) 63 | ax1.set_zlim(-1, 1) 64 | 65 | # ax1.set_xlim(-3, 3) 66 | # ax1.set_ylim(-3, 3) 67 | # ax1.set_zlim(-3, 3) 68 | 69 | # plot angular error 70 | ax2 = fig.add_axes([0.1, 0.05, 0.75, 0.2]) 71 | err = theta.reshape(1, seq_num) 72 | err = np.tile(err, (20, 1)) 73 | ax2.imshow(err, vmin=0,vmax=ang_threshold, aspect=3) 74 | ax2.set_yticks([]) 75 | ax2.set_xticks([0, seq_num*1/5, seq_num*2/5, seq_num*3/5, seq_num*4/5, seq_num]) 76 | fname = './vis_pose.png' 77 | plt.savefig(fname, dpi=50) 78 | 79 | def compute_error_in_q(args, dl, model, device, results, batch_size=1): 80 | use_SVD=True # Turn on for Direct-PN and Direct-PN+U reported result, despite it makes minuscule differences 81 | time_spent = [] 82 | predict_pose_list = [] 83 | gt_pose_list = [] 84 | ang_error_list = [] 85 | pose_result_raw = [] 86 | pose_GT = [] 87 | i = 0 88 | 89 | for batch in dl: 90 | if args.NeRFH: 91 | data, pose, img_idx = batch 92 | else: 93 | data, pose = batch 94 | data = data.to(device) # input 95 | pose = pose.reshape((batch_size,3,4)).numpy() # label 96 | 97 | if args.preprocess_ImgNet: 98 | data = preprocess_data(data, device) 99 | 100 | if use_SVD: 101 | # using SVD to make sure predict rotation is normalized rotation matrix 102 | with torch.no_grad(): 103 | if args.featuremetric: 104 | _, predict_pose = model(data) 105 | else: 106 | predict_pose = model(data) 107 | 108 | R_torch = predict_pose.reshape((batch_size, 3, 4))[:,:3,:3] # debug 109 | predict_pose = predict_pose.reshape((batch_size, 3, 4)).cpu().numpy() 110 | 111 | R = predict_pose[:,:3,:3] 112 | res = R@np.linalg.inv(R) 113 | # print('R@np.linalg.inv(R):', res) 114 | 115 | u,s,v=torch.svd(R_torch) 116 | Rs = torch.matmul(u, v.transpose(-2,-1)) 117 | predict_pose[:,:3,:3] = Rs[:,:3,:3].cpu().numpy() 118 | else: 119 | start_time = time.time() 120 | # inference NN 121 | with torch.no_grad(): 122 | predict_pose = model(data) 123 | predict_pose = predict_pose.reshape((batch_size, 3, 4)).cpu().numpy() 124 | time_spent.append(time.time() - start_time) 125 | 126 | pose_q = transforms.matrix_to_quaternion(torch.Tensor(pose[:,:3,:3]))#.cpu().numpy() # gnd truth in quaternion 127 | pose_x = pose[:, :3, 3] # gnd truth position 128 | predicted_q = transforms.matrix_to_quaternion(torch.Tensor(predict_pose[:,:3,:3]))#.cpu().numpy() # predict in quaternion 129 | predicted_x = predict_pose[:, :3, 3] # predict position 130 | pose_q = pose_q.squeeze() 131 | pose_x = pose_x.squeeze() 132 | predicted_q = predicted_q.squeeze() 133 | predicted_x = predicted_x.squeeze() 134 | 135 | #Compute Individual Sample Error 136 | q1 = pose_q / torch.linalg.norm(pose_q) 137 | q2 = predicted_q / torch.linalg.norm(predicted_q) 138 | d = torch.abs(torch.sum(torch.matmul(q1,q2))) 139 | d = torch.clamp(d, -1., 1.) # acos can only input [-1~1] 140 | theta = (2 * torch.acos(d) * 180/math.pi).numpy() 141 | error_x = torch.linalg.norm(torch.Tensor(pose_x-predicted_x)).numpy() 142 | results[i,:] = [error_x, theta] 143 | #print ('Iteration: {} Error XYZ (m): {} Error Q (degrees): {}'.format(i, error_x, theta)) 144 | 145 | # save results for visualization 146 | predict_pose_list.append(predicted_x) 147 | gt_pose_list.append(pose_x) 148 | ang_error_list.append(theta) 149 | pose_result_raw.append(predict_pose) 150 | pose_GT.append(pose) 151 | i += 1 152 | # pdb.set_trace() 153 | predict_pose_list = np.array(predict_pose_list) 154 | gt_pose_list = np.array(gt_pose_list) 155 | ang_error_list = np.array(ang_error_list) 156 | pose_result_raw = np.asarray(pose_result_raw)[:,0,:,:] 157 | pose_GT = np.asarray(pose_GT)[:,0,:,:] 158 | vis_info_ret = {"pose": predict_pose_list, "pose_gt": gt_pose_list, "theta": ang_error_list, "pose_result_raw": pose_result_raw, "pose_GT": pose_GT} 159 | return results, vis_info_ret 160 | 161 | # # pytorch 162 | def get_error_in_q(args, dl, model, sample_size, device, batch_size=1): 163 | ''' Convert Rotation matrix to quaternion, then calculate the location errors. original from PoseNet Paper ''' 164 | model.eval() 165 | 166 | results = np.zeros((sample_size, 2)) 167 | results, vis_info = compute_error_in_q(args, dl, model, device, results, batch_size) 168 | median_result = np.median(results,axis=0) 169 | mean_result = np.mean(results,axis=0) 170 | 171 | # standard log 172 | print ('Median error {}m and {} degrees.'.format(median_result[0], median_result[1])) 173 | print ('Mean error {}m and {} degrees.'.format(mean_result[0], mean_result[1])) 174 | 175 | # timing log 176 | #print ('Avg execution time (sec): {:.3f}'.format(np.mean(time_spent))) 177 | 178 | # standard log2 179 | # num_translation_less_5cm = np.asarray(np.where(results[:,0]<0.05))[0] 180 | # num_rotation_less_5 = np.asarray(np.where(results[:,1]<5))[0] 181 | # print ('translation error less than 5cm {}/{}.'.format(num_translation_less_5cm.shape[0], results.shape[0])) 182 | # print ('rotation error less than 5 degree {}/{}.'.format(num_rotation_less_5.shape[0], results.shape[0])) 183 | # print ('results:', results) 184 | 185 | # save for direct-pn paper log 186 | # if 0: 187 | # filename='Direct-PN+U_' + args.datadir.split('/')[-1] + '_result.txt' 188 | # np.savetxt(filename, predict_pose) 189 | 190 | # visualize results 191 | # vis_pose(vis_info) 192 | 193 | class EfficientNetB3(nn.Module): 194 | ''' EfficientNet-B3 backbone, 195 | model ref: https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py 196 | ''' 197 | def __init__(self, feat_dim=12): 198 | super(EfficientNetB3, self).__init__() 199 | self.backbone_net = EfficientNet.from_pretrained('efficientnet-b3') 200 | self.feature_extractor = self.backbone_net.extract_features 201 | self.avgpool = nn.AdaptiveAvgPool2d(1) 202 | self.fc_pose = nn.Linear(1536, feat_dim) # 1280 for efficientnet-b0, 1536 for efficientnet-b3 203 | 204 | def forward(self, Input): 205 | x = self.feature_extractor(Input) 206 | x = self.avgpool(x) 207 | x = x.reshape(x.size(0), -1) 208 | predict = self.fc_pose(x) 209 | return predict 210 | 211 | # PoseNet (SE(3)) w/ mobilev2 backbone 212 | class PoseNetV2(nn.Module): 213 | def __init__(self, feat_dim=12): 214 | super(PoseNetV2, self).__init__() 215 | self.backbone_net = models.mobilenet_v2(pretrained=True) 216 | self.feature_extractor = self.backbone_net.features 217 | self.avgpool = nn.AdaptiveAvgPool2d(1) 218 | self.fc_pose = nn.Linear(1280, feat_dim) 219 | 220 | def forward(self, Input): 221 | x = self.feature_extractor(Input) 222 | x = self.avgpool(x) 223 | x = x.reshape(x.size(0), -1) 224 | predict = self.fc_pose(x) 225 | # pdb.set_trace() 226 | return predict 227 | 228 | # PoseNet (SE(3)) w/ resnet34 backnone. We found dropout layer is unnecessary, so we set droprate as 0 in reported results. 229 | class PoseNet_res34(nn.Module): 230 | def __init__(self, droprate=0.5, pretrained=True, 231 | feat_dim=2048): 232 | super(PoseNet_res34, self).__init__() 233 | self.droprate = droprate 234 | 235 | # replace the last FC layer in feature extractor 236 | self.feature_extractor = models.resnet34(pretrained=True) 237 | self.feature_extractor.avgpool = nn.AdaptiveAvgPool2d(1) 238 | fe_out_planes = self.feature_extractor.fc.in_features 239 | self.feature_extractor.fc = nn.Linear(fe_out_planes, feat_dim) 240 | self.fc_pose = nn.Linear(feat_dim, 12) 241 | 242 | # initialize 243 | if pretrained: 244 | init_modules = [self.feature_extractor.fc] 245 | else: 246 | init_modules = self.modules() 247 | 248 | for m in init_modules: 249 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 250 | nn.init.kaiming_normal_(m.weight.data) 251 | if m.bias is not None: 252 | nn.init.constant_(m.bias.data, 0) 253 | 254 | def forward(self, x): 255 | x = self.feature_extractor(x) 256 | x = F.relu(x) 257 | if self.droprate > 0: 258 | x = F.dropout(x, p=self.droprate) 259 | predict = self.fc_pose(x) 260 | return predict 261 | 262 | 263 | # from MapNet paper CVPR 2018 264 | class PoseNet(nn.Module): 265 | def __init__(self, feature_extractor, droprate=0.5, pretrained=True, 266 | feat_dim=2048, filter_nans=False): 267 | super(PoseNet, self).__init__() 268 | self.droprate = droprate 269 | 270 | # replace the last FC layer in feature extractor 271 | self.feature_extractor = models.resnet34(pretrained=True) 272 | self.feature_extractor.avgpool = nn.AdaptiveAvgPool2d(1) 273 | fe_out_planes = self.feature_extractor.fc.in_features 274 | self.feature_extractor.fc = nn.Linear(fe_out_planes, feat_dim) 275 | 276 | self.fc_xyz = nn.Linear(feat_dim, 3) 277 | self.fc_wpqr = nn.Linear(feat_dim, 3) 278 | if filter_nans: 279 | self.fc_wpqr.register_backward_hook(hook=filter_hook) 280 | # initialize 281 | if pretrained: 282 | init_modules = [self.feature_extractor.fc, self.fc_xyz, self.fc_wpqr] 283 | else: 284 | init_modules = self.modules() 285 | 286 | for m in init_modules: 287 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 288 | nn.init.kaiming_normal_(m.weight.data) 289 | if m.bias is not None: 290 | nn.init.constant_(m.bias.data, 0) 291 | 292 | def forward(self, x): 293 | x = self.feature_extractor(x) 294 | x = F.relu(x) 295 | if self.droprate > 0: 296 | x = F.dropout(x, p=self.droprate) 297 | 298 | xyz = self.fc_xyz(x) 299 | wpqr = self.fc_wpqr(x) 300 | return torch.cat((xyz, wpqr), 1) 301 | 302 | class MapNet(nn.Module): 303 | """ 304 | Implements the MapNet model (green block in Fig. 2 of paper) 305 | """ 306 | def __init__(self, mapnet): 307 | """ 308 | :param mapnet: the MapNet (two CNN blocks inside the green block in Fig. 2 309 | of paper). Not to be confused with MapNet, the model! 310 | """ 311 | super(MapNet, self).__init__() 312 | self.mapnet = mapnet 313 | 314 | def forward(self, x): 315 | """ 316 | :param x: image blob (N x T x C x H x W) 317 | :return: pose outputs 318 | (N x T x 6) 319 | """ 320 | s = x.size() 321 | x = x.view(-1, *s[2:]) 322 | poses = self.mapnet(x) 323 | poses = poses.view(s[0], s[1], -1) 324 | return poses 325 | 326 | def eval_on_epoch(args, dl, model, optimizer, loss_func, device): 327 | model.eval() 328 | val_loss_epoch = [] 329 | for data, pose in dl: 330 | inputs = data.to(device) 331 | labels = pose.to(device) 332 | if args.preprocess_ImgNet: 333 | inputs = preprocess_data(inputs, device) 334 | predict = model(inputs) 335 | loss = loss_func(predict, labels) 336 | val_loss_epoch.append(loss.item()) 337 | val_loss_epoch_mean = np.mean(val_loss_epoch) 338 | return val_loss_epoch_mean 339 | 340 | 341 | def train_on_epoch(args, dl, model, optimizer, loss_func, device): 342 | model.train() 343 | train_loss_epoch = [] 344 | for data, pose in dl: 345 | inputs = data.to(device) # (N, Ch, H, W) ~ (4,3,200,200), 7scenes [4, 3, 256, 341] wierd shape... 346 | labels = pose.to(device) 347 | if args.preprocess_ImgNet: 348 | inputs = preprocess_data(inputs, device) 349 | 350 | predict = model(inputs) 351 | loss = loss_func(predict, labels) 352 | loss.backward() 353 | optimizer.step() 354 | optimizer.zero_grad() 355 | train_loss_epoch.append(loss.item()) 356 | train_loss_epoch_mean = np.mean(train_loss_epoch) 357 | return train_loss_epoch_mean 358 | 359 | def train_posenet(args, train_dl, val_dl, model, epochs, optimizer, loss_func, scheduler, device, early_stopping): 360 | writer = SummaryWriter() 361 | model_log = tqdm(total=0, position=1, bar_format='{desc}') 362 | for epoch in tqdm(range(epochs), desc='epochs'): 363 | 364 | # train 1 epoch 365 | train_loss = train_on_epoch(args, train_dl, model, optimizer, loss_func, device) 366 | writer.add_scalar("Loss/train", train_loss, epoch) 367 | 368 | # validate every epoch 369 | val_loss = eval_on_epoch(args, val_dl, model, optimizer, loss_func, device) 370 | writer.add_scalar("Loss/val", val_loss, epoch) 371 | 372 | # reduce LR on plateau 373 | scheduler.step(val_loss) 374 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch) 375 | 376 | # logging 377 | tqdm.write('At epoch {0:6d} : train loss: {1:.4f}, val loss: {2:.4f}'.format(epoch, train_loss, val_loss)) 378 | 379 | # check wether to early stop 380 | early_stopping(val_loss, model, epoch=epoch, save_multiple=(not args.no_save_multiple), save_all=args.save_all_ckpt) 381 | if early_stopping.early_stop: 382 | print("Early stopping") 383 | break 384 | 385 | model_log.set_description_str(f'Best val loss: {early_stopping.val_loss_min:.4f}') 386 | 387 | if epoch % args.i_eval == 0: 388 | get_error_in_q(args, val_dl, model, len(val_dl.dataset), device, batch_size=1) 389 | 390 | 391 | writer.flush() 392 | -------------------------------------------------------------------------------- /script/dm/prepare_data.py: -------------------------------------------------------------------------------- 1 | import utils.set_sys_path 2 | import torch 3 | from torch.utils.data import TensorDataset, DataLoader 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | from dataset_loaders.load_llff import load_llff_data 8 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 9 | 10 | def prepare_data(args, images, poses_train, i_split): 11 | ''' prepare data for ready to train posenet, return dataloaders ''' 12 | #TODO: Convert GPU friendly data generator later: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel 13 | #TODO: Probably a better implementation style here: https://github.com/PyTorchLightning/pytorch-lightning 14 | 15 | i_train, i_val, i_test = i_split 16 | 17 | img_train = torch.Tensor(images[i_train]).permute(0, 3, 1, 2) # now shape is [N, CH, H, W] 18 | pose_train = torch.Tensor(poses_train[i_train]) 19 | 20 | trainset = TensorDataset(img_train, pose_train) 21 | train_dl = DataLoader(trainset, batch_size=args.batch_size, shuffle=True) 22 | 23 | img_val = torch.Tensor(images[i_val]).permute(0, 3, 1, 2) # now shape is [N, CH, H, W] 24 | pose_val = torch.Tensor(poses_train[i_val]) 25 | 26 | valset = TensorDataset(img_val, pose_val) 27 | val_dl = DataLoader(valset) 28 | 29 | img_test = torch.Tensor(images[i_test]).permute(0, 3, 1, 2) # now shape is [N, CH, H, W] 30 | pose_test = torch.Tensor(poses_train[i_test]) 31 | 32 | testset = TensorDataset(img_test, pose_test) 33 | test_dl = DataLoader(testset) 34 | 35 | return train_dl, val_dl, test_dl 36 | 37 | def load_dataset(args): 38 | ''' load posenet training data ''' 39 | if args.dataset_type == 'llff': 40 | if args.no_bd_factor: 41 | bd_factor = None 42 | else: 43 | bd_factor = 0.75 44 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 45 | recenter=True, bd_factor=bd_factor, 46 | spherify=args.spherify) 47 | 48 | hwf = poses[0,:3,-1] 49 | poses = poses[:,:3,:4] 50 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 51 | if not isinstance(i_test, list): 52 | i_test = [i_test] 53 | 54 | if args.llffhold > 0: 55 | print('Auto LLFF holdout,', args.llffhold) 56 | i_test = np.arange(images.shape[0])[::args.llffhold] 57 | 58 | i_val = i_test 59 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 60 | (i not in i_test and i not in i_val)]) 61 | 62 | print('DEFINING BOUNDS') 63 | if args.no_ndc: 64 | near = np.ndarray.min(bds) * .9 65 | far = np.ndarray.max(bds) * 1. 66 | 67 | else: 68 | near = 0. 69 | far = 1. 70 | 71 | if args.finetune_unlabel: 72 | i_train = i_test 73 | i_split = [i_train, i_val, i_test] 74 | else: 75 | print('Unknown dataset type', args.dataset_type, 'exiting') 76 | return 77 | 78 | poses_train = poses[:,:3,:].reshape((poses.shape[0],12)) # get rid of last row [0,0,0,1] 79 | print("images.shape {}, poses_train.shape {}".format(images.shape, poses_train.shape)) 80 | 81 | INPUT_SHAPE = images[0].shape 82 | print("=====================================================================") 83 | print("INPUT_SHAPE:", INPUT_SHAPE) 84 | return images, poses_train, render_poses, hwf, i_split, near, far -------------------------------------------------------------------------------- /script/feature/dfnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from typing import List 6 | 7 | # VGG-16 Layer Names and Channels 8 | vgg16_layers = { 9 | "conv1_1": 64, 10 | "relu1_1": 64, 11 | "conv1_2": 64, 12 | "relu1_2": 64, 13 | "pool1": 64, 14 | "conv2_1": 128, 15 | "relu2_1": 128, 16 | "conv2_2": 128, 17 | "relu2_2": 128, 18 | "pool2": 128, 19 | "conv3_1": 256, 20 | "relu3_1": 256, 21 | "conv3_2": 256, 22 | "relu3_2": 256, 23 | "conv3_3": 256, 24 | "relu3_3": 256, 25 | "pool3": 256, 26 | "conv4_1": 512, 27 | "relu4_1": 512, 28 | "conv4_2": 512, 29 | "relu4_2": 512, 30 | "conv4_3": 512, 31 | "relu4_3": 512, 32 | "pool4": 512, 33 | "conv5_1": 512, 34 | "relu5_1": 512, 35 | "conv5_2": 512, 36 | "relu5_2": 512, 37 | "conv5_3": 512, 38 | "relu5_3": 512, 39 | "pool5": 512, 40 | } 41 | 42 | class AdaptLayers(nn.Module): 43 | """Small adaptation layers. 44 | """ 45 | 46 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128): 47 | """Initialize one adaptation layer for every extraction point. 48 | 49 | Args: 50 | hypercolumn_layers: The list of the hypercolumn layer names. 51 | output_dim: The output channel dimension. 52 | """ 53 | super(AdaptLayers, self).__init__() 54 | self.layers = [] 55 | channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers] 56 | for i, l in enumerate(channel_sizes): 57 | layer = nn.Sequential( 58 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0), 59 | nn.ReLU(), 60 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2), 61 | nn.BatchNorm2d(output_dim), 62 | ) 63 | self.layers.append(layer) 64 | self.add_module("adapt_layer_{}".format(i), layer) # ex: adapt_layer_0 65 | 66 | def forward(self, features: List[torch.tensor]): 67 | """Apply adaptation layers. # here is list of three levels of features 68 | """ 69 | 70 | for i, _ in enumerate(features): 71 | features[i] = getattr(self, "adapt_layer_{}".format(i))(features[i]) 72 | return features 73 | 74 | class DFNet(nn.Module): 75 | ''' DFNet implementation ''' 76 | default_conf = { 77 | 'hypercolumn_layers': ["conv1_2", "conv3_3", "conv5_3"], 78 | 'output_dim': 128, 79 | } 80 | mean = [0.485, 0.456, 0.406] 81 | std = [0.229, 0.224, 0.225] 82 | 83 | def __init__(self, feat_dim=12, places365_model_path=''): 84 | super(DFNet, self).__init__() 85 | 86 | self.layer_to_index = {k: v for v, k in enumerate(vgg16_layers.keys())} 87 | self.hypercolumn_indices = [self.layer_to_index[n] for n in self.default_conf['hypercolumn_layers']] # [2, 14, 28] 88 | 89 | # Initialize architecture 90 | vgg16 = models.vgg16(pretrained=True) 91 | 92 | self.encoder = nn.Sequential(*list(vgg16.features.children())) 93 | 94 | self.scales = [] 95 | current_scale = 0 96 | for i, layer in enumerate(self.encoder): 97 | if isinstance(layer, torch.nn.MaxPool2d): 98 | current_scale += 1 99 | if i in self.hypercolumn_indices: 100 | self.scales.append(2**current_scale) 101 | 102 | ## adaptation layers, see off branches from fig.3 in S2DNet paper 103 | self.adaptation_layers = AdaptLayers(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim']) 104 | 105 | # pose regression layers 106 | self.avgpool = nn.AdaptiveAvgPool2d(1) 107 | self.fc_pose = nn.Linear(512, feat_dim) 108 | 109 | def forward(self, x, return_feature=False, isSingleStream=False, return_pose=True, upsampleH=240, upsampleW=427): 110 | ''' 111 | inference DFNet. It can regress camera pose as well as extract intermediate layer features. 112 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream 113 | :param return_feature: whether to return features as output 114 | :param isSingleStream: whether it's an single stream inference or siamese network inference 115 | :param upsampleH: feature upsample size H 116 | :param upsampleW: feature upsample size W 117 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None 118 | :return predict: [2B, 12] or [B, 12] 119 | ''' 120 | # normalize input data 121 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std) 122 | x = (x - mean[:, None, None]) / std[:, None, None] 123 | 124 | ### encoder ### 125 | feature_maps = [] 126 | for i in range(len(self.encoder)): 127 | x = self.encoder[i](x) 128 | 129 | if i in self.hypercolumn_indices: 130 | feature = x.clone() 131 | feature_maps.append(feature) 132 | 133 | if i==self.hypercolumn_indices[-1]: 134 | if return_pose==False: 135 | predict = None 136 | break 137 | 138 | ### extract and process intermediate features ### 139 | if return_feature: 140 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer 141 | 142 | if isSingleStream: # not siamese network style inference 143 | feature_stacks = [] 144 | for f in feature_maps: 145 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f)) 146 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W]) 147 | else: # siamese network style inference 148 | feature_stacks_t = [] 149 | feature_stacks_r = [] 150 | for f in feature_maps: 151 | # split real and nerf batches 152 | batch = f.shape[0] # should be target batch_size + rgb batch_size 153 | feature_t = f[:batch//2] 154 | feature_r = f[batch//2:] 155 | 156 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img 157 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img 158 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W] 159 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W] 160 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W]) 161 | else: 162 | feature_maps = None 163 | 164 | if return_pose==False: 165 | return feature_maps, predict 166 | 167 | ### pose regression head ### 168 | x = self.avgpool(x) 169 | x = x.reshape(x.size(0), -1) 170 | predict = self.fc_pose(x) 171 | 172 | return feature_maps, predict 173 | 174 | class DFNet_s(nn.Module): 175 | ''' A slight accelerated version of DFNet, we experimentally found this version's performance is similar to original DFNet but inferences faster ''' 176 | default_conf = { 177 | 'hypercolumn_layers': ["conv1_2"], 178 | 'output_dim': 128, 179 | } 180 | mean = [0.485, 0.456, 0.406] 181 | std = [0.229, 0.224, 0.225] 182 | 183 | def __init__(self, feat_dim=12, places365_model_path=''): 184 | super(DFNet_s, self).__init__() 185 | 186 | self.layer_to_index = {k: v for v, k in enumerate(vgg16_layers.keys())} 187 | self.hypercolumn_indices = [self.layer_to_index[n] for n in self.default_conf['hypercolumn_layers']] # [2, 14, 28] 188 | 189 | # Initialize architecture 190 | vgg16 = models.vgg16(pretrained=True) 191 | 192 | self.encoder = nn.Sequential(*list(vgg16.features.children())) 193 | 194 | self.scales = [] 195 | current_scale = 0 196 | for i, layer in enumerate(self.encoder): 197 | if isinstance(layer, torch.nn.MaxPool2d): 198 | current_scale += 1 199 | if i in self.hypercolumn_indices: 200 | self.scales.append(2**current_scale) 201 | 202 | ## adaptation layers, see off branches from fig.3 in S2DNet paper 203 | self.adaptation_layers = AdaptLayers(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim']) 204 | 205 | # pose regression layers 206 | self.avgpool = nn.AdaptiveAvgPool2d(1) 207 | self.fc_pose = nn.Linear(512, feat_dim) 208 | 209 | def forward(self, x, return_feature=False, isSingleStream=False, return_pose=True, upsampleH=240, upsampleW=427): 210 | ''' 211 | inference DFNet_s. It can regress camera pose as well as extract intermediate layer features. 212 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream 213 | :param return_feature: whether to return features as output 214 | :param isSingleStream: whether it's an single stream inference or siamese network inference 215 | :param upsampleH: feature upsample size H 216 | :param upsampleW: feature upsample size W 217 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None 218 | :return predict: [2B, 12] or [B, 12] 219 | ''' 220 | 221 | # normalize input data 222 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std) 223 | x = (x - mean[:, None, None]) / std[:, None, None] 224 | 225 | ### encoder ### 226 | feature_maps = [] 227 | for i in range(len(self.encoder)): 228 | x = self.encoder[i](x) 229 | 230 | if i in self.hypercolumn_indices: 231 | feature = x.clone() 232 | feature_maps.append(feature) 233 | 234 | if i==self.hypercolumn_indices[-1]: 235 | if return_pose==False: 236 | predict = None 237 | break 238 | 239 | ### extract and process intermediate features ### 240 | if return_feature: 241 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer 242 | 243 | if isSingleStream: # not siamese network style inference 244 | feature_stacks = [] 245 | for f in feature_maps: 246 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f)) 247 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W]) 248 | else: # siamese network style inference 249 | feature_stacks_t = [] 250 | feature_stacks_r = [] 251 | for f in feature_maps: 252 | # split real and nerf batches 253 | batch = f.shape[0] # should be target batch_size + rgb batch_size 254 | feature_t = f[:batch//2] 255 | feature_r = f[batch//2:] 256 | 257 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img 258 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img 259 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W] 260 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W] 261 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W]) 262 | else: 263 | feature_maps = None 264 | 265 | if return_pose==False: 266 | return feature_maps, predict 267 | 268 | ### pose regression head ### 269 | x = self.avgpool(x) 270 | x = x.reshape(x.size(0), -1) 271 | predict = self.fc_pose(x) 272 | 273 | return feature_maps, predict -------------------------------------------------------------------------------- /script/feature/efficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import copy, pdb 5 | from typing import List 6 | from efficientnet_pytorch import EfficientNet 7 | 8 | # efficientnet-B3 Layer Names and Channels 9 | EB3_layers = { 10 | "reduction_1": 24, # torch.Size([2, 24, 120, 213]) 11 | "reduction_2": 32, # torch.Size([2, 32, 60, 106]) 12 | "reduction_3": 48, # torch.Size([2, 48, 30, 53]) 13 | "reduction_4": 136, # torch.Size([2, 136, 15, 26]) 14 | "reduction_5": 384, # torch.Size([2, 384, 8, 13]) 15 | "reduction_6": 1536, # torch.Size([2, 1536, 8, 13]) 16 | } 17 | 18 | # efficientnet-B0 Layer Names and Channels 19 | EB0_layers = { 20 | "reduction_1": 16, # torch.Size([2, 16, 120, 213]) 21 | "reduction_2": 24, # torch.Size([2, 24, 60, 106]) 22 | "reduction_3": 40, # torch.Size([2, 40, 30, 53]) 23 | "reduction_4": 112, # torch.Size([2, 112, 15, 26]) 24 | "reduction_5": 320, # torch.Size([2, 320, 8, 13]) 25 | "reduction_6": 1280, # torch.Size([2, 1280, 8, 13]) 26 | } 27 | 28 | class AdaptLayers(nn.Module): 29 | """Small adaptation layers. 30 | """ 31 | 32 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128): 33 | """Initialize one adaptation layer for every extraction point. 34 | 35 | Args: 36 | hypercolumn_layers: The list of the hypercolumn layer names. 37 | output_dim: The output channel dimension. 38 | """ 39 | super(AdaptLayers, self).__init__() 40 | self.layers = [] 41 | channel_sizes = [EB3_layers[name] for name in hypercolumn_layers] 42 | for i, l in enumerate(channel_sizes): 43 | layer = nn.Sequential( 44 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0), 45 | nn.ReLU(), 46 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2), 47 | nn.BatchNorm2d(output_dim), 48 | ) 49 | self.layers.append(layer) 50 | self.add_module("adapt_layer_{}".format(i), layer) # ex: adapt_layer_0 51 | 52 | def forward(self, features: List[torch.tensor]): 53 | """Apply adaptation layers. # here is list of three levels of features 54 | """ 55 | 56 | for i, _ in enumerate(features): 57 | features[i] = getattr(self, "adapt_layer_{}".format(i))(features[i]) 58 | return features 59 | 60 | class EfficientNetB3(nn.Module): 61 | ''' DFNet with EB3 backbone ''' 62 | default_conf = { 63 | # 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_6"], 64 | 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_5"], 65 | # 'hypercolumn_layers': ["reduction_2", "reduction_4", "reduction_6"], 66 | 'output_dim': 128, 67 | } 68 | mean = [0.485, 0.456, 0.406] 69 | std = [0.229, 0.224, 0.225] 70 | 71 | def __init__(self, feat_dim=12, places365_model_path=''): 72 | super(EfficientNetB3, self).__init__() 73 | # Initialize architecture 74 | self.backbone_net = EfficientNet.from_pretrained('efficientnet-b3') 75 | self.feature_extractor = self.backbone_net.extract_endpoints 76 | 77 | # self.feature_block_index = [1, 3, 6] # same as the 'hypercolumn_layers' 78 | self.feature_block_index = [1, 3, 5] # same as the 'hypercolumn_layers' 79 | # self.feature_block_index = [2, 4, 6] # same as the 'hypercolumn_layers' 80 | 81 | ## adaptation layers, see off branches from fig.3 in S2DNet paper 82 | self.adaptation_layers = AdaptLayers(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim']) 83 | 84 | # pose regression layers 85 | self.avgpool = nn.AdaptiveAvgPool2d(1) 86 | self.fc_pose = nn.Linear(1536, feat_dim) 87 | 88 | def forward(self, x, return_feature=False, isSingleStream=False, upsampleH=120, upsampleW=213): 89 | ''' 90 | inference DFNet. It can regress camera pose as well as extract intermediate layer features. 91 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream 92 | :param return_feature: whether to return features as output 93 | :param isSingleStream: whether it's an single stream inference or siamese network inference 94 | :param upsampleH: feature upsample size H 95 | :param upsampleW: feature upsample size W 96 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None 97 | :return predict: [2B, 12] or [B, 12] 98 | ''' 99 | # normalize input data 100 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std) 101 | x = (x - mean[:, None, None]) / std[:, None, None] 102 | 103 | ### encoder ### 104 | feature_maps = [] 105 | list_x = self.feature_extractor(x) 106 | 107 | x = list_x['reduction_6'] # features to save 108 | for i in self.feature_block_index: 109 | fe = list_x['reduction_'+str(i)].clone() 110 | feature_maps.append(fe) 111 | 112 | ### extract and process intermediate features ### 113 | if return_feature: 114 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer 115 | 116 | pdb.set_trace() 117 | if isSingleStream: # not siamese network style inference 118 | feature_stacks = [] 119 | 120 | for f in feature_maps: 121 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f)) 122 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W]) 123 | else: # siamese network style inference 124 | feature_stacks_t = [] 125 | feature_stacks_r = [] 126 | 127 | for f in feature_maps: 128 | # split real and nerf batches 129 | batch = f.shape[0] # should be target batch_size + rgb batch_size 130 | feature_t = f[:batch//2] 131 | feature_r = f[batch//2:] 132 | 133 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img 134 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img 135 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W] 136 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W] 137 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W]) 138 | 139 | else: 140 | feature_maps = None 141 | 142 | ### pose regression head ### 143 | x = self.avgpool(x) 144 | x = x.reshape(x.size(0), -1) 145 | predict = self.fc_pose(x) 146 | 147 | return feature_maps, predict 148 | 149 | class AdaptLayers2(nn.Module): 150 | """Small adaptation layers. 151 | """ 152 | 153 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128): 154 | """Initialize one adaptation layer for every extraction point. 155 | 156 | Args: 157 | hypercolumn_layers: The list of the hypercolumn layer names. 158 | output_dim: The output channel dimension. 159 | """ 160 | super(AdaptLayers2, self).__init__() 161 | self.layers = [] 162 | channel_sizes = [EB0_layers[name] for name in hypercolumn_layers] 163 | for i, l in enumerate(channel_sizes): 164 | layer = nn.Sequential( 165 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0), 166 | nn.ReLU(), 167 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2), 168 | nn.BatchNorm2d(output_dim), 169 | ) 170 | self.layers.append(layer) 171 | self.add_module("adapt_layer_{}".format(i), layer) # ex: adapt_layer_0 172 | 173 | def forward(self, features: List[torch.tensor]): 174 | """Apply adaptation layers. # here is list of three levels of features 175 | """ 176 | 177 | for i, _ in enumerate(features): 178 | features[i] = getattr(self, "adapt_layer_{}".format(i))(features[i]) 179 | return features 180 | 181 | class EfficientNetB0(nn.Module): 182 | ''' DFNet with EB0 backbone, feature levels can be customized ''' 183 | default_conf = { 184 | # 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_6"], 185 | 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_5"], 186 | # 'hypercolumn_layers': ["reduction_2", "reduction_4", "reduction_6"], 187 | # 'hypercolumn_layers': ["reduction_1"], 188 | 'output_dim': 128, 189 | } 190 | mean = [0.485, 0.456, 0.406] 191 | std = [0.229, 0.224, 0.225] 192 | 193 | def __init__(self, feat_dim=12, places365_model_path=''): 194 | super(EfficientNetB0, self).__init__() 195 | # Initialize architecture 196 | self.backbone_net = EfficientNet.from_pretrained('efficientnet-b0') 197 | self.feature_extractor = self.backbone_net.extract_endpoints 198 | 199 | # self.feature_block_index = [1, 3, 6] # same as the 'hypercolumn_layers' 200 | self.feature_block_index = [1, 3, 5] # same as the 'hypercolumn_layers' 201 | # self.feature_block_index = [2, 4, 6] # same as the 'hypercolumn_layers' 202 | # self.feature_block_index = [1] 203 | 204 | ## adaptation layers, see off branches from fig.3 in S2DNet paper 205 | self.adaptation_layers = AdaptLayers2(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim']) 206 | 207 | # pose regression layers 208 | self.avgpool = nn.AdaptiveAvgPool2d(1) 209 | self.fc_pose = nn.Linear(1280, feat_dim) 210 | 211 | def forward(self, x, return_feature=False, isSingleStream=False, return_pose=False, upsampleH=120, upsampleW=213): 212 | ''' 213 | inference DFNet. It can regress camera pose as well as extract intermediate layer features. 214 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream 215 | :param return_feature: whether to return features as output 216 | :param isSingleStream: whether it's an single stream inference or siamese network inference 217 | :param return_pose: TODO: if only return_pose, we don't need to compute return_feature part 218 | :param upsampleH: feature upsample size H 219 | :param upsampleW: feature upsample size W 220 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None 221 | :return predict: [2B, 12] or [B, 12] 222 | ''' 223 | # normalize input data 224 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std) 225 | x = (x - mean[:, None, None]) / std[:, None, None] 226 | 227 | ### encoder ### 228 | feature_maps = [] 229 | list_x = self.feature_extractor(x) 230 | 231 | x = list_x['reduction_6'] # features to save 232 | for i in self.feature_block_index: 233 | fe = list_x['reduction_'+str(i)].clone() 234 | feature_maps.append(fe) 235 | 236 | ### extract and process intermediate features ### 237 | if return_feature: 238 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer 239 | 240 | if isSingleStream: # not siamese network style inference 241 | feature_stacks = [] 242 | 243 | for f in feature_maps: 244 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f)) 245 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W]) 246 | else: # siamese network style inference 247 | feature_stacks_t = [] 248 | feature_stacks_r = [] 249 | 250 | for f in feature_maps: 251 | 252 | # split real and nerf batches 253 | batch = f.shape[0] # should be target batch_size + rgb batch_size 254 | feature_t = f[:batch//2] 255 | feature_r = f[batch//2:] 256 | 257 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img 258 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img 259 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W] 260 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W] 261 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W]) 262 | 263 | else: 264 | feature_maps = None 265 | 266 | ### pose regression head ### 267 | x = self.avgpool(x) 268 | x = x.reshape(x.size(0), -1) 269 | predict = self.fc_pose(x) 270 | 271 | return feature_maps, predict 272 | 273 | def main(): 274 | """ 275 | test model 276 | """ 277 | from torchsummary import summary 278 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 279 | feat_model = EfficientNetB3() 280 | # feat_model.load_state_dict(torch.load('')) 281 | feat_model.to(device) 282 | summary(feat_model, (3, 240, 427)) 283 | 284 | if __name__ == '__main__': 285 | main() 286 | -------------------------------------------------------------------------------- /script/feature/options.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | def config_parser(): 3 | parser = configargparse.ArgumentParser() 4 | parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1") 5 | parser.add_argument("--multi_gpu", action='store_true', help='use multiple gpu on the server') 6 | parser.add_argument('--config', is_config_file=True, help='config file path') 7 | parser.add_argument("--expname", type=str, help='experiment name') 8 | parser.add_argument("--basedir", type=str, default='../logs', help='where to store ckpts and logs') 9 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') 10 | parser.add_argument("--places365_model_path", type=str, default='', help='ckpt path of places365 pretrained model') 11 | 12 | # 7Scenes 13 | parser.add_argument("--trainskip", type=int, default=1, help='will load 1/N images from train sets, useful for large datasets like 7 Scenes') 14 | parser.add_argument("--df", type=float, default=1., help='image downscale factor') 15 | parser.add_argument("--reduce_embedding", type=int, default=-1, help='fourier embedding mode: -1: paper default, \ 16 | 0: reduce by half, 1: remove embedding, 2: DNeRF embedding') 17 | parser.add_argument("--epochToMaxFreq", type=int, default=-1, help='DNeRF embedding mode: (based on Nerfie paper): \ 18 | hyper-parameter for when α should reach the maximum number of frequencies m') 19 | parser.add_argument("--render_pose_only", action='store_true', help='render a spiral video for 7 Scene') 20 | parser.add_argument("--save_pose_avg_stats", action='store_true', help='save a pose avg stats to unify NeRF, posenet, direct-pn training') 21 | parser.add_argument("--load_pose_avg_stats", action='store_true', help='load precomputed pose avg stats to unify NeRF, posenet, nerf tracking training') 22 | parser.add_argument("--train_local_nerf", type=int, default=-1, help='train local NeRF with ith training sequence only, ie. Stairs can pick 0~3') 23 | parser.add_argument("--render_video_train", action='store_true', help='render train set NeRF and save as video, make sure render_test is True') 24 | parser.add_argument("--render_video_test", action='store_true', help='render val set NeRF and save as video, make sure render_test is True') 25 | parser.add_argument("--frustum_overlap_th", type=float, help='frustsum overlap threshold') 26 | parser.add_argument("--no_DNeRF_viewdir", action='store_true', default=False, help='will not use DNeRF in viewdir encoding') 27 | parser.add_argument("--load_unique_view_stats", action='store_true', help='load unique views frame index') 28 | parser.add_argument("--finetune_unlabel", action='store_true', help='finetune unlabeled sequence like MapNet') 29 | parser.add_argument("--i_eval", type=int, default=20, help='frequency of eval posenet result') 30 | parser.add_argument("--save_all_ckpt", action='store_true', help='save all ckpts for each epoch') 31 | parser.add_argument("--val_on_psnr", action='store_true', default=False, help='EarlyStopping with max validation psnr') 32 | 33 | # NeRF training options 34 | parser.add_argument("--netdepth", type=int, default=8, help='layers in network') 35 | parser.add_argument("--netwidth", type=int, default=128, help='channels per layer') 36 | parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network') 37 | parser.add_argument("--netwidth_fine", type=int, default=128, help='channels per layer in fine network') 38 | parser.add_argument("--N_rand", type=int, default=1536, help='batch size (number of random rays per gradient step)') 39 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') 40 | parser.add_argument("--lrate_decay", type=float, default=250, help='exponential learning rate decay (in 1000 steps)') 41 | parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') 42 | parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory') 43 | parser.add_argument("--no_batching", action='store_true', default=True, help='only take random rays from 1 image at a time') 44 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') 45 | parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network') 46 | 47 | # NeRF-Hist training options 48 | parser.add_argument("--NeRFH", action='store_true', default=True, help='new implementation for NeRFH, please add --encode_hist') 49 | parser.add_argument("--N_vocab", type=int, default=1000, 50 | help='''number of vocabulary (number of images) 51 | in the dataset for nn.Embedding''') 52 | parser.add_argument("--fix_index", action='store_true', help='fix training frame index as 0') 53 | parser.add_argument("--encode_hist", default=False, action='store_true', help='encode histogram instead of frame index') 54 | parser.add_argument("--hist_bin", type=int, default=10, help='image histogram bin size') 55 | parser.add_argument("--in_channels_a", type=int, default=50, help='appearance embedding dimension, hist_bin*N_a when embedding histogram') 56 | parser.add_argument("--in_channels_t", type=int, default=20, help='transient embedding dimension, hist_bin*N_tau when embedding histogram') 57 | parser.add_argument("--svd_reg", default=False, action='store_true', help='use svd regularize output at training') 58 | 59 | # NeRF rendering options 60 | parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray') 61 | parser.add_argument("--N_importance", type=int, default=64,help='number of additional fine samples per ray') 62 | parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter') 63 | parser.add_argument("--use_viewdirs", action='store_true', default=True, help='use full 5D input instead of 3D') 64 | parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none') 65 | parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') 66 | parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') 67 | parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 68 | parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') 69 | parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') 70 | parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 71 | parser.add_argument("--no_grad_update", action='store_true', default=False, help='do not update nerf in training') 72 | parser.add_argument("--tinyimg", action='store_true', default=False, help='render nerf img in a tiny scale image, this is a temporal compromise for direct feature matching, must FIX later') 73 | parser.add_argument("--tinyscale", type=float, default=4., help='determine the scale of downsizing nerf rendring, must FIX later') 74 | 75 | ##################### PoseNet Settings ######################## 76 | parser.add_argument("--pose_only", type=int, default=1, help='posenet type to train, \ 77 | 1: train baseline posenet, 2: posenet+nerf manual optimize, \ 78 | 3: VLocNet, 4: DGRNet') 79 | parser.add_argument("--learning_rate", type=float, default=0.0001, help='learning rate') 80 | parser.add_argument("--batch_size", type=int, default=1, help='dataloader batch size, Attention: this is NOT the actual training batch size, \ 81 | please use --featurenet_batch_size for training') 82 | parser.add_argument("--featurenet_batch_size", type=int, default=8, help='featurenet training batch size, choose smaller batch size') 83 | parser.add_argument("--pretrain_model_path", type=str, default='', help='model path of pretrained model') 84 | parser.add_argument("--model_name", type=str, help='pose model output folder name') 85 | parser.add_argument("--combine_loss_w", nargs='+', type=float, default=[1, 1, 1], 86 | help='weights of combined loss ex, [1, 1, 1], \ 87 | default None, only use when combine_loss is True') 88 | parser.add_argument("--patience", nargs='+', type=int, default=[200, 50], help='set training schedule for patience [EarlyStopping, reduceLR]') 89 | parser.add_argument("--resize_factor", type=int, default=2, help='image resize downsample ratio') 90 | parser.add_argument("--freezeBN", action='store_true', help='Freeze the Batch Norm layer at training PoseNet') 91 | parser.add_argument("--preprocess_ImgNet", action='store_true', help='Normalize input data for PoseNet') 92 | parser.add_argument("--eval", action='store_true', help='eval model') 93 | parser.add_argument("--no_save_multiple", action='store_true', help='default, save multiple posenet model, if true, save only one posenet model') 94 | parser.add_argument("--resnet34", action='store_true', default=False, help='use resnet34 backbone instead of mobilenetV2') 95 | parser.add_argument("--efficientnet", action='store_true', default=False, help='use efficientnet-b3 backbone instead of mobilenetV2') 96 | parser.add_argument("--dropout", type=float, default=0.5, help='dropout rate for resnet34 backbone') 97 | parser.add_argument("--DFNet", action='store_true', default=False, help='use DFNet') 98 | parser.add_argument("--DFNet_s", action='store_true', default=False, help='use accelerated DFNet, performance is similar to DFNet but slightly faster') 99 | parser.add_argument("--featurelossonly", action='store_true', default=False, help='only use feature loss to train feature extraction') 100 | parser.add_argument("--random_view_synthesis", action='store_true', default=False, help='add random view synthesis') 101 | parser.add_argument("--rvs_refresh_rate", type=int, default=2, help='re-synthesis new views per X epochs') 102 | parser.add_argument("--rvs_trans", type=float, default=5, help='jitter range for rvs on translation') 103 | parser.add_argument("--rvs_rotation", type=float, default=1.2, help='jitter range for rvs on rotation, this is in log_10 uniform range, log(15) = 1.2') 104 | parser.add_argument("--d_max", type=float, default=1, help='rvs bounds d_max') 105 | parser.add_argument("--val_batch_size", type=int, default=1, help='batch_size for validation, higher number leads to faster speed') 106 | 107 | # legacy mesh options 108 | parser.add_argument("--mesh_only", action='store_true', help='do not optimize, reload weights and save mesh to a file') 109 | parser.add_argument("--mesh_grid_size", type=int, default=80,help='number of grid points to sample in each dimension for marching cubes') 110 | 111 | # training options 112 | parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops') 113 | parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops') 114 | parser.add_argument("--epochs", type=int, default=2000,help='number of epochs to train') 115 | parser.add_argument("--poselossonly", action='store_true', help='eval model') 116 | parser.add_argument("--tripletloss", action='store_true', help='use triplet loss at training featurenet, this is to prevent catastophic failing') 117 | parser.add_argument("--triplet_margin", type=float,default=1., help='triplet loss margin hyperparameter') 118 | parser.add_argument("--render_feature_only", action='store_true', default=False, help='render features and save to a path') 119 | 120 | # dataset options 121 | parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / 7Scenes') 122 | parser.add_argument("--testskip", type=int, default=1, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 123 | 124 | ## legacy blender flags 125 | parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') 126 | parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800') 127 | 128 | ## llff flags 129 | parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images') 130 | parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') 131 | parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth') 132 | parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes') 133 | parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8') 134 | parser.add_argument("--no_bd_factor", action='store_true', default=False, help='do not use bd factor') 135 | 136 | # logging/saving options 137 | parser.add_argument("--i_print", type=int, default=1, help='frequency of console printout and metric loggin') 138 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') 139 | parser.add_argument("--i_weights", type=int, default=200, help='frequency of weight ckpt saving') 140 | parser.add_argument("--i_testset", type=int, default=200, help='frequency of testset saving') 141 | parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') 142 | 143 | return parser -------------------------------------------------------------------------------- /script/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/script/models/__init__.py -------------------------------------------------------------------------------- /script/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pdb 4 | 5 | class ColorLoss(nn.Module): 6 | def __init__(self, coef=1): 7 | super().__init__() 8 | self.coef = coef 9 | self.loss = nn.MSELoss(reduction='mean') 10 | 11 | def forward(self, inputs, targets): 12 | loss = self.loss(inputs['rgb_coarse'], targets) 13 | if 'rgb_fine' in inputs: 14 | loss += self.loss(inputs['rgb_fine'], targets) 15 | 16 | return self.coef * loss 17 | 18 | 19 | class NerfWLoss(nn.Module): 20 | """ 21 | Equation 13 in the NeRF-W paper. 22 | Name abbreviations: 23 | c_l: coarse color loss 24 | f_l: fine color loss (1st term in equation 13) 25 | b_l: beta loss (2nd term in equation 13) 26 | s_l: sigma loss (3rd term in equation 13) 27 | targets # [N, 3] 28 | inputs['rgb_coarse'] # [N, 3] 29 | inputs['rgb_fine'] # [N, 3] 30 | inputs['beta'] # [N] 31 | inputs['transient_sigmas'] # [N, 2*N_Samples] 32 | :return: 33 | """ 34 | def __init__(self, coef=1, lambda_u=0.01): 35 | """ 36 | lambda_u: in equation 13 37 | """ 38 | super().__init__() 39 | self.coef = coef 40 | self.lambda_u = lambda_u 41 | 42 | def forward(self, inputs, targets, use_hier_rgbs=False, rgb_h=None, rgb_w=None): 43 | 44 | ret = {} 45 | ret['c_l'] = 0.5 * ((inputs['rgb_coarse']-targets)**2).mean() 46 | if 'rgb_fine' in inputs: 47 | if 'beta' not in inputs: # no transient head, normal MSE loss 48 | ret['f_l'] = 0.5 * ((inputs['rgb_fine']-targets)**2).mean() 49 | else: 50 | ret['f_l'] = ((inputs['rgb_fine']-targets)**2/(2*inputs['beta'].unsqueeze(1)**2)).mean() 51 | ret['b_l'] = 3 + torch.log(inputs['beta']).mean() # +3 to make it positive 52 | ret['s_l'] = self.lambda_u * inputs['transient_sigmas'].mean() 53 | 54 | for k, v in ret.items(): 55 | ret[k] = self.coef * v 56 | 57 | return ret 58 | 59 | loss_dict = {'color': ColorLoss, 60 | 'nerfw': NerfWLoss} -------------------------------------------------------------------------------- /script/models/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim as dssim 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value = (image_pred-image_gt)**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 13 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 14 | 15 | def ssim(image_pred, image_gt, reduction='mean'): 16 | """ 17 | image_pred and image_gt: (1, 3, H, W) 18 | """ 19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1] 20 | return 1-2*dssim_ # in [-1, 1] -------------------------------------------------------------------------------- /script/models/nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | torch.autograd.set_detect_anomaly(True) 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | # Misc 9 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 10 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 12 | 13 | def batchify(fn, chunk): 14 | """Constructs a version of 'fn' that applies to smaller batches. 15 | """ 16 | if chunk is None: 17 | return fn 18 | def ret(inputs): 19 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 20 | return ret 21 | 22 | 23 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 24 | """Prepares inputs and applies network 'fn'. 25 | """ 26 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 27 | embedded = embed_fn(inputs_flat) 28 | 29 | if viewdirs is not None: 30 | input_dirs = viewdirs[:,None].expand(inputs.shape) 31 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 32 | embedded_dirs = embeddirs_fn(input_dirs_flat) 33 | embedded = torch.cat([embedded, embedded_dirs], -1) 34 | 35 | outputs_flat = batchify(fn, netchunk)(embedded) 36 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 37 | return outputs 38 | 39 | def run_network_DNeRF(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64, epoch=None, no_DNeRF_viewdir=False): 40 | """Prepares inputs and applies network 'fn'. 41 | """ 42 | if epoch<0 or epoch==None: 43 | print("Error: run_network_DNeRF(): Invalid epoch") 44 | sys.exit() 45 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 46 | embedded = embed_fn(inputs_flat, epoch) 47 | # add weighted function here 48 | if viewdirs is not None: 49 | input_dirs = viewdirs[:,None].expand(inputs.shape) 50 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 51 | 52 | if no_DNeRF_viewdir: 53 | embedded_dirs = embeddirs_fn(input_dirs_flat) 54 | else: 55 | embedded_dirs = embeddirs_fn(input_dirs_flat, epoch) 56 | embedded = torch.cat([embedded, embedded_dirs], -1) 57 | 58 | outputs_flat = batchify(fn, netchunk)(embedded) 59 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 60 | return outputs 61 | 62 | 63 | # Positional encoding (section 5.1) 64 | class Embedder: 65 | def __init__(self, **kwargs): 66 | self.kwargs = kwargs 67 | self.N_freqs = 0 68 | self.N = -1 # epoch to max frequency, for Nerfie embedding only 69 | self.create_embedding_fn() 70 | 71 | def create_embedding_fn(self): 72 | embed_fns = [] 73 | d = self.kwargs['input_dims'] 74 | out_dim = 0 75 | if self.kwargs['include_input']: 76 | embed_fns.append(lambda x : x) 77 | out_dim += d 78 | 79 | max_freq = self.kwargs['max_freq_log2'] 80 | self.N_freqs = self.kwargs['num_freqs'] 81 | 82 | if self.kwargs['log_sampling']: 83 | freq_bands = 2.**torch.linspace(0., max_freq, steps=self.N_freqs) # tensor([ 1., 2., 4., 8., 16., 32., 64., 128., 256., 512.]) 84 | else: 85 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=self.N_freqs) 86 | 87 | for freq in freq_bands: # 10 iters for 3D location, 4 iters for 2D direction 88 | for p_fn in self.kwargs['periodic_fns']: 89 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 90 | out_dim += d 91 | self.embed_fns = embed_fns 92 | self.out_dim = out_dim 93 | 94 | def embed(self, inputs): 95 | # inputs [65536, 3] 96 | if self.kwargs['max_freq_log2'] != 0: 97 | ret = torch.cat([fn(inputs) for fn in self.embed_fns], -1) # cos, sin embedding # ret.shape [65536, 63] 98 | else: 99 | ret = inputs 100 | return ret 101 | 102 | def get_embed_weight(self, epoch, num_freqs, N): 103 | ''' Nerfie Paper Eq.(8) ''' 104 | alpha = num_freqs * epoch / N 105 | W_j = [] 106 | for i in range(num_freqs): 107 | tmp = torch.clamp(torch.Tensor([alpha - i]), 0, 1) 108 | tmp2 = (1 - torch.cos(torch.Tensor([np.pi]) * tmp)) / 2 109 | W_j.append(tmp2) 110 | return W_j 111 | 112 | def embed_DNeRF(self, inputs, epoch): 113 | ''' Nerfie paper section 3.5 Coarse-to-Fine Deformation Regularization ''' 114 | # get weight for each frequency band j 115 | W_j = self.get_embed_weight(epoch, self.N_freqs, self.N) # W_j: [W_0, W_1, W_2, ..., W_{m-1}] 116 | 117 | # Fourier embedding 118 | out = [] 119 | for fn in self.embed_fns: # 17, embed_fns:[input, cos, sin, cos, sin, ..., cos, sin] 120 | out.append(fn(inputs)) 121 | 122 | # apply weighted positional encoding, only to cos&sins 123 | for i in range(len(W_j)): 124 | out[2*i+1] = W_j[i] * out[2*i+1] 125 | out[2*i+2] = W_j[i] * out[2*i+2] 126 | ret = torch.cat(out, -1) 127 | return ret 128 | 129 | def update_N(self, N): 130 | self.N=N 131 | 132 | 133 | def get_embedder(multires, i=0, reduce_mode=-1, epochToMaxFreq=-1): 134 | if i == -1: 135 | return nn.Identity(), 3 136 | 137 | if reduce_mode == 0: 138 | # reduce embedding 139 | embed_kwargs = { 140 | 'include_input' : True, 141 | 'input_dims' : 3, 142 | 'max_freq_log2' : (multires-1)//2, 143 | 'num_freqs' : multires//2, 144 | 'log_sampling' : True, 145 | 'periodic_fns' : [torch.sin, torch.cos], 146 | } 147 | elif reduce_mode == 1: 148 | # remove embedding 149 | embed_kwargs = { 150 | 'include_input' : True, 151 | 'input_dims' : 3, 152 | 'max_freq_log2' : 0, 153 | 'num_freqs' : 0, 154 | 'log_sampling' : True, 155 | 'periodic_fns' : [torch.sin, torch.cos], 156 | } 157 | elif reduce_mode == 2: 158 | # DNeRF embedding 159 | embed_kwargs = { 160 | 'include_input' : True, 161 | 'input_dims' : 3, 162 | 'max_freq_log2' : multires-1, 163 | 'num_freqs' : multires, 164 | 'log_sampling' : True, 165 | 'periodic_fns' : [torch.sin, torch.cos], 166 | } 167 | else: 168 | # paper default 169 | embed_kwargs = { 170 | 'include_input' : True, 171 | 'input_dims' : 3, 172 | 'max_freq_log2' : multires-1, 173 | 'num_freqs' : multires, 174 | 'log_sampling' : True, 175 | 'periodic_fns' : [torch.sin, torch.cos], 176 | } 177 | 178 | embedder_obj = Embedder(**embed_kwargs) 179 | if reduce_mode == 2: 180 | embedder_obj.update_N(epochToMaxFreq) 181 | embed = lambda x, epoch, eo=embedder_obj: eo.embed_DNeRF(x, epoch) 182 | else: 183 | embed = lambda x, eo=embedder_obj : eo.embed(x) 184 | return embed, embedder_obj.out_dim, embedder_obj# 63 for pos, 27 for view dir 185 | 186 | # Model 187 | class NeRF(nn.Module): 188 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): 189 | """ 190 | """ 191 | super(NeRF, self).__init__() 192 | self.D = D 193 | self.W = W 194 | self.input_ch = input_ch 195 | self.input_ch_views = input_ch_views 196 | self.skips = skips 197 | self.use_viewdirs = use_viewdirs 198 | 199 | self.pts_linears = nn.ModuleList( 200 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 201 | 202 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 203 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 204 | 205 | ### Implementation according to the NeRF paper 206 | # self.views_linears = nn.ModuleList( 207 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 208 | 209 | if use_viewdirs: 210 | self.feature_linear = nn.Linear(W, W) 211 | self.alpha_linear = nn.Linear(W, 1) 212 | self.rgb_linear = nn.Linear(W//2, 3) 213 | else: 214 | self.output_linear = nn.Linear(W, output_ch) 215 | 216 | def forward(self, x): 217 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 218 | h = input_pts 219 | for i, l in enumerate(self.pts_linears): 220 | h = self.pts_linears[i](h) 221 | h = F.relu(h) 222 | if i in self.skips: 223 | h = torch.cat([input_pts, h], -1) 224 | 225 | if self.use_viewdirs: 226 | alpha = self.alpha_linear(h) 227 | feature = self.feature_linear(h) 228 | h = torch.cat([feature, input_views], -1) 229 | 230 | for i, l in enumerate(self.views_linears): 231 | h = self.views_linears[i](h) 232 | h = F.relu(h) 233 | 234 | rgb = self.rgb_linear(h) 235 | outputs = torch.cat([rgb, alpha], -1) 236 | else: 237 | outputs = self.output_linear(h) 238 | 239 | return outputs 240 | 241 | def load_weights_from_keras(self, weights): 242 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 243 | 244 | # Load pts_linears 245 | for i in range(self.D): 246 | idx_pts_linears = 2 * i 247 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 248 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) 249 | 250 | # Load feature_linear 251 | idx_feature_linear = 2 * self.D 252 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 253 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) 254 | 255 | # Load views_linears 256 | idx_views_linears = 2 * self.D + 2 257 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 258 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) 259 | 260 | # Load rgb_linear 261 | idx_rbg_linear = 2 * self.D + 4 262 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 263 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) 264 | 265 | # Load alpha_linear 266 | idx_alpha_linear = 2 * self.D + 6 267 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 268 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) 269 | 270 | def create_nerf(args): 271 | """Instantiate NeRF's MLP model. 272 | """ 273 | if args.reduce_embedding==2: # use DNeRF embedding 274 | embed_fn, input_ch, embedder_obj = get_embedder(args.multires, args.i_embed, args.reduce_embedding, args.epochToMaxFreq) # input_ch.shape=63 275 | else: 276 | embed_fn, input_ch, _ = get_embedder(args.multires, args.i_embed, args.reduce_embedding) # input_ch.shape=63 277 | 278 | input_ch_views = 0 279 | embeddirs_fn = None 280 | if args.use_viewdirs: 281 | if args.reduce_embedding==2: # use DNeRF embedding 282 | if args.no_DNeRF_viewdir: # no DNeRF embedding for viewdir 283 | embeddirs_fn, input_ch_views, _ = get_embedder(args.multires_views, args.i_embed) 284 | else: 285 | embeddirs_fn, input_ch_views, embedddirs_obj = get_embedder(args.multires_views, args.i_embed, args.reduce_embedding, args.epochToMaxFreq) 286 | else: 287 | embeddirs_fn, input_ch_views, _ = get_embedder(args.multires_views, args.i_embed, args.reduce_embedding) # input_ch_views.shape=27 288 | output_ch = 5 if args.N_importance > 0 else 4 289 | skips = [4] 290 | model = NeRF(D=args.netdepth, W=args.netwidth, input_ch=input_ch, output_ch=output_ch, skips=skips, 291 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) 292 | device = torch.device("cuda") 293 | if args.multi_gpu: 294 | model = torch.nn.DataParallel(model).to(device) 295 | else: 296 | model = model.to(device) 297 | grad_vars = list(model.parameters()) 298 | 299 | model_fine = None 300 | if args.N_importance > 0: 301 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, input_ch=input_ch, output_ch=output_ch, skips=skips, 302 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) 303 | if args.multi_gpu: 304 | model_fine = torch.nn.DataParallel(model_fine).to(device) 305 | else: 306 | model_fine = model_fine.to(device) 307 | grad_vars += list(model_fine.parameters()) 308 | 309 | if args.reduce_embedding==2: # use DNeRF embedding 310 | network_query_fn = lambda inputs, viewdirs, network_fn, epoch: run_network_DNeRF(inputs, viewdirs, network_fn, 311 | embed_fn=embed_fn, 312 | embeddirs_fn=embeddirs_fn, 313 | netchunk=args.netchunk, 314 | epoch=epoch, no_DNeRF_viewdir=args.no_DNeRF_viewdir) 315 | else: 316 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 317 | embed_fn=embed_fn, 318 | embeddirs_fn=embeddirs_fn, 319 | netchunk=args.netchunk) 320 | 321 | # Create optimizer 322 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 323 | 324 | start = 0 325 | basedir = args.basedir 326 | expname = args.expname 327 | 328 | # Load checkpoints 329 | if args.ft_path is not None and args.ft_path!='None': 330 | ckpts = [args.ft_path] 331 | else: 332 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] 333 | 334 | print('Found ckpts', ckpts) 335 | if len(ckpts) > 0 and not args.no_reload: 336 | ckpt_path = ckpts[-1] 337 | print('Reloading from', ckpt_path) 338 | ckpt = torch.load(ckpt_path) 339 | 340 | start = ckpt['global_step'] 341 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 342 | # Load model 343 | model.load_state_dict(ckpt['network_fn_state_dict']) 344 | 345 | if model_fine is not None: 346 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 347 | 348 | ########################## 349 | 350 | render_kwargs_train = { 351 | 'network_query_fn' : network_query_fn, 352 | 'perturb' : args.perturb, 353 | 'N_importance' : args.N_importance, 354 | 'network_fine' : model_fine, 355 | 'N_samples' : args.N_samples, 356 | 'network_fn' : model, 357 | 'use_viewdirs' : args.use_viewdirs, 358 | 'white_bkgd' : args.white_bkgd, 359 | 'raw_noise_std' : args.raw_noise_std, 360 | } 361 | 362 | # NDC only good for LLFF-style forward facing data 363 | if args.dataset_type != 'llff' or args.no_ndc: 364 | print('Not ndc!') 365 | render_kwargs_train['ndc'] = False 366 | render_kwargs_train['lindisp'] = args.lindisp 367 | 368 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 369 | render_kwargs_test['perturb'] = False 370 | render_kwargs_test['raw_noise_std'] = 0. 371 | 372 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer -------------------------------------------------------------------------------- /script/models/options.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | def config_parser(): 3 | parser = configargparse.ArgumentParser() 4 | parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1") 5 | parser.add_argument("--device", type=int, default=-1, help='CUDA_VISIBLE_DEVICES') 6 | parser.add_argument("--multi_gpu", action='store_true', help='use multiple gpu on the server') 7 | parser.add_argument('--config', is_config_file=True, help='config file path') 8 | parser.add_argument("--expname", type=str, help='experiment name') 9 | parser.add_argument("--basedir", type=str, default='../logs', help='where to store ckpts and logs') 10 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') 11 | 12 | # 7Scenes 13 | parser.add_argument("--trainskip", type=int, default=1, help='will load 1/N images from train sets, useful for large datasets like 7 Scenes') 14 | parser.add_argument("--df", type=float, default=1., help='image downscale factor') 15 | parser.add_argument("--reduce_embedding", type=int, default=-1, help='fourier embedding mode: -1: paper default, \ 16 | 0: reduce by half, 1: remove embedding, 2: DNeRF embedding') 17 | parser.add_argument("--epochToMaxFreq", type=int, default=-1, help='DNeRF embedding mode: (based on Nerfie paper): \ 18 | hyper-parameter for when α should reach the maximum number of frequencies m') 19 | parser.add_argument("--render_pose_only", action='store_true', help='render a spiral video for 7 Scene') 20 | parser.add_argument("--save_pose_avg_stats", action='store_true', help='save a pose avg stats to unify NeRF, posenet, direct-pn training') 21 | parser.add_argument("--load_pose_avg_stats", action='store_true', help='load precomputed pose avg stats to unify NeRF, posenet, nerf tracking training') 22 | parser.add_argument("--train_local_nerf", type=int, default=-1, help='train local NeRF with ith training sequence only, ie. Stairs can pick 0~3') 23 | parser.add_argument("--render_video_train", action='store_true', help='render train set NeRF and save as video, make sure render_test is True') 24 | parser.add_argument("--render_video_test", action='store_true', help='render val set NeRF and save as video, make sure render_test is True') 25 | parser.add_argument("--frustum_overlap_th", type=float, help='frustsum overlap threshold') 26 | parser.add_argument("--no_DNeRF_viewdir", action='store_true', default=False, help='will not use DNeRF in viewdir encoding') 27 | parser.add_argument("--load_unique_view_stats", action='store_true', help='load unique views frame index') 28 | 29 | # NeRF training options 30 | parser.add_argument("--netdepth", type=int, default=8, help='layers in network') 31 | parser.add_argument("--netwidth", type=int, default=128, help='channels per layer') 32 | parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network') 33 | parser.add_argument("--netwidth_fine", type=int, default=128, help='channels per layer in fine network') 34 | parser.add_argument("--N_rand", type=int, default=1536, help='batch size (number of random rays per gradient step)') 35 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') 36 | parser.add_argument("--lrate_decay", type=float, default=250, help='exponential learning rate decay (in 1000 steps)') 37 | parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') 38 | parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory') 39 | parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time') 40 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') 41 | parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network') 42 | parser.add_argument("--no_grad_update", action='store_true', default=False, help='do not update nerf in training') 43 | 44 | # NeRF-Hist training options 45 | parser.add_argument("--NeRFH", action='store_true', help='my implementation for NeRFH, to enable NeRF-Hist training, please make sure to add --encode_hist, otherwise it is similar to NeRFW') 46 | parser.add_argument("--N_vocab", type=int, default=1000, 47 | help='''number of vocabulary (number of images) 48 | in the dataset for nn.Embedding''') 49 | parser.add_argument("--fix_index", action='store_true', help='fix training frame index as 0') 50 | parser.add_argument("--encode_hist", default=False, action='store_true', help='encode histogram instead of frame index') 51 | parser.add_argument("--hist_bin", type=int, default=10, help='image histogram bin size') 52 | parser.add_argument("--in_channels_a", type=int, default=50, help='appearance embedding dimension, hist_bin*N_a when embedding histogram') 53 | parser.add_argument("--in_channels_t", type=int, default=20, help='transient embedding dimension, hist_bin*N_tau when embedding histogram') 54 | 55 | # NeRF rendering options 56 | parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray') 57 | parser.add_argument("--N_importance", type=int, default=64,help='number of additional fine samples per ray') 58 | parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter') 59 | parser.add_argument("--use_viewdirs", default=True, action='store_true', help='use full 5D input instead of 3D') 60 | parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none') 61 | parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') 62 | parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') 63 | parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 64 | parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') 65 | parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') 66 | parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 67 | 68 | # legacy mesh options 69 | parser.add_argument("--mesh_only", action='store_true', help='do not optimize, reload weights and save mesh to a file') 70 | parser.add_argument("--mesh_grid_size", type=int, default=80,help='number of grid points to sample in each dimension for marching cubes') 71 | 72 | # training options 73 | parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops') 74 | parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops') 75 | parser.add_argument("--epochs", type=int, default=600,help='number of epochs to train') 76 | 77 | # dataset options 78 | parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / 7Scenes') 79 | parser.add_argument("--testskip", type=int, default=1, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 80 | 81 | ## legacy blender flags 82 | parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') 83 | # parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800') 84 | 85 | ## llff flags 86 | parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images') 87 | parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') 88 | parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth') 89 | parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes') 90 | parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8') 91 | parser.add_argument("--no_bd_factor", action='store_true', default=False, help='do not use bd factor') 92 | 93 | # logging/saving options 94 | parser.add_argument("--i_print", type=int, default=1, help='frequency of console printout and metric loggin') 95 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') 96 | parser.add_argument("--i_weights", type=int, default=200, help='frequency of weight ckpt saving') 97 | parser.add_argument("--i_testset", type=int, default=200, help='frequency of testset saving') 98 | parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') 99 | 100 | return parser -------------------------------------------------------------------------------- /script/models/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # Ray helpers 5 | def get_rays(H, W, focal, c2w): 6 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H), indexing='ij') # pytorch's meshgrid has indexing='ij' 7 | i = i.t() 8 | j = j.t() 9 | dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1) 10 | 11 | # Rotate ray directions from camera frame to the world frame 12 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 13 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 14 | rays_o = c2w[:3,-1].expand(rays_d.shape) 15 | return rays_o, rays_d # rays_o (100,100,3), rays_d (100,100,3) 16 | 17 | 18 | def get_rays_np(H, W, focal, c2w): 19 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 20 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) 21 | # Rotate ray directions from camera frame to the world frame 22 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 23 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 24 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 25 | return rays_o, rays_d 26 | 27 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 28 | # Shift ray origins to near plane 29 | t = -(near + rays_o[...,2]) / rays_d[...,2] # t_n = −(n + o_z)/d_z move o to the ray's intersection with near plane 30 | rays_o = rays_o + t[...,None] * rays_d 31 | 32 | # Projection Formular (20) 33 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 34 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 35 | o2 = 1. + 2. * near / rays_o[...,2] 36 | # Formular (21) 37 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 38 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 39 | d2 = -2. * near / rays_o[...,2] 40 | 41 | rays_o = torch.stack([o0,o1,o2], -1) # o' 42 | rays_d = torch.stack([d0,d1,d2], -1) # d' 43 | 44 | return rays_o, rays_d -------------------------------------------------------------------------------- /script/run_nerf.py: -------------------------------------------------------------------------------- 1 | import utils.set_sys_path 2 | import os, sys 3 | import numpy as np 4 | import imageio 5 | import json 6 | import random 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | # from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm, trange 12 | 13 | from models.ray_utils import * 14 | from models.nerfw import * # NeRF-w and NeRF-hist 15 | from models.options import config_parser 16 | from models.rendering import * 17 | from dataset_loaders.load_7Scenes import load_7Scenes_dataloader_NeRF 18 | from dataset_loaders.load_Cambridge import load_Cambridge_dataloader_NeRF 19 | 20 | # losses 21 | from models.losses import loss_dict 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | np.random.seed(0) 25 | torch.manual_seed(0) 26 | import random 27 | random.seed(0) 28 | 29 | parser = config_parser() 30 | args = parser.parse_args() 31 | 32 | def train_on_epoch_nerfw(args, train_dl, H, W, focal, N_rand, optimizer, loss_func, global_step, render_kwargs_train): 33 | for batch_idx, (target, pose, img_idx) in enumerate(train_dl): 34 | target = target[0].permute(1,2,0).to(device) 35 | pose = pose.reshape(3,4).to(device) # reshape to 3x4 rot matrix 36 | img_idx = img_idx.to(device) 37 | 38 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 39 | if N_rand is not None: 40 | rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) 41 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W), indexing='ij'), -1) # (H, W, 2) 42 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2) 43 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 44 | select_coords = coords[select_inds].long() # (N_rand, 2) 45 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 46 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 47 | batch_rays = torch.stack([rays_o, rays_d], 0) 48 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 49 | 50 | # ##### Core optimization loop ##### 51 | rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, retraw=True, img_idx=img_idx, **render_kwargs_train) 52 | optimizer.zero_grad() 53 | 54 | # compute loss 55 | results = {} 56 | results['rgb_fine'] = rgb 57 | results['rgb_coarse'] = extras['rgb0'] 58 | results['beta'] = extras['beta'] 59 | results['transient_sigmas'] = extras['transient_sigmas'] 60 | 61 | loss_d = loss_func(results, target_s) 62 | loss = sum(l for l in loss_d.values()) 63 | 64 | with torch.no_grad(): 65 | img_loss = img2mse(rgb, target_s) 66 | psnr = mse2psnr(img_loss) 67 | loss.backward() 68 | optimizer.step() 69 | 70 | # NOTE: IMPORTANT! 71 | ### update learning rate ### 72 | decay_rate = 0.1 73 | decay_steps = args.lrate_decay * 1000 74 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 75 | for param_group in optimizer.param_groups: 76 | param_group['lr'] = new_lrate 77 | ################################ 78 | 79 | torch.set_default_tensor_type('torch.FloatTensor') 80 | return loss, psnr 81 | 82 | def train_nerf(args, train_dl, val_dl, hwf, i_split, near, far, render_poses=None, render_img=None): 83 | 84 | i_train, i_val, i_test = i_split 85 | # Cast intrinsics to right types 86 | H, W, focal = hwf 87 | H, W = int(H), int(W) 88 | hwf = [H, W, focal] 89 | 90 | # Create log dir and copy the config file 91 | basedir = args.basedir 92 | expname = args.expname 93 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 94 | f = os.path.join(basedir, expname, 'args.txt') 95 | with open(f, 'w') as file: 96 | for arg in sorted(vars(args)): 97 | attr = getattr(args, arg) 98 | file.write('{} = {}\n'.format(arg, attr)) 99 | if args.config is not None: 100 | f = os.path.join(basedir, expname, 'config.txt') 101 | with open(f, 'w') as file: 102 | file.write(open(args.config, 'r').read()) 103 | 104 | # Create nerf model 105 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 106 | global_step = start 107 | 108 | bds_dict = { 109 | 'near' : near, 110 | 'far' : far, 111 | } 112 | render_kwargs_train.update(bds_dict) 113 | render_kwargs_test.update(bds_dict) 114 | if args.reduce_embedding==2: 115 | render_kwargs_train['i_epoch'] = -1 116 | render_kwargs_test['i_epoch'] = -1 117 | 118 | if args.render_test: 119 | print('TRAIN views are', i_train) 120 | print('TEST views are', i_test) 121 | print('VAL views are', i_val) 122 | if args.reduce_embedding==2: 123 | render_kwargs_test['i_epoch'] = global_step 124 | render_test(args, train_dl, val_dl, hwf, start, render_kwargs_test) 125 | return 126 | 127 | # Prepare raybatch tensor if batching random rays 128 | N_rand = args.N_rand 129 | # use_batching = not args.no_batching 130 | 131 | N_epoch = args.epochs + 1 # epoch 132 | print('Begin') 133 | print('TRAIN views are', i_train) 134 | print('TEST views are', i_test) 135 | print('VAL views are', i_val) 136 | 137 | 138 | # loss function 139 | loss_func = loss_dict['nerfw'](coef=1) 140 | 141 | for i in trange(start, N_epoch): 142 | time0 = time.time() 143 | if args.reduce_embedding==2: 144 | render_kwargs_train['i_epoch'] = i 145 | loss, psnr = train_on_epoch_nerfw(args, train_dl, H, W, focal, N_rand, optimizer, loss_func, global_step, render_kwargs_train) 146 | dt = time.time()-time0 147 | ##### end ##### 148 | 149 | # Rest is logging 150 | if i%args.i_weights==0 and i!=0: 151 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 152 | if args.N_importance > 0: # have fine sample network 153 | torch.save({ 154 | 'global_step': global_step, 155 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 156 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 157 | 'embedding_a_state_dict': render_kwargs_train['embedding_a'].state_dict(), 158 | 'embedding_t_state_dict': render_kwargs_train['embedding_t'].state_dict(), 159 | 'optimizer_state_dict': optimizer.state_dict(), 160 | }, path) 161 | else: 162 | torch.save({ 163 | 'global_step': global_step, 164 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 165 | 'optimizer_state_dict': optimizer.state_dict(), 166 | }, path) 167 | print('Saved checkpoints at', path) 168 | 169 | if i%args.i_testset==0 and i > 0: # run thru all validation set 170 | 171 | # clean GPU memory before testing, try to avoid OOM 172 | torch.cuda.empty_cache() 173 | 174 | if args.reduce_embedding==2: 175 | render_kwargs_test['i_epoch'] = i 176 | trainsavedir = os.path.join(basedir, expname, 'trainset_{:06d}'.format(i)) 177 | os.makedirs(trainsavedir, exist_ok=True) 178 | images_train = [] 179 | poses_train = [] 180 | index_train = [] 181 | j_skip = 10 # save holdout view render result Trainset/j_skip 182 | # randomly choose some holdout views from training set 183 | for batch_idx, (img, pose, img_idx) in enumerate(train_dl): 184 | if batch_idx % j_skip != 0: 185 | continue 186 | img_val = img.permute(0,2,3,1) # (1,H,W,3) 187 | pose_val = torch.zeros(1,4,4) 188 | pose_val[0,:3,:4] = pose.reshape(3,4)[:3,:4] # (1,3,4)) 189 | pose_val[0,3,3] = 1. 190 | images_train.append(img_val) 191 | poses_train.append(pose_val) 192 | index_train.append(img_idx) 193 | images_train = torch.cat(images_train, dim=0).numpy() 194 | poses_train = torch.cat(poses_train, dim=0).to(device) 195 | index_train = torch.cat(index_train, dim=0).to(device) 196 | print('train poses shape', poses_train.shape) 197 | 198 | with torch.no_grad(): 199 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 200 | render_path(args, poses_train, hwf, args.chunk, render_kwargs_test, gt_imgs=images_train, savedir=trainsavedir, img_ids=index_train) 201 | torch.set_default_tensor_type('torch.FloatTensor') 202 | print('Saved train set') 203 | del images_train 204 | del poses_train 205 | 206 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 207 | os.makedirs(testsavedir, exist_ok=True) 208 | images_val = [] 209 | poses_val = [] 210 | index_val = [] 211 | # views from validation set 212 | for img, pose, img_idx in val_dl: 213 | img_val = img.permute(0,2,3,1) # (1,H,W,3) 214 | pose_val = torch.zeros(1,4,4) 215 | pose_val[0,:3,:4] = pose.reshape(3,4)[:3,:4] # (1,3,4)) 216 | pose_val[0,3,3] = 1. 217 | images_val.append(img_val) 218 | poses_val.append(pose_val) 219 | index_val.append(img_idx) 220 | 221 | images_val = torch.cat(images_val, dim=0).numpy() 222 | poses_val = torch.cat(poses_val, dim=0).to(device) 223 | index_val = torch.cat(index_val, dim=0).to(device) 224 | print('test poses shape', poses_val.shape) 225 | 226 | with torch.no_grad(): 227 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 228 | render_path(args, poses_val, hwf, args.chunk, render_kwargs_test, gt_imgs=images_val, savedir=testsavedir, img_ids=index_val) 229 | torch.set_default_tensor_type('torch.FloatTensor') 230 | print('Saved test set') 231 | 232 | # clean GPU memory after testing 233 | torch.cuda.empty_cache() 234 | del images_val 235 | del poses_val 236 | 237 | if i%args.i_print==0: 238 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 239 | 240 | global_step += 1 241 | 242 | def train(): 243 | 244 | print(parser.format_values()) 245 | 246 | # Load data 247 | if args.dataset_type == '7Scenes': 248 | 249 | train_dl, val_dl, hwf, i_split, bds, render_poses, render_img = load_7Scenes_dataloader_NeRF(args) 250 | near = bds[0] 251 | far = bds[1] 252 | 253 | print('NEAR FAR', near, far) 254 | train_nerf(args, train_dl, val_dl, hwf, i_split, near, far, render_poses, render_img) 255 | return 256 | 257 | elif args.dataset_type == 'Cambridge': 258 | 259 | train_dl, val_dl, hwf, i_split, bds, render_poses, render_img = load_Cambridge_dataloader_NeRF(args) 260 | near = bds[0] 261 | far = bds[1] 262 | 263 | print('NEAR FAR', near, far) 264 | train_nerf(args, train_dl, val_dl, hwf, i_split, near, far, render_poses, render_img) 265 | return 266 | 267 | else: 268 | print('Unknown dataset type', args.dataset_type, 'exiting') 269 | return 270 | 271 | if __name__=='__main__': 272 | 273 | train() -------------------------------------------------------------------------------- /script/train.py: -------------------------------------------------------------------------------- 1 | import utils.set_sys_path 2 | import os, sys 3 | import numpy as np 4 | import imageio 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from tqdm import tqdm, trange 12 | from torchsummary import summary 13 | # from torchinfo import summary 14 | import matplotlib.pyplot as plt 15 | 16 | from dm.pose_model import * 17 | from dm.direct_pose_model import * 18 | from dm.callbacks import EarlyStopping 19 | # from dm.prepare_data import prepare_data, load_dataset 20 | from dm.options import config_parser 21 | from models.rendering import render_path 22 | from models.nerfw import to8b 23 | from dataset_loaders.load_7Scenes import load_7Scenes_dataloader 24 | from dataset_loaders.load_Cambridge import load_Cambridge_dataloader 25 | from utils.utils import freeze_bn_layer 26 | from feature.direct_feature_matching import train_feature_matching 27 | # import torch.onnx 28 | 29 | parser = config_parser() 30 | args = parser.parse_args() 31 | device = torch.device('cuda:0') # this is really controlled in train.sh 32 | 33 | def render_test(args, train_dl, val_dl, hwf, start, model, device, render_kwargs_test): 34 | model.eval() 35 | 36 | # ### Eval Training set result 37 | if args.render_video_train: 38 | images_train = [] 39 | poses_train = [] 40 | # views from train set 41 | for img, pose in train_dl: 42 | predict_pose = inference_pose_regression(args, img, device, model) 43 | device_cpu = torch.device('cpu') 44 | predict_pose = predict_pose.to(device_cpu) # put predict pose back to cpu 45 | 46 | img_val = img.permute(0,2,3,1) # (1,240,320,3) 47 | pose_val = torch.zeros(1,4,4) 48 | pose_val[0,:3,:4] = predict_pose.reshape(3,4)[:3,:4] # (1,3,4)) 49 | pose_val[0,3,3] = 1. 50 | images_train.append(img_val) 51 | poses_train.append(pose_val) 52 | 53 | images_train = torch.cat(images_train, dim=0).numpy() 54 | poses_train = torch.cat(poses_train, dim=0) 55 | print('train poses shape', poses_train.shape) 56 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 57 | with torch.no_grad(): 58 | rgbs, disps = render_path(poses_train.to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images_train, savedir=None) 59 | torch.set_default_tensor_type('torch.FloatTensor') 60 | print('Saving trainset as video', rgbs.shape, disps.shape) 61 | moviebase = os.path.join(args.basedir, args.model_name, '{}_trainset_{:06d}_'.format(args.model_name, start)) 62 | imageio.mimwrite(moviebase + 'train_rgb.mp4', to8b(rgbs), fps=15, quality=8) 63 | imageio.mimwrite(moviebase + 'train_disp.mp4', to8b(disps / np.max(disps)), fps=15, quality=8) 64 | 65 | ### Eval Validation set result 66 | if args.render_video_test: 67 | images_val = [] 68 | poses_val = [] 69 | # views from val set 70 | for img, pose in val_dl: 71 | predict_pose = inference_pose_regression(args, img, device, model) 72 | device_cpu = torch.device('cpu') 73 | predict_pose = predict_pose.to(device_cpu) # put predict pose back to cpu 74 | 75 | img_val = img.permute(0,2,3,1) # (1,240,360,3) 76 | pose_val = torch.zeros(1,4,4) 77 | pose_val[0,:3,:4] = predict_pose.reshape(3,4)[:3,:4] # (1,3,4)) 78 | pose_val[0,3,3] = 1. 79 | images_val.append(img_val) 80 | poses_val.append(pose_val) 81 | 82 | images_val = torch.cat(images_val, dim=0).numpy() 83 | poses_val = torch.cat(poses_val, dim=0) 84 | print('test poses shape', poses_val.shape) 85 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 86 | with torch.no_grad(): 87 | rgbs, disps = render_path(poses_val.to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images_val, savedir=None) 88 | torch.set_default_tensor_type('torch.FloatTensor') 89 | print('Saving testset as video', rgbs.shape, disps.shape) 90 | moviebase = os.path.join(args.basedir, args.model_name, '{}_test_{:06d}_'.format(args.model_name, start)) 91 | imageio.mimwrite(moviebase + 'test_rgb.mp4', to8b(rgbs), fps=15, quality=8) 92 | imageio.mimwrite(moviebase + 'test_disp.mp4', to8b(disps / np.max(disps)), fps=15, quality=8) 93 | return 94 | 95 | def train(): 96 | print(parser.format_values()) 97 | # Load data 98 | if args.dataset_type == '7Scenes': 99 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_7Scenes_dataloader(args) 100 | elif args.dataset_type == 'Cambridge': 101 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_Cambridge_dataloader(args) 102 | else: 103 | print("please choose dataset_type: 7Scenes or Cambridge, exiting...") 104 | sys.exit() 105 | 106 | ### pose regression module, here requires a pretrained DFNet for Pose Estimator F 107 | assert(args.pretrain_model_path != '') # make sure to add a valid PATH using --pretrain_model_path 108 | # load pretrained DFNet model 109 | model = load_exisiting_model(args) 110 | 111 | if args.freezeBN: 112 | model = freeze_bn_layer(model) 113 | model.to(device) 114 | 115 | ### feature extraction module, here requires a pretrained DFNet for Feature Extractor G using --pretrain_featurenet_path 116 | if args.pretrain_featurenet_path == '': 117 | print('Use the same DFNet for Feature Extraction and Pose Regression') 118 | feat_model = load_exisiting_model(args) 119 | else: 120 | # you can optionally load different pretrained DFNet for feature extractor and pose estimator 121 | feat_model = load_exisiting_model(args, isFeatureNet=True) 122 | 123 | feat_model.eval() 124 | feat_model.to(device) 125 | 126 | # set optimizer 127 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) #weight_decay=weight_decay, **kwargs 128 | 129 | # set callbacks parameters 130 | early_stopping = EarlyStopping(args, patience=args.patience[0], verbose=False) 131 | 132 | # start training 133 | if args.dataset_type == '7Scenes': 134 | train_feature_matching(args, model, feat_model, optimizer, i_split, hwf, near, far, device, early_stopping, train_dl=train_dl, val_dl=val_dl, test_dl=test_dl) 135 | elif args.dataset_type == 'Cambridge': 136 | train_feature_matching(args, model, feat_model, optimizer, i_split, hwf, near, far, device, early_stopping, train_dl=train_dl, val_dl=val_dl, test_dl=test_dl) 137 | 138 | def eval(): 139 | print(parser.format_values()) 140 | # Load data 141 | if args.dataset_type == '7Scenes': 142 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_7Scenes_dataloader(args) 143 | elif args.dataset_type == 'Cambridge': 144 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_Cambridge_dataloader(args) 145 | else: 146 | print("please choose dataset_type: 7Scenes or Cambridge, exiting...") 147 | sys.exit() 148 | 149 | # load pretrained DFNet_dm model 150 | model = load_exisiting_model(args) 151 | if args.freezeBN: 152 | model = freeze_bn_layer(model) 153 | model.to(device) 154 | 155 | print(len(test_dl.dataset)) 156 | 157 | get_error_in_q(args, test_dl, model, len(test_dl.dataset), device, batch_size=1) 158 | 159 | if __name__ == '__main__': 160 | if args.eval: 161 | torch.manual_seed(0) 162 | random.seed(0) 163 | np.random.seed(0) 164 | eval() 165 | else: 166 | train() 167 | -------------------------------------------------------------------------------- /script/utils/set_sys_path.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') -------------------------------------------------------------------------------- /script/utils/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | helper functions to train robust feature extractors 3 | ''' 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import datetime 9 | from torchvision.utils import make_grid 10 | import matplotlib.pyplot as plt 11 | import pdb 12 | from PIL import Image 13 | from torchvision.utils import save_image 14 | from math import pi 15 | import cv2 16 | # from pykalman import KalmanFilter 17 | 18 | def freeze_bn_layer(model): 19 | ''' freeze bn layer by not require grad but still behave differently when model.train() vs. model.eval() ''' 20 | print("Freezing BatchNorm Layers...") 21 | for module in model.modules(): 22 | if isinstance(module, nn.BatchNorm2d): 23 | # print("this is a BN layer:", module) 24 | if hasattr(module, 'weight'): 25 | module.weight.requires_grad_(False) 26 | if hasattr(module, 'bias'): 27 | module.bias.requires_grad_(False) 28 | return model 29 | 30 | def freeze_bn_layer_train(model): 31 | ''' set batchnorm to eval() 32 | it is useful to align train and testing result 33 | ''' 34 | # model.train() 35 | # print("Freezing BatchNorm Layers...") 36 | for module in model.modules(): 37 | if isinstance(module, nn.BatchNorm2d): 38 | module.eval() 39 | return model 40 | 41 | def save_image_saliancy(tensor, path, normalize: bool = False, scale_each: bool = False,): 42 | """ 43 | Modification based on TORCHVISION.UTILS 44 | ::param: tensor (batch, channel, H, W) 45 | """ 46 | # grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=32) 47 | grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=6) 48 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 49 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 50 | fig = plt.figure() 51 | plt.imshow(ndarr[:,:,0], cmap='jet') # viridis, plasma 52 | plt.axis('off') 53 | fig.savefig(path, bbox_inches='tight',dpi=fig.dpi,pad_inches=0.0) 54 | plt.close() 55 | 56 | def save_image_saliancy_single(tensor, path, normalize: bool = False, scale_each: bool = False,): 57 | """ 58 | Modification based on TORCHVISION.UTILS, save single feature map 59 | ::param: tensor (batch, channel, H, W) 60 | """ 61 | # grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=32) 62 | grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=1) 63 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 64 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 65 | fig = plt.figure() 66 | # plt.imshow(ndarr[:,:,0], cmap='plasma') # viridis, jet 67 | plt.imshow(ndarr[:,:,0], cmap='jet') # viridis, jet 68 | plt.axis('off') 69 | fig.savefig(path, bbox_inches='tight',dpi=fig.dpi,pad_inches=0.0) 70 | plt.close() 71 | 72 | def print_feature_examples(features, path): 73 | """ 74 | print feature maps 75 | ::param: features 76 | """ 77 | kwargs = {'normalize' : True, } # 'scale_each' : True 78 | 79 | for i in range(len(features)): 80 | fn = path + '{}.png'.format(i) 81 | # save_image(features[i].permute(1,0,2,3), fn, **kwargs) 82 | save_image_saliancy(features[i].permute(1,0,2,3), fn, normalize=True) 83 | # pdb.set_trace() 84 | ### 85 | 86 | def plot_features(features, path='f', isList=True): 87 | """ 88 | print feature maps 89 | :param features: (3, [batch, H, W]) or [3, batch, H, W] 90 | :param path: save image path 91 | :param isList: wether the features is an list 92 | :return: 93 | """ 94 | kwargs = {'normalize' : True, } # 'scale_each' : True 95 | 96 | if isList: 97 | dim = features[0].dim() 98 | else: 99 | dim = features.dim() 100 | assert(dim==3 or dim==4) 101 | 102 | if dim==4 and isList: 103 | print_feature_examples(features, path) 104 | elif dim==4 and (isList==False): 105 | fn = path 106 | lvl, b, H, W = features.shape 107 | for i in range(features.shape[0]): 108 | fn = path + '{}.png'.format(i) 109 | save_image_saliancy(features[i][None,...].permute(1,0,2,3).cpu(), fn, normalize=True) 110 | 111 | # # concat everything 112 | # features = features.reshape([-1, H, W]) 113 | # # save_image(features[None,...].permute(1,0,2,3).cpu(), fn, **kwargs) 114 | # save_image_saliancy(features[None,...].permute(1,0,2,3).cpu(), fn, normalize=True) 115 | 116 | elif dim==3 and isList: # print all images in the list 117 | for i in range(len(features)): 118 | fn = path + '{}.png'.format(i) 119 | # save_image(features[i][None,...].permute(1,0,2,3).cpu(), fn, **kwargs) 120 | save_image_saliancy(features[i][None,...].permute(1,0,2,3).cpu(), fn, normalize=True) 121 | elif dim==3 and (isList==False): 122 | fn = path 123 | save_image_saliancy(features[None,...].permute(1,0,2,3).cpu(), fn, normalize=True) 124 | 125 | def sample_homography_np( 126 | shape, shift=0, perspective=True, scaling=True, rotation=True, translation=True, 127 | n_scales=5, n_angles=25, scaling_amplitude=0.1, perspective_amplitude_x=0.1, 128 | perspective_amplitude_y=0.1, patch_ratio=0.5, max_angle=pi/2, 129 | allow_artifacts=False, translation_overflow=0.): 130 | """Sample a random valid homography. 131 | 132 | Computes the homography transformation between a random patch in the original image 133 | and a warped projection with the same image size. 134 | As in `tf.contrib.image.transform`, it maps the output point (warped patch) to a 135 | transformed input point (original patch). 136 | The original patch, which is initialized with a simple half-size centered crop, is 137 | iteratively projected, scaled, rotated and translated. 138 | 139 | Arguments: 140 | shape: A rank-2 `Tensor` specifying the height and width of the original image. 141 | perspective: A boolean that enables the perspective and affine transformations. 142 | scaling: A boolean that enables the random scaling of the patch. 143 | rotation: A boolean that enables the random rotation of the patch. 144 | translation: A boolean that enables the random translation of the patch. 145 | n_scales: The number of tentative scales that are sampled when scaling. 146 | n_angles: The number of tentatives angles that are sampled when rotating. 147 | scaling_amplitude: Controls the amount of scale. 148 | perspective_amplitude_x: Controls the perspective effect in x direction. 149 | perspective_amplitude_y: Controls the perspective effect in y direction. 150 | patch_ratio: Controls the size of the patches used to create the homography. (like crop size) 151 | max_angle: Maximum angle used in rotations. 152 | allow_artifacts: A boolean that enables artifacts when applying the homography. 153 | translation_overflow: Amount of border artifacts caused by translation. 154 | 155 | Returns: 156 | A `Tensor` of shape `[1, 8]` corresponding to the flattened homography transform. 157 | """ 158 | 159 | # print("debugging") 160 | 161 | 162 | # Corners of the output image 163 | pts1 = np.stack([[0., 0.], [0., 1.], [1., 1.], [1., 0.]], axis=0) 164 | # Corners of the input patch 165 | margin = (1 - patch_ratio) / 2 166 | pts2 = margin + np.array([[0, 0], [0, patch_ratio], 167 | [patch_ratio, patch_ratio], [patch_ratio, 0]]) 168 | 169 | from numpy.random import normal 170 | from numpy.random import uniform 171 | from scipy.stats import truncnorm 172 | 173 | # Random perspective and affine perturbations 174 | # lower, upper = 0, 2 175 | std_trunc = 2 176 | # pdb.set_trace() 177 | if perspective: 178 | if not allow_artifacts: 179 | perspective_amplitude_x = min(perspective_amplitude_x, margin) 180 | perspective_amplitude_y = min(perspective_amplitude_y, margin) 181 | perspective_displacement = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_y/2).rvs(1) 182 | h_displacement_left = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1) 183 | h_displacement_right = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1) 184 | pts2 += np.array([[h_displacement_left, perspective_displacement], 185 | [h_displacement_left, -perspective_displacement], 186 | [h_displacement_right, perspective_displacement], 187 | [h_displacement_right, -perspective_displacement]]).squeeze() 188 | 189 | # Random scaling 190 | # sample several scales, check collision with borders, randomly pick a valid one 191 | if scaling: 192 | scales = truncnorm(-1*std_trunc, std_trunc, loc=1, scale=scaling_amplitude/2).rvs(n_scales) 193 | scales = np.concatenate((np.array([1]), scales), axis=0) 194 | 195 | center = np.mean(pts2, axis=0, keepdims=True) 196 | scaled = (pts2 - center)[np.newaxis, :, :] * scales[:, np.newaxis, np.newaxis] + center 197 | if allow_artifacts: 198 | valid = np.arange(n_scales) # all scales are valid except scale=1 199 | else: 200 | valid = (scaled >= 0.) * (scaled < 1.) 201 | valid = valid.prod(axis=1).prod(axis=1) 202 | valid = np.where(valid)[0] 203 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int) 204 | pts2 = scaled[idx,:,:] 205 | 206 | # Random translation 207 | if translation: 208 | # pdb.set_trace() 209 | t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0) 210 | if allow_artifacts: 211 | t_min += translation_overflow 212 | t_max += translation_overflow 213 | pts2 += np.array([uniform(-t_min[0], t_max[0],1), uniform(-t_min[1], t_max[1], 1)]).T 214 | 215 | # Random rotation 216 | # sample several rotations, check collision with borders, randomly pick a valid one 217 | if rotation: 218 | angles = np.linspace(-max_angle, max_angle, num=n_angles) 219 | angles = np.concatenate((angles, np.array([0.])), axis=0) # in case no rotation is valid 220 | center = np.mean(pts2, axis=0, keepdims=True) 221 | rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), 222 | np.cos(angles)], axis=1), [-1, 2, 2]) 223 | rotated = np.matmul( (pts2 - center)[np.newaxis,:,:], rot_mat) + center 224 | if allow_artifacts: 225 | valid = np.arange(n_angles) # all scales are valid except scale=1 226 | else: # find multiple valid option and choose the valid one 227 | valid = (rotated >= 0.) * (rotated < 1.) 228 | valid = valid.prod(axis=1).prod(axis=1) 229 | valid = np.where(valid)[0] 230 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int) 231 | pts2 = rotated[idx,:,:] 232 | 233 | # Rescale to actual size 234 | shape = shape[::-1] # different convention [y, x] 235 | pts1 *= shape[np.newaxis,:] 236 | pts2 *= shape[np.newaxis,:] 237 | 238 | homography = cv2.getPerspectiveTransform(np.float32(pts1+shift), np.float32(pts2+shift)) 239 | return homography 240 | 241 | def warp_points(points, homographies, device='cpu'): 242 | """ 243 | Warp a list of points with the given homography. 244 | 245 | Arguments: 246 | points: list of N points, shape (N, 2(x, y))). 247 | homography: batched or not (shapes (B, 3, 3) and (...) respectively). 248 | 249 | Returns: a Tensor of shape (N, 2) or (B, N, 2(x, y)) (depending on whether the homography 250 | is batched) containing the new coordinates of the warped points. 251 | 252 | """ 253 | # expand points len to (x, y, 1) 254 | no_batches = len(homographies.shape) == 2 255 | homographies = homographies.unsqueeze(0) if no_batches else homographies 256 | 257 | batch_size = homographies.shape[0] 258 | points = torch.cat((points.float(), torch.ones((points.shape[0], 1)).to(device)), dim=1) 259 | points = points.to(device) 260 | homographies = homographies.view(batch_size*3,3) 261 | 262 | warped_points = homographies@points.transpose(0,1) 263 | 264 | # normalize the points 265 | warped_points = warped_points.view([batch_size, 3, -1]) 266 | warped_points = warped_points.transpose(2, 1) 267 | warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:] 268 | return warped_points[0,:,:] if no_batches else warped_points 269 | 270 | def inv_warp_image_batch(img, mat_homo_inv, device='cpu', mode='bilinear'): 271 | ''' 272 | Inverse warp images in batch 273 | 274 | :param img: 275 | batch of images 276 | tensor [batch_size, 1, H, W] 277 | :param mat_homo_inv: 278 | batch of homography matrices 279 | tensor [batch_size, 3, 3] 280 | :param device: 281 | GPU device or CPU 282 | :return: 283 | batch of warped images 284 | tensor [batch_size, 1, H, W] 285 | ''' 286 | # compute inverse warped points 287 | if len(img.shape) == 2 or len(img.shape) == 3: 288 | img = img.view(1,1,img.shape[0], img.shape[1]) 289 | if len(mat_homo_inv.shape) == 2: 290 | mat_homo_inv = mat_homo_inv.view(1,3,3) 291 | 292 | Batch, channel, H, W = img.shape 293 | coor_cells = torch.stack(torch.meshgrid(torch.linspace(-1, 1, W), torch.linspace(-1, 1, H), indexing='ij'), dim=2) 294 | coor_cells = coor_cells.transpose(0, 1) 295 | coor_cells = coor_cells.to(device) 296 | coor_cells = coor_cells.contiguous() 297 | 298 | src_pixel_coords = warp_points(coor_cells.view([-1, 2]), mat_homo_inv, device) 299 | src_pixel_coords = src_pixel_coords.view([Batch, H, W, 2]) 300 | src_pixel_coords = src_pixel_coords.float() 301 | 302 | warped_img = F.grid_sample(img, src_pixel_coords, mode=mode, align_corners=True) 303 | return warped_img 304 | 305 | def compute_valid_mask(image_shape, inv_homography, device='cpu', erosion_radius=0): 306 | """ 307 | Compute a boolean mask of the valid pixels resulting from an homography applied to 308 | an image of a given shape. Pixels that are False correspond to bordering artifacts. 309 | A margin can be discarded using erosion. 310 | 311 | Arguments: 312 | input_shape: Tensor of rank 2 representing the image shape, i.e. `[H, W]`. 313 | homography: Tensor of shape (B, 8) or (8,), where B is the batch size. 314 | `erosion_radius: radius of the margin to be discarded. 315 | 316 | Returns: a Tensor of type `tf.int32` and shape (H, W). 317 | """ 318 | 319 | if inv_homography.dim() == 2: 320 | inv_homography = inv_homography.view(-1, 3, 3) 321 | batch_size = inv_homography.shape[0] 322 | mask = torch.ones(batch_size, 1, image_shape[0], image_shape[1]).to(device) 323 | mask = inv_warp_image_batch(mask, inv_homography, device=device, mode='nearest') 324 | mask = mask.view(batch_size, image_shape[0], image_shape[1]) 325 | mask = mask.cpu().numpy() 326 | if erosion_radius > 0: 327 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erosion_radius*2,)*2) 328 | for i in range(batch_size): 329 | mask[i, :, :] = cv2.erode(mask[i, :, :], kernel, iterations=1) 330 | 331 | return torch.tensor(mask).to(device) 332 | 333 | def Kalman1D(observations,damping=1): 334 | # To return the smoothed time series data 335 | observation_covariance = damping 336 | initial_value_guess = observations[0] 337 | transition_matrix = 1 338 | transition_covariance = 0.1 339 | initial_value_guess 340 | kf = KalmanFilter( 341 | initial_state_mean=initial_value_guess, 342 | initial_state_covariance=observation_covariance, 343 | observation_covariance=observation_covariance, 344 | transition_covariance=transition_covariance, 345 | transition_matrices=transition_matrix 346 | ) 347 | pred_state, state_cov = kf.smooth(observations) 348 | return pred_state 349 | 350 | def Kalman3D(observations,damping=1): 351 | ''' 352 | In: 353 | observation: Nx3 354 | Out: 355 | pred_state: Nx3 356 | ''' 357 | # To return the smoothed time series data 358 | observation_covariance = damping 359 | transition_matrix = 1 360 | transition_covariance = 0.1 361 | initial_value_guess_x = observations[0,0] 362 | initial_value_guess_y = observations[0,1] # ? 363 | initial_value_guess_z = observations[0,2] # ? 364 | 365 | # perform 1D smooth for each axis 366 | kfx = KalmanFilter( 367 | initial_state_mean=initial_value_guess_x, 368 | initial_state_covariance=observation_covariance, 369 | observation_covariance=observation_covariance, 370 | transition_covariance=transition_covariance, 371 | transition_matrices=transition_matrix 372 | ) 373 | pred_state_x, state_cov_x = kfx.smooth(observations[:, 0]) 374 | 375 | kfy = KalmanFilter( 376 | initial_state_mean=initial_value_guess_y, 377 | initial_state_covariance=observation_covariance, 378 | observation_covariance=observation_covariance, 379 | transition_covariance=transition_covariance, 380 | transition_matrices=transition_matrix 381 | ) 382 | pred_state_y, state_cov_y = kfy.smooth(observations[:, 1]) 383 | 384 | kfz = KalmanFilter( 385 | initial_state_mean=initial_value_guess_z, 386 | initial_state_covariance=observation_covariance, 387 | observation_covariance=observation_covariance, 388 | transition_covariance=transition_covariance, 389 | transition_matrices=transition_matrix 390 | ) 391 | pred_state_z, state_cov_z = kfy.smooth(observations[:, 2]) 392 | 393 | pred_state = np.concatenate((pred_state_x, pred_state_y, pred_state_z), axis=1) 394 | return pred_state --------------------------------------------------------------------------------