├── LICENSE
├── README.md
├── data
├── 7Scenes
│ ├── .DS_Store
│ ├── ._.DS_Store
│ ├── chess
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
│ ├── fire
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
│ ├── heads
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
│ ├── office
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
│ ├── pumpkin
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_avg_stats_old.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
│ ├── redkitchen
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
│ └── stairs
│ │ ├── pose_avg_stats.txt
│ │ ├── pose_stats.txt
│ │ ├── stats.txt
│ │ └── world_setup.json
├── Cambridge
│ ├── GreatCourt
│ │ ├── pose_avg_stats.txt
│ │ └── world_setup.json
│ ├── KingsCollege
│ │ ├── pose_avg_stats.txt
│ │ └── world_setup.json
│ ├── OldHospital
│ │ ├── pose_avg_stats.txt
│ │ └── world_setup.json
│ ├── ShopFacade
│ │ ├── pose_avg_stats.txt
│ │ └── world_setup.json
│ └── StMarysChurch
│ │ ├── pose_avg_stats.txt
│ │ └── world_setup.json
└── deepslam_data
│ └── .gitignore
├── dataset_loaders
├── cambridge_scenes.py
├── load_7Scenes.py
├── load_Cambridge.py
├── seven_scenes.py
└── utils
│ └── color.py
├── imgs
└── DFNet.png
├── requirements.txt
└── script
├── config_dfnet.txt
├── config_dfnetdm.txt
├── config_nerfh.txt
├── dm
├── __init__.py
├── callbacks.py
├── direct_pose_model.py
├── options.py
├── pose_model.py
└── prepare_data.py
├── feature
├── dfnet.py
├── direct_feature_matching.py
├── efficientnet.py
├── misc.py
├── model.py
└── options.py
├── models
├── __init__.py
├── losses.py
├── metrics.py
├── nerf.py
├── nerfw.py
├── options.py
├── ray_utils.py
└── rendering.py
├── run_feature.py
├── run_nerf.py
├── train.py
└── utils
├── set_sys_path.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Active Vision Laboratory
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DFNet: Enhance Absolute Pose Regression with Direct Feature Matching
2 |
3 | **[Shuai Chen](https://scholar.google.com/citations?user=c0xTh_YAAAAJ&hl=en), [Xinghui Li](https://scholar.google.com/citations?user=XLlgbBoAAAAJ&hl=en), [Zirui Wang](https://scholar.google.com/citations?user=zCBKqa8AAAAJ&hl=en), and [Victor Adrian Prisacariu](https://scholar.google.com/citations?user=GmWA-LoAAAAJ&hl=en) (ECCV 2022)**
4 |
5 | **[Project Page](https://dfnet.active.vision) | [Paper](https://arxiv.org/abs/2204.00559)**
6 |
7 | [](https://arxiv.org/abs/2204.00559)
8 |
9 | ## Setup
10 | ### Installing Requirements
11 | We tested our code based on CUDA11.3+, PyTorch 1.11.0+, and Python 3.7+ using [docker](https://docs.docker.com/engine/install/ubuntu/).
12 |
13 | Rest of dependencies are in requirement.txt
14 |
15 | ### Data Preparation
16 | - **7-Scenes**
17 |
18 | We use a similar data preparation as in [MapNet](https://github.com/NVlabs/geomapnet). You can download the [7-Scenes](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/) datasets to the `data/deepslam_data` directory.
19 |
20 | Or you can use simlink
21 |
22 | ```sh
23 | cd data/deepslam_data && ln -s 7SCENES_DIR 7Scenes
24 | ```
25 |
26 | Notice that we additionally computed a pose averaging stats (pose_avg_stats.txt) and manually tuned world_setup.json in `data/7Scenes` to align the 7Scenes' coordinate system with NeRF's coordinate system. You could generate your own re-alignment to a new pose_avg_stats.txt using the `--save_pose_avg_stats` configuration.
27 |
28 | - **Cambridge Landmarks**
29 |
30 | You can download the Cambridge Landmarks dataset using this script [here](https://github.com/vislearn/dsacstar/blob/master/datasets/setup_cambridge.py). Please also put the pose_avg_stats.txt and world_setup.json to the `data/Cambridge/CAMBRIDGE_SCENES` like we provided in the source code.
31 |
32 | ## Training
33 |
34 | Our method relies on a pretrained Histogram-assisted NeRF model and a DFNet model as we stated in the paper. We have provide example config files in our repo. The followings are examples to train the models.
35 |
36 | - NeRF model
37 |
38 | ```sh
39 | python run_nerf.py --config config_nerfh.txt
40 | ```
41 |
42 | - DFNet model
43 |
44 | ```sh
45 | python run_feature.py --config config_dfnet.txt
46 | ```
47 |
48 | - Direct Feature Matching (DFNetdm)
49 |
50 | ```sh
51 | python train.py --config config_dfnetdm.txt
52 | ```
53 |
54 | ## Evaluation
55 | We provide methods to evaluate our models.
56 | - To evaluate the NeRF model in PSNR, simply add `--render_test` argument.
57 |
58 | ```sh
59 | python run_nerf.py --config config_nerfh.txt --render_test
60 | ```
61 |
62 | - To evaluate APR performance of the DFNet model, you can just add `--eval --testskip=1 --pretrain_model_path=../logs/PATH_TO_CHECKPOINT`. For example:
63 |
64 | ```sh
65 | python run_feature.py --config config_dfnet.txt --eval --testskip=1 --pretrain_model_path=../logs/heads/dfnet/checkpoint.pt
66 | ```
67 |
68 | - Same to evaluate APR performance for the DFNetdm model
69 |
70 | ```sh
71 | python train.py --config config_dfnetdm.txt --eval --testskip=1 --pretrain_model_path=../logs/heads/dfnetdm/checkpoint.pt
72 | ```
73 |
74 | ## Pre-trained model
75 | We provide the 7-Scenes and Cambridge pre-trained models [here](https://www.robots.ox.ac.uk/~shuaic/DFNet2022/pretrain_models.tar.gz). Some models have slight better results than our paper reported. We suggest the models to be put in a new directory (`./logs/`) of the project.
76 |
77 | Notice we additionally provided Cambridge's Great Court scene models, although we didn't include the results in our main paper for fair comparisons with other works.
78 |
79 | Due to my limited resource, my pre-trained models are trained using 3080ti or 1080ti. I noticed earlier that the model's performance might vary slightly (could be better or worse) when inferencing with different types of GPUs, even using the exact same model. Therefore, all experiments on the paper are reported based on the same GPUs as they were trained.
80 |
81 | ## Acknowledgement
82 | We thank Michael Hobley, Theo Costain, Lixiong Chen, and Kejie Li for their generous discussion on this work.
83 |
84 | Most of our code is built upon [Direct-PoseNet](https://github.com/ActiveVisionLab/direct-posenet). Part of our Histogram-assisted NeRF implementation is referenced from the reproduced NeRFW code [here](https://github.com/kwea123/nerf_pl/tree/nerfw). We thank [@kwea123](https://github.com/kwea123) for this excellent work!
85 |
86 | ## Citation
87 | Please cite our paper and star this repo if you find our work helpful. Thanks!
88 | ```
89 | @inproceedings{chen2022dfnet,
90 | title={DFNet: Enhance Absolute Pose Regression with Direct Feature Matching},
91 | author={Chen, Shuai and Li, Xinghui and Wang, Zirui and Prisacariu, Victor},
92 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
93 | year={2022}
94 | }
95 | ```
96 | This code builds on previous camera relocalization pipelines, namely Direct-PoseNet. Please consider citing:
97 | ```
98 | @inproceedings{chen2021direct,
99 | title={Direct-PoseNet: Absolute pose regression with photometric consistency},
100 | author={Chen, Shuai and Wang, Zirui and Prisacariu, Victor},
101 | booktitle={2021 International Conference on 3D Vision (3DV)},
102 | pages={1175--1185},
103 | year={2021},
104 | organization={IEEE}
105 | }
106 | ```
107 |
--------------------------------------------------------------------------------
/data/7Scenes/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/data/7Scenes/.DS_Store
--------------------------------------------------------------------------------
/data/7Scenes/._.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/data/7Scenes/._.DS_Store
--------------------------------------------------------------------------------
/data/7Scenes/chess/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.350092088047137207e-01 1.857172640673737662e-01 -3.021040835171091565e-01 4.217527911023845610e-01
2 | -3.778758022057916027e-02 8.992282578255548220e-01 4.358447419771065423e-01 -7.391797421190476891e-01
3 | 3.526044217412146464e-01 -3.961030650668416753e-01 8.478045078986057304e-01 2.995973868145236363e-01
4 |
--------------------------------------------------------------------------------
/data/7Scenes/chess/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/chess/stats.txt:
--------------------------------------------------------------------------------
1 | 5.009650708326967017e-01 4.413125411911532625e-01 4.458285283490354689e-01
2 | 4.329720281018845096e-02 5.278270383679337097e-02 4.760929057962018374e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/chess/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":2,
4 | "pose_scale": 0.5,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 1.0]
7 | }
--------------------------------------------------------------------------------
/data/7Scenes/fire/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.682675339923704216e-01 1.308464847914461437e-01 -2.129252921427031431e-01 4.483261509061373107e-02
2 | -9.619441241477342738e-03 8.708690325271444266e-01 4.914209952122894909e-01 -4.201599145869548413e-01
3 | 2.497307529451176511e-01 -4.737787728496888895e-01 8.444928806274848432e-01 7.115903476590906829e-01
4 |
--------------------------------------------------------------------------------
/data/7Scenes/fire/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/fire/stats.txt:
--------------------------------------------------------------------------------
1 | 5.222627479256024552e-01 4.620521564670138082e-01 4.212473626365915158e-01
2 | 5.550322239689903236e-02 5.943252514694064015e-02 5.525370066993806617e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/fire/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":3,
4 | "pose_scale": 1,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 1.0]
7 | }
--------------------------------------------------------------------------------
/data/7Scenes/heads/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.821477519358328134e-01 -8.872136230840418913e-02 1.658743899386847798e-01 -1.157979895409090576e-02
2 | 3.111216594703735891e-02 9.462589180611266082e-01 3.219100699261678855e-01 -1.317583753459090623e-01
3 | -1.855204207020725304e-01 -3.110025399573565497e-01 9.321263828701550347e-01 9.636789777181822836e-02
4 |
--------------------------------------------------------------------------------
/data/7Scenes/heads/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/heads/stats.txt:
--------------------------------------------------------------------------------
1 | 4.570619554738562518e-01 4.504317877348855137e-01 4.586057516467524908e-01
2 | 7.874170624948270691e-02 7.747845434384653673e-02 7.183367877515742239e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/heads/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":2.5,
4 | "pose_scale": 1,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 1.0]
7 | }
--------------------------------------------------------------------------------
/data/7Scenes/office/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.192216692444369341e-01 -2.380853847231172160e-01 3.136030490488191935e-01 -2.580896777083788868e-02
2 | 8.889970878427481960e-02 9.014018431780023155e-01 4.237588452095970570e-01 -8.784274026985845474e-01
3 | -3.835731541303978309e-01 -3.616490933163600818e-01 8.497538283137727744e-01 1.063082783627855354e+00
4 |
--------------------------------------------------------------------------------
/data/7Scenes/office/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/office/stats.txt:
--------------------------------------------------------------------------------
1 | 4.703657901067226921e-01 4.414751487847252132e-01 4.351020758221028628e-01
2 | 7.105139804377599844e-02 7.191485421006868495e-02 6.783299267371162289e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/office/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":2,
4 | "pose_scale": 0.5,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 0.5]
7 | }
--------------------------------------------------------------------------------
/data/7Scenes/pumpkin/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.994189485731389544e-01 7.232278172020349324e-03 -3.330854823320933411e-02 -6.206658357107126822e-02
2 | -8.310756492857687694e-03 9.994418989496620664e-01 -3.235462795969396704e-02 -7.690604500476190264e-01
3 | 3.305596102789840757e-02 3.261264749044923833e-02 9.989212775109888032e-01 4.472261112787878634e-01
4 |
--------------------------------------------------------------------------------
/data/7Scenes/pumpkin/pose_avg_stats_old.txt:
--------------------------------------------------------------------------------
1 | 9.867033112503645897e-01 8.544426416488330733e-02 -1.382600929006191914e-01 7.374091044342952206e-02
2 | -9.057380802494104099e-02 9.953998641700174677e-01 -3.123292669878471872e-02 -7.475368646794867677e-01
3 | 1.349554032539169723e-01 4.334037530562200036e-02 9.899033543740218821e-01 3.342737444938814195e-01
4 |
--------------------------------------------------------------------------------
/data/7Scenes/pumpkin/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/pumpkin/stats.txt:
--------------------------------------------------------------------------------
1 | 5.503370888799515859e-01 4.492568432042766124e-01 4.579284152018213705e-01
2 | 4.053158612557544727e-02 4.899782680513672939e-02 3.385843494567825074e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/pumpkin/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":2.5,
4 | "pose_scale": 0.5,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 1.0]
7 | }
--------------------------------------------------------------------------------
/data/7Scenes/redkitchen/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.397923675833245172e-01 1.720816216688166589e-01 -2.952595829367096747e-01 1.068215764548530455e-01
2 | -1.553194639452672166e-01 9.846599603919916621e-01 7.950236801879838333e-02 -4.759053293419227559e-01
3 | 3.044111856550024142e-01 -2.885615852243439416e-02 9.521035406737251572e-01 9.771192826975949597e-01
4 |
--------------------------------------------------------------------------------
/data/7Scenes/redkitchen/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/redkitchen/stats.txt:
--------------------------------------------------------------------------------
1 | 5.262172203420504291e-01 4.400453064527823366e-01 4.320846191351511711e-01
2 | 4.872459633076364760e-02 6.484063059696282272e-02 5.724255797232574716e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/redkitchen/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":2,
4 | "pose_scale": 0.5,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 0.5]
7 | }
--------------------------------------------------------------------------------
/data/7Scenes/stairs/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.981044025604641767e-01 6.017846147761087006e-02 -1.289008780444797844e-02 4.026755230123485463e-02
2 | -4.703056619089211743e-02 8.809109637765394352e-01 4.709394862846526530e-01 -9.428126168765132986e-01
3 | 3.969543340464730397e-02 -4.694405464725816546e-01 8.820713383249345618e-01 3.852607943350118691e-01
4 |
--------------------------------------------------------------------------------
/data/7Scenes/stairs/pose_stats.txt:
--------------------------------------------------------------------------------
1 | 0.0000000 0.0000000 0.0000000
2 | 1.0000000 1.0000000 1.0000000
3 |
--------------------------------------------------------------------------------
/data/7Scenes/stairs/stats.txt:
--------------------------------------------------------------------------------
1 | 4.472714732115506964e-01 4.312183359438830910e-01 4.291487246732026972e-01
2 | 3.258580609153208241e-02 2.618736971489385446e-02 1.208855922484347589e-02
3 |
--------------------------------------------------------------------------------
/data/7Scenes/stairs/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0,
3 | "far":4,
4 | "pose_scale": 1,
5 | "pose_scale2": 1,
6 | "move_all_cam_vec": [0.0, 0.0, 0.0]
7 | }
--------------------------------------------------------------------------------
/data/Cambridge/GreatCourt/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 3.405292754461299309e-01 4.953070871398960184e-01 7.991937825040464904e-01 4.704508373014760281e+01
2 | -9.402316354416121458e-01 1.812586354202151695e-01 2.882876667504056800e-01 3.467785281210451842e+01
3 | -2.069849976503225410e-03 -8.495976674371187309e-01 5.274272643753655787e-01 1.101080132352710184e+00
4 |
--------------------------------------------------------------------------------
/data/Cambridge/GreatCourt/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0.0,
3 | "far":10.0,
4 | "pose_scale": 0.3027,
5 | "pose_scale2": 0.2,
6 | "move_all_cam_vec": [0.0, 0.0, 0.0]
7 | }
--------------------------------------------------------------------------------
/data/Cambridge/KingsCollege/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.995083419588323137e-01 -1.453974655309233331e-02 2.777895111190991154e-02 2.004095163645802913e+01
2 | -2.395968310872182219e-02 2.172811532927548528e-01 9.758149588979971867e-01 -2.354010655332784197e+01
3 | -2.022394471995193205e-02 -9.760007664924973403e-01 2.168259575466510436e-01 1.650110331018928678e+00
4 |
--------------------------------------------------------------------------------
/data/Cambridge/KingsCollege/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0.0,
3 | "far":10.0,
4 | "pose_scale": 0.3027,
5 | "pose_scale2": 0.2,
6 | "move_all_cam_vec": [0.0, 0.0, 0.0]
7 | }
--------------------------------------------------------------------------------
/data/Cambridge/OldHospital/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 9.997941252129602940e-01 6.239930741698496326e-03 1.930726428032739084e-02 1.319547963328867723e+01
2 | -3.333807443587469103e-03 -8.880897259859261705e-01 4.596580515189216398e-01 -6.473184854291670343e-01
3 | 2.001481745059596404e-02 -4.596277862168271500e-01 -8.878860879751624413e-01 2.310333011616541654e+01
4 |
--------------------------------------------------------------------------------
/data/Cambridge/OldHospital/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0.0,
3 | "far":10.0,
4 | "pose_scale": 0.3027,
5 | "pose_scale2": 0.2,
6 | "move_all_cam_vec": [0.0, 0.0, 5.0]
7 | }
--------------------------------------------------------------------------------
/data/Cambridge/ShopFacade/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | 2.084004683986779016e-01 1.972095064159990266e-02 9.778447365901210553e-01 -4.512817941282106560e+00
2 | -9.780353328393808221e-01 8.307943784904847639e-03 2.082735359757174609e-01 1.914896116567694540e+00
3 | -4.016526979027209426e-03 -9.997710048685441997e-01 2.101916590087021808e-02 1.768500113487243564e+00
4 |
--------------------------------------------------------------------------------
/data/Cambridge/ShopFacade/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0.0,
3 | "far":20.0,
4 | "pose_scale": 0.3027,
5 | "pose_scale2": 0.32,
6 | "move_all_cam_vec": [0.0, 0.0, 2.5]
7 | }
--------------------------------------------------------------------------------
/data/Cambridge/StMarysChurch/pose_avg_stats.txt:
--------------------------------------------------------------------------------
1 | -6.692001528162709878e-01 7.430812642562667492e-01 1.179059789653581552e-03 1.114036505648812359e+01
2 | 3.891382817260490012e-02 3.662925707351961935e-02 -9.985709847092467673e-01 -5.441265972613005403e-02
3 | -7.420625778515127502e-01 -6.681979738352623599e-01 -5.342844106669619036e-02 1.708768320112491068e+01
4 |
--------------------------------------------------------------------------------
/data/Cambridge/StMarysChurch/world_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "near":0.0,
3 | "far":10.0,
4 | "pose_scale": 0.3027,
5 | "pose_scale2": 0.2,
6 | "move_all_cam_vec": [0.0, 0.0, 0.0]
7 | }
--------------------------------------------------------------------------------
/data/deepslam_data/.gitignore:
--------------------------------------------------------------------------------
1 | 7Scenes
2 |
--------------------------------------------------------------------------------
/dataset_loaders/cambridge_scenes.py:
--------------------------------------------------------------------------------
1 | """
2 | pytorch data loader for the Cambridge Landmark dataset
3 | """
4 | import os
5 | import os.path as osp
6 | import numpy as np
7 | from PIL import Image
8 | import torch
9 | from torch.utils import data
10 | import sys
11 | import pdb
12 | import cv2
13 | from dataset_loaders.utils.color import rgb_to_yuv
14 | from torchvision.utils import save_image
15 | import json
16 |
17 | sys.path.insert(0, '../')
18 | #from common.pose_utils import process_poses
19 |
20 | ## NUMPY
21 | def qlog(q):
22 | """
23 | Applies logarithm map to q
24 | :param q: (4,)
25 | :return: (3,)
26 | """
27 | if all(q[1:] == 0):
28 | q = np.zeros(3)
29 | else:
30 | q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:])
31 | return q
32 |
33 | def process_poses_rotmat(poses_in, rot_mat):
34 | """
35 | processes the position + quaternion raw pose from dataset to position + rotation matrix
36 | :param poses_in: N x 7
37 | :param rot_mat: N x 3 x 3
38 | :return: processed poses N x 12
39 | """
40 |
41 | poses = np.zeros((poses_in.shape[0], 3, 4))
42 | poses[:,:3,:3] = rot_mat
43 | poses[...,:3,3] = poses_in[:, :3]
44 | poses = poses.reshape(poses_in.shape[0], 12)
45 | return poses
46 |
47 | from torchvision.datasets.folder import default_loader
48 | def load_image(filename, loader=default_loader):
49 | try:
50 | img = loader(filename)
51 | except IOError as e:
52 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e))
53 | return None
54 | except:
55 | print('Could not load image {:s}, unexpected error'.format(filename))
56 | return None
57 | return img
58 |
59 | def load_depth_image(filename):
60 | try:
61 | img_depth = Image.fromarray(np.array(Image.open(filename)).astype("uint16"))
62 | except IOError as e:
63 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e))
64 | return None
65 | return img_depth
66 |
67 | def normalize(x):
68 | return x / np.linalg.norm(x)
69 |
70 | def viewmatrix(z, up, pos):
71 | vec2 = normalize(z)
72 | vec1_avg = up
73 | vec0 = normalize(np.cross(vec1_avg, vec2))
74 | vec1 = normalize(np.cross(vec2, vec0))
75 | m = np.stack([vec0, vec1, vec2, pos], 1)
76 | return m
77 |
78 | def normalize_recenter_pose(poses, sc, hwf):
79 | ''' normalize xyz into [-1, 1], and recenter pose '''
80 | target_pose = poses.reshape(poses.shape[0],3,4)
81 | target_pose[:,:3,3] = target_pose[:,:3,3] * sc
82 |
83 |
84 | x_norm = target_pose[:,0,3]
85 | y_norm = target_pose[:,1,3]
86 | z_norm = target_pose[:,2,3]
87 |
88 | tpose_ = target_pose+0
89 |
90 | # find the center of pose
91 | center = np.array([x_norm.mean(), y_norm.mean(), z_norm.mean()])
92 | bottom = np.reshape([0,0,0,1.], [1,4])
93 |
94 | # pose avg
95 | vec2 = normalize(tpose_[:, :3, 2].sum(0))
96 | up = tpose_[:, :3, 1].sum(0)
97 | hwf=np.array(hwf).transpose()
98 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
99 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
100 |
101 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [tpose_.shape[0],1,1])
102 | poses = np.concatenate([tpose_[:,:3,:4], bottom], -2)
103 | poses = np.linalg.inv(c2w) @ poses
104 | return poses[:,:3,:].reshape(poses.shape[0],12)
105 |
106 | def downscale_pose(poses, sc):
107 | ''' downscale translation pose to [-1:1] only '''
108 | target_pose = poses.reshape(poses.shape[0],3,4)
109 | target_pose[:,:3,3] = target_pose[:,:3,3] * sc
110 | return target_pose.reshape(poses.shape[0],12)
111 |
112 | class Cambridge2(data.Dataset):
113 | def __init__(self, scene, data_path, train, transform=None,
114 | target_transform=None, mode=0, seed=7,
115 | skip_images=False, df=2., trainskip=1, testskip=1, hwf=[480,854,744.],
116 | ret_idx=False, fix_idx=False, ret_hist=False, hist_bin=10):
117 | """
118 | :param scene: scene name ['chess', 'pumpkin', ...]
119 | :param data_path: root 7scenes data directory.
120 | Usually '../data/deepslam_data/7Scenes'
121 | :param train: if True, return the training images. If False, returns the
122 | testing images
123 | :param transform: transform to apply to the images
124 | :param target_transform: transform to apply to the poses
125 | :param mode: (Obsolete) 0: just color image, 1: color image in NeRF 0-1 and resized
126 | :param skip_images: If True, skip loading images and return None instead
127 | :param df: downscale factor
128 | :param trainskip: due to 7scenes are so big, now can use less training sets # of trainset = 1/trainskip
129 | :param testskip: skip part of testset, # of testset = 1/testskip
130 | :param hwf: H,W,Focal from COLMAP
131 | """
132 |
133 | self.transform = transform
134 | self.target_transform = target_transform
135 | self.df = df
136 |
137 | self.H, self.W, self.focal = hwf
138 | self.H = int(self.H)
139 | self.W = int(self.W)
140 | np.random.seed(seed)
141 |
142 | self.train = train
143 | self.ret_idx = ret_idx
144 | self.fix_idx = fix_idx
145 | self.ret_hist = ret_hist
146 | self.hist_bin = hist_bin # histogram bin size
147 |
148 | if self.train:
149 | root_dir = osp.join(data_path, scene) + '/train'
150 | else:
151 | root_dir = osp.join(data_path, scene) + '/test'
152 |
153 | rgb_dir = root_dir + '/rgb/'
154 |
155 | pose_dir = root_dir + '/poses/'
156 |
157 | world_setup_fn = osp.join(data_path, scene) + '/world_setup.json'
158 |
159 | # collect poses and image names
160 | self.rgb_files = os.listdir(rgb_dir)
161 | self.rgb_files = [rgb_dir + f for f in self.rgb_files]
162 | self.rgb_files.sort()
163 |
164 | self.pose_files = os.listdir(pose_dir)
165 | self.pose_files = [pose_dir + f for f in self.pose_files]
166 | self.pose_files.sort()
167 |
168 | # remove some abnormal data, need to fix later
169 | if scene == 'ShopFacade' and self.train:
170 | del self.rgb_files[42]
171 | del self.rgb_files[35]
172 | del self.pose_files[42]
173 | del self.pose_files[35]
174 |
175 | if len(self.rgb_files) != len(self.pose_files):
176 | raise Exception('RGB file count does not match pose file count!')
177 |
178 | # read json file
179 | with open(world_setup_fn, 'r') as myfile:
180 | data=myfile.read()
181 |
182 | # parse json file
183 | obj = json.loads(data)
184 | self.near = obj['near']
185 | self.far = obj['far']
186 | self.pose_scale = obj['pose_scale']
187 | self.pose_scale2 = obj['pose_scale2']
188 | self.move_all_cam_vec = obj['move_all_cam_vec']
189 |
190 | # trainskip and testskip
191 | frame_idx = np.arange(len(self.rgb_files))
192 | if train and trainskip > 1:
193 | frame_idx_tmp = frame_idx[::trainskip]
194 | frame_idx = frame_idx_tmp
195 | elif not train and testskip > 1:
196 | frame_idx_tmp = frame_idx[::testskip]
197 | frame_idx = frame_idx_tmp
198 | self.gt_idx = frame_idx
199 |
200 | self.rgb_files = [self.rgb_files[i] for i in frame_idx]
201 | self.pose_files = [self.pose_files[i] for i in frame_idx]
202 |
203 | if len(self.rgb_files) != len(self.pose_files):
204 | raise Exception('RGB file count does not match pose file count!')
205 |
206 | # read poses
207 | poses = []
208 | for i in range(len(self.pose_files)):
209 | pose = np.loadtxt(self.pose_files[i])
210 | poses.append(pose)
211 | poses = np.array(poses) # [N, 4, 4]
212 | self.poses = poses[:, :3, :4].reshape(poses.shape[0], 12)
213 |
214 | # debug read one img and get the shape of the img
215 | img = load_image(self.rgb_files[0])
216 | img_np = (np.array(img) / 255.).astype(np.float32) # (480,854,3)
217 |
218 | self.H, self.W = img_np.shape[:2]
219 | if self.df != 1.:
220 | self.H = int(self.H//self.df)
221 | self.W = int(self.W//self.df)
222 | self.focal = self.focal/self.df
223 |
224 | def __len__(self):
225 | return len(self.rgb_files)
226 |
227 | def __getitem__(self, index):
228 | img = load_image(self.rgb_files[index])
229 | pose = self.poses[index]
230 | if self.df != 1.:
231 | img_np = (np.array(img) / 255.).astype(np.float32)
232 | dims = (self.W, self.H)
233 | img_half_res = cv2.resize(img_np, dims, interpolation=cv2.INTER_AREA) # (H, W, 3)
234 | img = img_half_res
235 |
236 | if self.transform is not None:
237 | img = self.transform(img)
238 |
239 | if self.target_transform is not None:
240 | pose = self.target_transform(pose)
241 |
242 | if self.ret_idx:
243 | if self.train and self.fix_idx==False:
244 | return img, pose, index
245 | else:
246 | return img, pose, 0
247 | if self.ret_hist:
248 | yuv = rgb_to_yuv(img)
249 | y_img = yuv[0] # extract y channel only
250 | hist = torch.histc(y_img, bins=self.hist_bin, min=0., max=1.) # compute intensity histogram
251 | hist = hist/(hist.sum())*100 # convert to histogram density, in terms of percentage per bin
252 | hist = torch.round(hist)
253 | return img, pose, hist
254 |
255 | return img, pose
256 |
257 | def main():
258 | """
259 | visualizes the dataset
260 | """
261 | #from common.vis_utils import show_batch, show_stereo_batch
262 | from torchvision.utils import make_grid
263 | import torchvision.transforms as transforms
264 | seq = 'ShopFacade'
265 | mode = 1
266 | # num_workers = 1
267 |
268 | # transformer
269 | data_transform = transforms.Compose([
270 | transforms.ToTensor(),
271 | ])
272 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x))
273 | kwargs = dict(ret_hist=True)
274 | dset = Cambridge2(seq, '../data/Cambridge/', True, data_transform, target_transform=target_transform, mode=mode, df=7.15, trainskip=2, **kwargs)
275 | print('Loaded Cambridge sequence {:s}, length = {:d}'.format(seq, len(dset)))
276 |
277 | data_loader = data.DataLoader(dset, batch_size=4, shuffle=False)
278 |
279 | batch_count = 0
280 | N = 2
281 | for batch in data_loader:
282 | print('Minibatch {:d}'.format(batch_count))
283 | pdb.set_trace()
284 |
285 | batch_count += 1
286 | if batch_count >= N:
287 | break
288 |
289 | if __name__ == '__main__':
290 | main()
291 |
--------------------------------------------------------------------------------
/dataset_loaders/load_Cambridge.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import sys
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | import torch.cuda
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms, models
10 |
11 | from dataset_loaders.cambridge_scenes import Cambridge2, normalize_recenter_pose, load_image
12 | import pdb
13 |
14 | #from dataset_loaders.frustum.frustum_util import initK, generate_sampling_frustum, compute_frustums_overlap
15 |
16 | #focal_length = 555 # This is an approximate https://github.com/NVlabs/geomapnet/issues/8
17 | # Official says (585,585)
18 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
19 |
20 | # translation z axis
21 | trans_t = lambda t : np.array([
22 | [1,0,0,0],
23 | [0,1,0,0],
24 | [0,0,1,t],
25 | [0,0,0,1]]).astype(float)
26 |
27 | # x rotation
28 | rot_phi = lambda phi : np.array([
29 | [1,0,0,0],
30 | [0,np.cos(phi),-np.sin(phi),0],
31 | [0,np.sin(phi), np.cos(phi),0],
32 | [0,0,0,1]]).astype(float)
33 |
34 | # y rotation
35 | rot_theta = lambda th : np.array([
36 | [np.cos(th),0,-np.sin(th),0],
37 | [0,1,0,0],
38 | [np.sin(th),0, np.cos(th),0],
39 | [0,0,0,1]]).astype(float)
40 |
41 | # z rotation
42 | rot_psi = lambda psi : np.array([
43 | [np.cos(psi),-np.sin(psi),0,0],
44 | [np.sin(psi),np.cos(psi),0,0],
45 | [0,0,1,0],
46 | [0,0,0,1]]).astype(float)
47 |
48 | def is_inside_frustum(p, x_res, y_res):
49 | return (0 < p[0]) & (p[0] < x_res) & (0 < p[1]) & (p[1] < y_res)
50 |
51 | def initK(f, cx, cy):
52 | K = np.eye(3, 3)
53 | K[0, 0] = K[1, 1] = f
54 | K[0, 2] = cx
55 | K[1, 2] = cy
56 | return K
57 |
58 | def generate_sampling_frustum(step, depth, K, f, cx, cy, x_res, y_res):
59 | #pdb.set_trace()
60 | x_max = depth * (x_res - cx) / f
61 | x_min = -depth * cx / f
62 | y_max = depth * (y_res - cy) / f
63 | y_min = -depth * cy / f
64 |
65 | zs = np.arange(0, depth, step)
66 | xs = np.arange(x_min, x_max, step)
67 | ys = np.arange(y_min, y_max, step)
68 |
69 | X0 = []
70 | for z in zs:
71 | for x in xs:
72 | for y in ys:
73 | P = np.array([x, y, z])
74 | p = np.dot(K, P)
75 | if p[2] < 0.00001:
76 | continue
77 | p = p / p[2]
78 | if is_inside_frustum(p, x_res, y_res):
79 | X0.append(P)
80 | X0 = np.array(X0)
81 | return X0
82 |
83 | def compute_frustums_overlap(pose0, pose1, sampling_frustum, K, x_res, y_res):
84 | R0 = pose0[0:3, 0:3]
85 | t0 = pose0[0:3, 3]
86 | R1 = pose1[0:3, 0:3]
87 | t1 = pose1[0:3, 3]
88 |
89 | R10 = np.dot(R1.T, R0)
90 | t10 = np.dot(R1.T, t0 - t1)
91 |
92 | _P = np.dot(R10, sampling_frustum.T).T + t10
93 | p = np.dot(K, _P.T).T
94 | pn = p[:, 2]
95 | p = np.divide(p, pn[:, None])
96 | res = np.apply_along_axis(is_inside_frustum, 1, p, x_res, y_res)
97 | return np.sum(res) / float(res.shape[0])
98 |
99 | def perturb_rotation(c2w, theta, phi, psi=0):
100 | last_row = np.tile(np.array([0, 0, 0, 1]), (1, 1)) # (N_images, 1, 4)
101 | c2w = np.concatenate([c2w, last_row], 0) # (N_images, 4, 4) homogeneous coordinate
102 | c2w = rot_phi(phi/180.*np.pi) @ c2w
103 | c2w = rot_theta(theta/180.*np.pi) @ c2w
104 | c2w = rot_psi(psi/180.*np.pi) @ c2w
105 | c2w = c2w[:3,:4]
106 | return c2w
107 |
108 | def viewmatrix(z, up, pos):
109 | vec2 = normalize(z)
110 | vec1_avg = up
111 | vec0 = normalize(np.cross(vec1_avg, vec2))
112 | vec1 = normalize(np.cross(vec2, vec0))
113 | m = np.stack([vec0, vec1, vec2, pos], 1)
114 | return m
115 |
116 | def normalize(v):
117 | """Normalize a vector."""
118 | return v / np.linalg.norm(v)
119 |
120 | def average_poses(poses):
121 | """
122 | Calculate the average pose, which is then used to center all poses
123 | using @center_poses. Its computation is as follows:
124 | 1. Compute the center: the average of pose centers.
125 | 2. Compute the z axis: the normalized average z axis.
126 | 3. Compute axis y': the average y axis.
127 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
128 | 5. Compute the y axis: z cross product x.
129 | Note that at step 3, we cannot directly use y' as y axis since it's
130 | not necessarily orthogonal to z axis. We need to pass from x to y.
131 | Inputs:
132 | poses: (N_images, 3, 4)
133 | Outputs:
134 | pose_avg: (3, 4) the average pose
135 | """
136 | # 1. Compute the center
137 | center = poses[..., 3].mean(0) # (3)
138 | # 2. Compute the z axis
139 | z = normalize(poses[..., 2].mean(0)) # (3)
140 | # 3. Compute axis y' (no need to normalize as it's not the final output)
141 | y_ = poses[..., 1].mean(0) # (3)
142 | # 4. Compute the x axis
143 | x = normalize(np.cross(y_, z)) # (3)
144 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
145 | y = np.cross(z, x) # (3)
146 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
147 | return pose_avg
148 |
149 | def center_poses(poses, pose_avg_from_file=None):
150 | """
151 | Center the poses so that we can use NDC.
152 | See https://github.com/bmild/nerf/issues/34
153 |
154 | Inputs:
155 | poses: (N_images, 3, 4)
156 | pose_avg_from_file: if not None, pose_avg is loaded from pose_avg_stats.txt
157 |
158 | Outputs:
159 | poses_centered: (N_images, 3, 4) the centered poses
160 | pose_avg: (3, 4) the average pose
161 | """
162 |
163 |
164 | if pose_avg_from_file is None:
165 | pose_avg = average_poses(poses) # (3, 4) # this need to be fixed throughout dataset
166 | else:
167 | pose_avg = pose_avg_from_file
168 |
169 | pose_avg_homo = np.eye(4)
170 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation (4,4)
171 | # by simply adding 0, 0, 0, 1 as the last row
172 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
173 | poses_homo = \
174 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
175 |
176 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
177 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
178 |
179 | return poses_centered, pose_avg #np.linalg.inv(pose_avg_homo)
180 |
181 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
182 | render_poses = []
183 | rads = np.array(list(rads) + [1.])
184 | hwf = c2w[:,4:5] # it's empty here...
185 |
186 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
187 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
188 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
189 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
190 | return render_poses
191 |
192 | def average_poses(poses):
193 | """
194 | Same as in SingleCamVideoStatic.py
195 | Inputs:
196 | poses: (N_images, 3, 4)
197 | Outputs:
198 | pose_avg: (3, 4) the average pose
199 | """
200 | center = poses[..., 3].mean(0) # (3)
201 | z = normalize(poses[..., 2].mean(0)) # (3)
202 | y_ = poses[..., 1].mean(0) # (3)
203 | x = normalize(np.cross(y_, z)) # (3)
204 | y = np.cross(z, x) # (3)
205 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
206 | return pose_avg
207 |
208 | def generate_render_pose(poses, bds):
209 | idx = np.random.choice(poses.shape[0])
210 | c2w=poses[idx]
211 | print(c2w[:3,:4])
212 |
213 | ## Get spiral
214 | # Get average pose
215 | up = normalize(poses[:, :3, 1].sum(0))
216 |
217 | # Find a reasonable "focus depth" for this dataset
218 | close_depth, inf_depth = bds.min()*.9, bds.max()*5.
219 | dt = .75
220 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
221 | focal = mean_dz
222 |
223 | # Get radii for spiral path
224 | shrink_factor = .8
225 | zdelta = close_depth * .2
226 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T
227 | rads = np.percentile(np.abs(tt), 20, 0) # views of 20 degrees
228 | c2w_path = c2w
229 | N_views = 120 # number of views in video
230 | N_rots = 2
231 |
232 | # Generate poses for spiral path
233 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
234 | return render_poses
235 |
236 | def perturb_render_pose(poses, bds, x, angle):
237 | """
238 | Inputs:
239 | poses: (3, 4)
240 | bds: bounds
241 | x: translational perturb range
242 | angle: rotation angle perturb range in degrees
243 | Outputs:
244 | new_c2w: (N_views, 3, 4) new poses
245 | """
246 | idx = np.random.choice(poses.shape[0])
247 | c2w=poses[idx]
248 |
249 | N_views = 10 # number of views in video
250 | new_c2w = np.zeros((N_views, 3, 4))
251 |
252 | # perturb translational pose
253 | for i in range(N_views):
254 | new_c2w[i] = c2w
255 | new_c2w[i,:,3] = new_c2w[i,:,3] + np.random.uniform(-x,x,3) # perturb pos between -1 to 1
256 | theta=np.random.uniform(-angle,angle,1) # in degrees
257 | phi=np.random.uniform(-angle,angle,1) # in degrees
258 | psi=np.random.uniform(-angle,angle,1) # in degrees
259 | new_c2w[i] = perturb_rotation(new_c2w[i], theta, phi, psi)
260 | return new_c2w, idx
261 |
262 | def remove_overlap_data(train_set, val_set):
263 | ''' Remove some overlap data in val set so that train set and val set do not have overlap '''
264 | train = train_set.gt_idx
265 | val = val_set.gt_idx
266 |
267 | # find redundant data index in val_set
268 | index = np.where(np.in1d(val, train) == True) # this is a tuple
269 | # delete redundant data
270 | val_set.gt_idx = np.delete(val_set.gt_idx, index)
271 | val_set.poses = np.delete(val_set.poses, index, axis=0)
272 | for i in sorted(index[0], reverse=True):
273 | val_set.c_imgs.pop(i)
274 | val_set.d_imgs.pop(i)
275 | return train_set, val_set
276 |
277 | def fix_coord(args, train_set, val_set, pose_avg_stats_file='', rescale_coord=True):
278 | ''' fix coord for 7 Scenes to align with llff style dataset '''
279 |
280 | # This is only to store a pre-calculated pose average stats of the dataset
281 | if args.save_pose_avg_stats:
282 | pdb.set_trace()
283 | if pose_avg_stats_file == '':
284 | print('pose_avg_stats_file location unspecified, please double check...')
285 | sys.exit()
286 |
287 | all_poses = train_set.poses
288 | all_poses = all_poses.reshape(all_poses.shape[0], 3, 4)
289 | all_poses, pose_avg = center_poses(all_poses)
290 |
291 | # save pose_avg to pose_avg_stats.txt
292 | np.savetxt(pose_avg_stats_file, pose_avg)
293 | print('pose_avg_stats.txt successfully saved')
294 | sys.exit()
295 |
296 | # get all poses (train+val)
297 | train_poses = train_set.poses
298 |
299 | val_poses = val_set.poses
300 | all_poses = np.concatenate([train_poses, val_poses])
301 |
302 | # Center the poses for ndc
303 | all_poses = all_poses.reshape(all_poses.shape[0], 3, 4)
304 |
305 | # Here we use either pre-stored pose average stats or calculate pose average stats on the flight to center the poses
306 | if args.load_pose_avg_stats:
307 | pose_avg_from_file = np.loadtxt(pose_avg_stats_file)
308 | all_poses, pose_avg = center_poses(all_poses, pose_avg_from_file)
309 | else:
310 | all_poses, pose_avg = center_poses(all_poses)
311 |
312 | # Correct axis to LLFF Style y,z -> -y,-z
313 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(all_poses), 1, 1)) # (N_images, 1, 4)
314 | all_poses = np.concatenate([all_poses, last_row], 1)
315 |
316 | # rotate tpose 90 degrees at x axis # only corrected translation position
317 | all_poses = rot_phi(180/180.*np.pi) @ all_poses
318 |
319 | # correct view direction except mirror with gt view
320 | all_poses[:,:3,:3] = -all_poses[:,:3,:3]
321 |
322 | # camera direction mirror at x axis mod1 R' = R @ mirror matrix
323 | # ref: https://gamedev.stackexchange.com/questions/149062/how-to-mirror-reflect-flip-a-4d-transformation-matrix
324 | all_poses[:,:3,:3] = all_poses[:,:3,:3] @ np.array([[-1,0,0],[0,1,0],[0,0,1]])
325 |
326 | all_poses = all_poses[:,:3,:4]
327 |
328 | bounds = np.array([train_set.near, train_set.far]) # manual tuned
329 |
330 | if rescale_coord:
331 | sc=train_set.pose_scale # manual tuned factor, align with colmap scale
332 | all_poses[:,:3,3] *= sc
333 |
334 | ### quite ugly ###
335 | # move center of camera pose
336 | if train_set.move_all_cam_vec != [0.,0.,0.]:
337 | all_poses[:, :3, 3] += train_set.move_all_cam_vec
338 |
339 | if train_set.pose_scale2 != 1.0:
340 | all_poses[:,:3,3] *= train_set.pose_scale2
341 | # end of new mod1
342 |
343 | # Return all poses to dataset loaders
344 | all_poses = all_poses.reshape(all_poses.shape[0], 12)
345 | train_set.poses = all_poses[:train_poses.shape[0]]
346 | val_set.poses = all_poses[train_poses.shape[0]:]
347 | return train_set, val_set, bounds
348 |
349 | def load_Cambridge_dataloader(args):
350 | ''' Data loader for Pose Regression Network '''
351 | if args.pose_only: # if train posenet is true
352 | pass
353 | else:
354 | raise Exception('load_Cambridge_dataloader() currently only support PoseNet Training, not NeRF training')
355 | data_dir, scene = osp.split(args.datadir) # ../data/7Scenes, chess
356 | dataset_folder, dataset = osp.split(data_dir) # ../data, 7Scenes
357 |
358 | # transformer
359 | data_transform = transforms.Compose([
360 | transforms.ToTensor(),
361 | ])
362 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x))
363 |
364 | ret_idx = False # return frame index
365 | fix_idx = False # return frame index=0 in training
366 | ret_hist = False
367 |
368 | if 'NeRFH' in args:
369 | if args.NeRFH == True:
370 | ret_idx = True
371 | if args.fix_index:
372 | fix_idx = True
373 |
374 | # encode hist experiment
375 | if args.encode_hist:
376 | ret_idx = False
377 | fix_idx = False
378 | ret_hist = True
379 |
380 | kwargs = dict(scene=scene, data_path=data_dir,
381 | transform=data_transform, target_transform=target_transform,
382 | df=args.df, ret_idx=ret_idx, fix_idx=fix_idx,
383 | ret_hist=ret_hist, hist_bin=args.hist_bin)
384 |
385 | if args.finetune_unlabel: # direct-pn + unlabel
386 | train_set = Cambridge2(train=False, testskip=args.trainskip, **kwargs)
387 | val_set = Cambridge2(train=False, testskip=args.testskip, **kwargs)
388 |
389 | # if not args.eval:
390 | # # remove overlap data in val_set that was already in train_set,
391 | # train_set, val_set = remove_overlap_data(train_set, val_set)
392 | else:
393 | train_set = Cambridge2(train=True, trainskip=args.trainskip, **kwargs)
394 | val_set = Cambridge2(train=False, testskip=args.testskip, **kwargs)
395 | L = len(train_set)
396 |
397 | i_train = train_set.gt_idx
398 | i_val = val_set.gt_idx
399 | i_test = val_set.gt_idx
400 | # use a pose average stats computed earlier to unify posenet and nerf training
401 | if args.save_pose_avg_stats or args.load_pose_avg_stats:
402 | pose_avg_stats_file = osp.join(args.datadir, 'pose_avg_stats.txt')
403 | train_set, val_set, bounds = fix_coord(args, train_set, val_set, pose_avg_stats_file, rescale_coord=False) # only adjust coord. systems, rescale are done at training
404 | else:
405 | train_set, val_set, bounds = fix_coord(args, train_set, val_set, rescale_coord=False)
406 |
407 | train_shuffle=True
408 | if args.eval:
409 | train_shuffle=False
410 |
411 | train_dl = DataLoader(train_set, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=8) #num_workers=4 pin_memory=True
412 | val_dl = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, num_workers=2)
413 | test_dl = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=2)
414 |
415 | hwf = [train_set.H, train_set.W, train_set.focal]
416 | i_split = [i_train, i_val, i_test]
417 |
418 | return train_dl, val_dl, test_dl, hwf, i_split, bounds.min(), bounds.max()
419 |
420 | def load_Cambridge_dataloader_NeRF(args):
421 | ''' Data loader for NeRF '''
422 |
423 | data_dir, scene = osp.split(args.datadir) # ../data/7Scenes, chess
424 | dataset_folder, dataset = osp.split(data_dir) # ../data, 7Scenes
425 |
426 | data_transform = transforms.Compose([
427 | transforms.ToTensor()])
428 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x))
429 |
430 | ret_idx = False # return frame index
431 | fix_idx = False # return frame index=0 in training
432 | ret_hist = False
433 |
434 | if 'NeRFH' in args:
435 | ret_idx = True
436 | if args.fix_index:
437 | fix_idx = True
438 |
439 | # encode hist experiment
440 | if args.encode_hist:
441 | ret_idx = False
442 | fix_idx = False
443 | ret_hist = True
444 |
445 | kwargs = dict(scene=scene, data_path=data_dir,
446 | transform=data_transform, target_transform=target_transform,
447 | df=args.df, ret_idx=ret_idx, fix_idx=fix_idx, ret_hist=ret_hist, hist_bin=args.hist_bin)
448 |
449 | train_set = Cambridge2(train=True, trainskip=args.trainskip, **kwargs)
450 | val_set = Cambridge2(train=False, testskip=args.testskip, **kwargs)
451 |
452 | i_train = train_set.gt_idx
453 | i_val = val_set.gt_idx
454 | i_test = val_set.gt_idx
455 |
456 | # use a pose average stats computed earlier to unify posenet and nerf training
457 | if args.save_pose_avg_stats or args.load_pose_avg_stats:
458 | pose_avg_stats_file = osp.join(args.datadir, 'pose_avg_stats.txt')
459 | train_set, val_set, bounds = fix_coord(args, train_set, val_set, pose_avg_stats_file)
460 | else:
461 | train_set, val_set, bounds = fix_coord(args, train_set, val_set)
462 |
463 | render_poses = None
464 | render_img = None
465 |
466 | train_shuffle=True
467 | if args.render_video_train or args.render_test or args.dataset_type == 'Cambridge2':
468 | train_shuffle=False
469 | train_dl = DataLoader(train_set, batch_size=1, shuffle=train_shuffle) # default
470 | # train_dl = DataLoader(train_set, batch_size=1, shuffle=False) # for debug only
471 | val_dl = DataLoader(val_set, batch_size=1, shuffle=False)
472 |
473 | hwf = [train_set.H, train_set.W, train_set.focal]
474 |
475 | i_split = [i_train, i_val, i_test]
476 |
477 | return train_dl, val_dl, hwf, i_split, bounds, render_poses, render_img
--------------------------------------------------------------------------------
/dataset_loaders/seven_scenes.py:
--------------------------------------------------------------------------------
1 | """
2 | pytorch data loader for the 7-scenes dataset
3 | """
4 | import os
5 | import os.path as osp
6 | import numpy as np
7 | from PIL import Image
8 | import torch
9 | from torch.utils import data
10 | import sys
11 | import pickle
12 | import pdb,copy
13 | import cv2
14 |
15 | sys.path.insert(0, '../')
16 | import transforms3d.quaternions as txq
17 | # see for formulas:
18 | # https://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-801-machine-vision-fall-2004/readings/quaternions.pdf
19 | # and "Quaternion and Rotation" - Yan-Bin Jia, September 18, 2016
20 | from dataset_loaders.utils.color import rgb_to_yuv
21 | import json
22 |
23 | def RT2QT(poses_in, mean_t, std_t):
24 | """
25 | processes the 1x12 raw pose from dataset by aligning and then normalizing
26 | :param poses_in: N x 12
27 | :param mean_t: 3
28 | :param std_t: 3
29 | :return: processed poses (translation + quaternion) N x 7
30 | """
31 | poses_out = np.zeros((len(poses_in), 7))
32 | poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]
33 |
34 | # align
35 | for i in range(len(poses_out)):
36 | R = poses_in[i].reshape((3, 4))[:3, :3]
37 | q = txq.mat2quat(R)
38 | q = q/(np.linalg.norm(q) + 1e-12) # normalize
39 | q *= np.sign(q[0]) # constrain to hemisphere
40 | poses_out[i, 3:] = q
41 |
42 | # normalize translation
43 | poses_out[:, :3] -= mean_t
44 | poses_out[:, :3] /= std_t
45 | return poses_out
46 |
47 | def qlog(q):
48 | """
49 | Applies logarithm map to q
50 | :param q: (4,)
51 | :return: (3,)
52 | """
53 | if all(q[1:] == 0):
54 | q = np.zeros(3)
55 | else:
56 | q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:])
57 | return q
58 |
59 | import transforms3d.quaternions as txq # Warning: outdated package
60 |
61 | def process_poses_rotmat(poses_in, mean_t, std_t, align_R, align_t, align_s):
62 | """
63 | processes the 1x12 raw pose from dataset by aligning and then normalizing
64 | produce logq
65 | :param poses_in: N x 12
66 | :return: processed poses N x 12
67 | """
68 | return poses_in
69 |
70 | def process_poses_q(poses_in, mean_t, std_t, align_R, align_t, align_s):
71 | """
72 | processes the 1x12 raw pose from dataset by aligning and then normalizing
73 | produce logq
74 | :param poses_in: N x 12
75 | :param mean_t: 3
76 | :param std_t: 3
77 | :param align_R: 3 x 3
78 | :param align_t: 3
79 | :param align_s: 1
80 | :return: processed poses (translation + log quaternion) N x 6
81 | """
82 | poses_out = np.zeros((len(poses_in), 6)) # (1000,6)
83 | poses_out[:, 0:3] = poses_in[:, [3, 7, 11]] # x,y,z position
84 | # align
85 | for i in range(len(poses_out)):
86 | R = poses_in[i].reshape((3, 4))[:3, :3] # rotation
87 | q = txq.mat2quat(np.dot(align_R, R))
88 | q *= np.sign(q[0]) # constrain to hemisphere, first number, +1/-1, q.shape (1,4)
89 | poses_out[i, 3:] = q # logq rotation
90 | t = poses_out[i, :3] - align_t
91 | poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()
92 |
93 | # normalize translation
94 | poses_out[:, :3] -= mean_t #(1000, 6)
95 | poses_out[:, :3] /= std_t
96 | return poses_out
97 |
98 | def process_poses_logq(poses_in, mean_t, std_t, align_R, align_t, align_s):
99 | """
100 | processes the 1x12 raw pose from dataset by aligning and then normalizing
101 | produce logq
102 | :param poses_in: N x 12
103 | :param mean_t: 3
104 | :param std_t: 3
105 | :param align_R: 3 x 3
106 | :param align_t: 3
107 | :param align_s: 1
108 | :return: processed poses (translation + log quaternion) N x 6
109 | """
110 | poses_out = np.zeros((len(poses_in), 6)) # (1000,6)
111 | poses_out[:, 0:3] = poses_in[:, [3, 7, 11]] # x,y,z position
112 | # align
113 | for i in range(len(poses_out)):
114 | R = poses_in[i].reshape((3, 4))[:3, :3] # rotation
115 | q = txq.mat2quat(np.dot(align_R, R))
116 | q *= np.sign(q[0]) # constrain to hemisphere, first number, +1/-1, q.shape (1,4)
117 | q = qlog(q) # (1,3)
118 | poses_out[i, 3:] = q # logq rotation
119 | t = poses_out[i, :3] - align_t
120 | poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()
121 |
122 | # normalize translation
123 | poses_out[:, :3] -= mean_t #(1000, 6)
124 | poses_out[:, :3] /= std_t
125 | return poses_out
126 |
127 | from torchvision.datasets.folder import default_loader
128 | def load_image(filename, loader=default_loader):
129 | try:
130 | img = loader(filename)
131 | except IOError as e:
132 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e))
133 | return None
134 | except:
135 | print('Could not load image {:s}, unexpected error'.format(filename))
136 | return None
137 | return img
138 |
139 | def load_depth_image(filename):
140 | try:
141 | img_depth = Image.fromarray(np.array(Image.open(filename)).astype("uint16"))
142 | except IOError as e:
143 | print('Could not load image {:s}, IOError: {:s}'.format(filename, e))
144 | return None
145 | return img_depth
146 |
147 | def normalize(x):
148 | return x / np.linalg.norm(x)
149 |
150 | def viewmatrix(z, up, pos):
151 | vec2 = normalize(z)
152 | vec1_avg = up
153 | vec0 = normalize(np.cross(vec1_avg, vec2))
154 | vec1 = normalize(np.cross(vec2, vec0))
155 | m = np.stack([vec0, vec1, vec2, pos], 1)
156 | return m
157 |
158 | def normalize_recenter_pose(poses, sc, hwf):
159 | ''' normalize xyz into [-1, 1], and recenter pose '''
160 | target_pose = poses.reshape(poses.shape[0],3,4)
161 | target_pose[:,:3,3] = target_pose[:,:3,3] * sc
162 |
163 | x_norm = target_pose[:,0,3]
164 | y_norm = target_pose[:,1,3]
165 | z_norm = target_pose[:,2,3]
166 |
167 | tpose_ = target_pose+0
168 |
169 | # find the center of pose
170 | center = np.array([x_norm.mean(), y_norm.mean(), z_norm.mean()])
171 | bottom = np.reshape([0,0,0,1.], [1,4])
172 |
173 | # pose avg
174 | vec2 = normalize(tpose_[:, :3, 2].sum(0))
175 | up = tpose_[:, :3, 1].sum(0)
176 | hwf=np.array(hwf).transpose()
177 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
178 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
179 |
180 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [tpose_.shape[0],1,1])
181 | poses = np.concatenate([tpose_[:,:3,:4], bottom], -2)
182 | poses = np.linalg.inv(c2w) @ poses
183 | return poses[:,:3,:].reshape(poses.shape[0],12)
184 |
185 | class SevenScenes(data.Dataset):
186 | def __init__(self, scene, data_path, train, transform=None,
187 | target_transform=None, mode=0, seed=7,
188 | df=1., trainskip=1, testskip=1, hwf=[480,640,585.],
189 | ret_idx=False, fix_idx=False, ret_hist=False, hist_bin=10):
190 | """
191 | :param scene: scene name ['chess', 'pumpkin', ...]
192 | :param data_path: root 7scenes data directory.
193 | Usually '../data/deepslam_data/7Scenes'
194 | :param train: if True, return the training images. If False, returns the
195 | testing images
196 | :param transform: transform to apply to the images
197 | :param target_transform: transform to apply to the poses
198 | :param mode: (Obsolete) 0: just color image, 1: color image in NeRF 0-1 and resized.
199 | :param df: downscale factor
200 | :param trainskip: due to 7scenes are so big, now can use less training sets # of trainset = 1/trainskip
201 | :param testskip: skip part of testset, # of testset = 1/testskip
202 | :param hwf: H,W,Focal from COLMAP
203 | :param ret_idx: bool, currently only used by NeRF-W
204 | """
205 |
206 | self.transform = transform
207 | self.target_transform = target_transform
208 | self.df = df
209 |
210 | self.H, self.W, self.focal = hwf
211 | self.H = int(self.H)
212 | self.W = int(self.W)
213 | np.random.seed(seed)
214 |
215 | self.train = train
216 | self.ret_idx = ret_idx
217 | self.fix_idx = fix_idx
218 | self.ret_hist = ret_hist
219 | self.hist_bin = hist_bin # histogram bin size
220 |
221 | # directories
222 | base_dir = osp.join(osp.expanduser(data_path), scene) # '../data/deepslam_data/7Scenes'
223 | data_dir = osp.join('..', 'data', '7Scenes', scene) # '../data/7Scenes/chess'
224 | world_setup_fn = data_dir + '/world_setup.json'
225 |
226 | # read json file
227 | with open(world_setup_fn, 'r') as myfile:
228 | data=myfile.read()
229 |
230 | # parse json file
231 | obj = json.loads(data)
232 | self.near = obj['near']
233 | self.far = obj['far']
234 | self.pose_scale = obj['pose_scale']
235 | self.pose_scale2 = obj['pose_scale2']
236 | self.move_all_cam_vec = obj['move_all_cam_vec']
237 |
238 | # decide which sequences to use
239 | if train:
240 | split_file = osp.join(base_dir, 'TrainSplit.txt')
241 | else:
242 | split_file = osp.join(base_dir, 'TestSplit.txt')
243 | with open(split_file, 'r') as f:
244 | seqs = [int(l.split('sequence')[-1]) for l in f if not l.startswith('#')] # parsing
245 |
246 | # read poses and collect image names
247 | self.c_imgs = []
248 | self.d_imgs = []
249 | self.gt_idx = np.empty((0,), dtype=np.int)
250 | ps = {}
251 | vo_stats = {}
252 | gt_offset = int(0)
253 | for seq in seqs:
254 | seq_dir = osp.join(base_dir, 'seq-{:02d}'.format(seq))
255 | seq_data_dir = osp.join(data_dir, 'seq-{:02d}'.format(seq))
256 |
257 | p_filenames = [n for n in os.listdir(osp.join(seq_dir, '.')) if n.find('pose') >= 0]
258 | idxes = [int(n[6:12]) for n in p_filenames]
259 |
260 | frame_idx = np.array(sorted(idxes))
261 |
262 |
263 | # trainskip and testskip
264 | if train and trainskip > 1:
265 | frame_idx_tmp = frame_idx[::trainskip]
266 | frame_idx = frame_idx_tmp
267 | elif not train and testskip > 1:
268 | frame_idx_tmp = frame_idx[::testskip]
269 | frame_idx = frame_idx_tmp
270 |
271 | pss = [np.loadtxt(osp.join(seq_dir, 'frame-{:06d}.pose.txt'.
272 | format(i))).flatten()[:12] for i in frame_idx] # all the 3x4 pose matrices
273 | ps[seq] = np.asarray(pss) # list of all poses in file No. seq
274 | vo_stats[seq] = {'R': np.eye(3), 't': np.zeros(3), 's': 1}
275 |
276 | self.gt_idx = np.hstack((self.gt_idx, gt_offset+frame_idx))
277 | gt_offset += len(p_filenames)
278 | c_imgs = [osp.join(seq_dir, 'frame-{:06d}.color.png'.format(i)) for i in frame_idx]
279 | d_imgs = [osp.join(seq_dir, 'frame-{:06d}.depth.png'.format(i)) for i in frame_idx]
280 | self.c_imgs.extend(c_imgs)
281 | self.d_imgs.extend(d_imgs)
282 |
283 | pose_stats_filename = osp.join(data_dir, 'pose_stats.txt')
284 | if train:
285 | mean_t = np.zeros(3) # optionally, use the ps dictionary to calc stats
286 | std_t = np.ones(3)
287 | np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f')
288 | else:
289 | mean_t, std_t = np.loadtxt(pose_stats_filename)
290 |
291 | # convert pose to translation + log quaternion
292 | logq = False
293 | quat = False
294 | if logq: # (batch_num, 6)
295 | self.poses = np.empty((0, 6))
296 | elif quat: # (batch_num, 7)
297 | self.poses = np.empty((0, 7))
298 | else: # (batch_num, 12)
299 | self.poses = np.empty((0, 12))
300 |
301 | for seq in seqs:
302 | if logq:
303 | pss = process_poses_logq(poses_in=ps[seq], mean_t=mean_t, std_t=std_t, align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'], align_s=vo_stats[seq]['s']) # here returns t + logQed R
304 | self.poses = np.vstack((self.poses, pss))
305 | elif quat:
306 | pss = RT2QT(poses_in=ps[seq], mean_t=mean_t, std_t=std_t) # here returns t + quaternion R
307 | self.poses = np.vstack((self.poses, pss))
308 | else:
309 | pss = process_poses_rotmat(poses_in=ps[seq], mean_t=mean_t, std_t=std_t, align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'], align_s=vo_stats[seq]['s'])
310 | self.poses = np.vstack((self.poses, pss))
311 |
312 | # debug read one img and get the shape of the img
313 | img = load_image(self.c_imgs[0])
314 | img_np = (np.array(img) / 255.).astype(np.float32) # (480,640,3)
315 | self.H, self.W = img_np.shape[:2]
316 | if self.df != 1.:
317 | self.H = int(self.H//self.df)
318 | self.W = int(self.W//self.df)
319 | self.focal = self.focal/self.df
320 |
321 | def __len__(self):
322 | return self.poses.shape[0]
323 |
324 | def __getitem__(self, index):
325 | # print("index:", index)
326 | img = load_image(self.c_imgs[index]) # chess img.size = (640,480)
327 | pose = self.poses[index]
328 | if self.df != 1.:
329 | img_np = (np.array(img) / 255.).astype(np.float32)
330 | dims = (self.W, self.H)
331 | img_half_res = cv2.resize(img_np, dims, interpolation=cv2.INTER_AREA) # (H, W, 3)
332 | img = img_half_res
333 |
334 | if self.target_transform is not None:
335 | pose = self.target_transform(pose)
336 |
337 | if self.transform is not None:
338 | img = self.transform(img)
339 |
340 | if self.ret_idx:
341 | if self.train and self.fix_idx==False:
342 | return img, pose, index
343 | else:
344 | return img, pose, 0
345 |
346 | if self.ret_hist:
347 | yuv = rgb_to_yuv(img)
348 | y_img = yuv[0] # extract y channel only
349 | hist = torch.histc(y_img, bins=self.hist_bin, min=0., max=1.) # compute intensity histogram
350 | hist = hist/(hist.sum())*100 # convert to histogram density, in terms of percentage per bin
351 | hist = torch.round(hist)
352 | return img, pose, hist
353 |
354 | return img, pose
355 |
356 |
357 | def main():
358 | """
359 | visualizes the dataset
360 | """
361 | # from common.vis_utils import show_batch, show_stereo_batch
362 | from torchvision.utils import make_grid
363 | import torchvision.transforms as transforms
364 | seq = 'heads'
365 | mode = 1
366 | num_workers = 6
367 | transform = transforms.Compose([
368 | transforms.Scale(256),
369 | transforms.CenterCrop(224),
370 | transforms.ToTensor(),
371 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
372 | ])
373 | target_transform = transforms.Lambda(lambda x: torch.Tensor(x))
374 | dset = SevenScenes(seq, '../data/deepslam_data/7Scenes', True, transform, target_transform=target_transform, mode=mode)
375 | print('Loaded 7Scenes sequence {:s}, length = {:d}'.format(seq, len(dset)))
376 | pdb.set_trace()
377 |
378 | data_loader = data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=num_workers)
379 |
380 | batch_count = 0
381 | N = 2
382 | for batch in data_loader:
383 | print('Minibatch {:d}'.format(batch_count))
384 | pdb.set_trace()
385 | # if mode < 2:
386 | # show_batch(make_grid(batch[0], nrow=1, padding=25, normalize=True))
387 | # elif mode == 2:
388 | # lb = make_grid(batch[0][0], nrow=1, padding=25, normalize=True)
389 | # rb = make_grid(batch[0][1], nrow=1, padding=25, normalize=True)
390 | # show_stereo_batch(lb, rb)
391 |
392 | batch_count += 1
393 | if batch_count >= N:
394 | break
395 |
396 | if __name__ == '__main__':
397 | main()
398 |
--------------------------------------------------------------------------------
/dataset_loaders/utils/color.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def rgb_to_yuv(image: torch.Tensor) -> torch.Tensor:
5 | r"""
6 | From Kornia.
7 | Convert an RGB image to YUV.
8 |
9 | .. image:: _static/img/rgb_to_yuv.png
10 |
11 | The image data is assumed to be in the range of (0, 1).
12 |
13 | Args:
14 | image: RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`.
15 |
16 | Returns:
17 | YUV version of the image with shape :math:`(*, 3, H, W)`.
18 |
19 | Example:
20 | >>> input = torch.rand(2, 3, 4, 5)
21 | >>> output = rgb_to_yuv(input) # 2x3x4x5
22 | """
23 | if not isinstance(image, torch.Tensor):
24 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
25 |
26 | if len(image.shape) < 3 or image.shape[-3] != 3:
27 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
28 |
29 | r: torch.Tensor = image[..., 0, :, :]
30 | g: torch.Tensor = image[..., 1, :, :]
31 | b: torch.Tensor = image[..., 2, :, :]
32 |
33 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b
34 | u: torch.Tensor = -0.147 * r - 0.289 * g + 0.436 * b
35 | v: torch.Tensor = 0.615 * r - 0.515 * g - 0.100 * b
36 |
37 | out: torch.Tensor = torch.stack([y, u, v], -3)
38 |
39 | return out
40 |
--------------------------------------------------------------------------------
/imgs/DFNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/imgs/DFNet.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Package Version
2 | ----------------------- -------------------
3 | absl-py 0.14.0
4 | antlr4-python3-runtime 4.8
5 | backcall 0.2.0
6 | beautifulsoup4 4.9.3
7 | brotlipy 0.7.0
8 | cached-property 1.5.2
9 | cachetools 4.2.4
10 | certifi 2021.5.30
11 | cffi 1.14.5
12 | chardet 3.0.4
13 | conda 4.10.1
14 | conda-build 3.21.4
15 | conda-package-handling 1.7.3
16 | ConfigArgParse 1.5.2
17 | cryptography 3.4.7
18 | cycler 0.10.0
19 | decorator 5.0.9
20 | dnspython 2.1.0
21 | efficientnet-pytorch 0.7.1
22 | einops 0.3.2
23 | filelock 3.0.12
24 | fvcore 0.1.5.post20210924
25 | glob2 0.7
26 | google-auth 1.35.0
27 | google-auth-oauthlib 0.4.6
28 | grpcio 1.41.0
29 | h5py 3.5.0
30 | idna 2.10
31 | imageio 2.9.0
32 | imageio-ffmpeg 0.4.5
33 | importlib-metadata 4.8.1
34 | iopath 0.1.9
35 | ipython 7.22.0
36 | ipython-genutils 0.2.0
37 | jedi 0.17.0
38 | Jinja2 3.0.0
39 | kiwisolver 1.3.2
40 | kornia 0.6.2
41 | libarchive-c 2.9
42 | Markdown 3.3.4
43 | MarkupSafe 2.0.1
44 | matplotlib 3.3.2
45 | mkl-fft 1.3.0
46 | mkl-random 1.2.1
47 | mkl-service 2.3.0
48 | numpy 1.20.2
49 | oauthlib 3.1.1
50 | olefile 0.46
51 | omegaconf 2.1.1
52 | opencv-python 4.4.0.46
53 | packaging 21.3
54 | parso 0.8.2
55 | pexpect 4.8.0
56 | pickleshare 0.7.5
57 | Pillow 8.2.0
58 | pip 21.1.2
59 | pkginfo 1.7.0
60 | portalocker 2.3.2
61 | prompt-toolkit 3.0.17
62 | protobuf 3.18.0
63 | psutil 5.8.0
64 | ptyprocess 0.7.0
65 | pyasn1 0.4.8
66 | pyasn1-modules 0.2.8
67 | pycosat 0.6.3
68 | pycparser 2.20
69 | Pygments 2.9.0
70 | pykalman 0.9.5
71 | pyOpenSSL 19.1.0
72 | pyparsing 2.4.7
73 | PySocks 1.7.1
74 | python-dateutil 2.8.2
75 | python-etcd 0.4.5
76 | pytorch3d 0.3.0
77 | pytz 2021.1
78 | PyYAML 5.4.1
79 | requests 2.24.0
80 | requests-oauthlib 1.3.0
81 | rsa 4.7.2
82 | ruamel-yaml-conda 0.15.100
83 | scipy 1.7.3
84 | setuptools 52.0.0.post20210125
85 | six 1.16.0
86 | soupsieve 2.2.1
87 | tabulate 0.8.9
88 | tensorboard 2.6.0
89 | tensorboard-data-server 0.6.1
90 | tensorboard-plugin-wit 1.8.0
91 | termcolor 1.1.0
92 | torch 1.11.0+cu113
93 | torchelastic 0.2.0
94 | torchsummary 1.5.1
95 | torchtext 0.10.0
96 | torchvision 0.10.0
97 | tqdm 4.51.0
98 | traitlets 5.0.5
99 | transforms3d 0.3.1
100 | trimesh 3.9.32
101 | typing-extensions 3.7.4.3
102 | urllib3 1.25.11
103 | wcwidth 0.2.5
104 | Werkzeug 2.0.1
105 | wheel 0.35.1
106 | yacs 0.1.8
107 | zipp 3.6.0
108 |
--------------------------------------------------------------------------------
/script/config_dfnet.txt:
--------------------------------------------------------------------------------
1 | ############################################### NeRF-Hist training example Cambridge ###############################################
2 | model_name=dfnet
3 | basedir=../logs/kings
4 | expname=nerfh
5 | datadir=../data/Cambridge/KingsCollege
6 | dataset_type=Cambridge
7 | trainskip=2 # train
8 | testskip=1 # train
9 | df=2
10 | load_pose_avg_stats=True
11 | NeRFH=True
12 | epochs=2000
13 | encode_hist=True
14 | tinyimg=True
15 | DFNet=True
16 | tripletloss=True
17 | featurenet_batch_size=4 # batch size, 4 or 8
18 | random_view_synthesis=True
19 | rvs_refresh_rate=20
20 | rvs_trans=3
21 | rvs_rotation=7.5
22 | d_max=1
23 | # pretrain_model_path = ../logs/kings/dfnet/checkpoint-0604-0.2688.pt # add your trained model for eval
24 | # eval=True # add this for eval
25 |
26 | ############################################### NeRF-Hist training example 7-Scenes ###############################################
27 | # model_name=dfnet
28 | # basedir=../logs/heads
29 | # expname=nerfh
30 | # datadir=../data/7Scenes/heads
31 | # dataset_type=7Scenes
32 | # trainskip=5 # train
33 | # testskip=1 #train
34 | # df=2 # train
35 | # load_pose_avg_stats=True
36 | # NeRFH=True
37 | # epochs=2000
38 | # encode_hist=True
39 | # batch_size=1 # NeRF loader batch size
40 | # tinyimg=True
41 | # DFNet=True
42 | # tripletloss=True
43 | # featurenet_batch_size=4
44 | # val_batch_size=8 # new
45 | # random_view_synthesis=True
46 | # rvs_refresh_rate=20
47 | # rvs_trans=0.2
48 | # rvs_rotation=10
49 | # d_max=0.2
50 | # # pretrain_model_path = ../logs/heads/dfnet/checkpoint-0888-0.0025.pt # add your trained model for eval
51 | # # eval=True # add this for eval
--------------------------------------------------------------------------------
/script/config_dfnetdm.txt:
--------------------------------------------------------------------------------
1 | ############################################### NeRF-Hist training example Cambridge ###############################################
2 | model_name=dfnetdm
3 | expname=nerfh
4 | basedir=../logs/kings # change this if change scenes
5 | datadir=../data/Cambridge/KingsCollege # change this if change scenes
6 | dataset_type=Cambridge
7 | pretrain_model_path=../logs/kings/dfnet/checkpoint-0604-0.2688.pt # this is your trained dfnet model for pose regression
8 | pretrain_featurenet_path=../logs/kings/dfnet/checkpoint-0604-0.2688.pt # this is your trained dfnet model for feature extraction
9 | trainskip=2 # train
10 | testskip=1 # train
11 | df=2
12 | load_pose_avg_stats=True
13 | NeRFH=True
14 | encode_hist=True
15 | freezeBN=True
16 | featuremetric=True
17 | pose_only=3
18 | svd_reg=True
19 | combine_loss = True
20 | combine_loss_w = [0., 0., 1.]
21 | finetune_unlabel=True
22 | i_eval=20
23 | DFNet=True
24 | val_on_psnr=True
25 | feature_matching_lvl = [0]
26 | # eval=True # add this for eval
27 | # pretrain_model_path=../logs/kings/dfnetdm/checkpoint-0267-17.1446.pt # add the trained model for eval
28 |
29 |
30 | ############################################### NeRF-Hist training example 7-Scenes ###############################################
31 | # model_name=dfnetdm
32 | # expname=nerfh
33 | # basedir=../logs/heads
34 | # datadir=../data/7Scenes/heads
35 | # dataset_type=7Scenes
36 | # pretrain_model_path=../logs/heads/dfnet/checkpoint-0888-0.0025.pt # this is your trained dfnet model for pose regression
37 | # pretrain_featurenet_path=../logs/heads/dfnet/checkpoint-0888-0.0025.pt # this is your trained dfnet model for feature extraction
38 | # trainskip=5 # train
39 | # testskip=1 #train
40 | # df=2
41 | # load_pose_avg_stats=True
42 | # NeRFH=True
43 | # encode_hist=True
44 | # freezeBN=True
45 | # featuremetric=True
46 | # pose_only=3
47 | # svd_reg=True
48 | # combine_loss = True
49 | # combine_loss_w = [0., 0., 1.]
50 | # finetune_unlabel=True
51 | # i_eval=20
52 | # DFNet=True
53 | # val_on_psnr=True
54 | # feature_matching_lvl = [0]
55 | # # eval=True # add this for eval
56 | # # pretrain_model_path=../logs/heads/dfnetdm/checkpoint-0317-17.5881.pt # add the trained model for eval
--------------------------------------------------------------------------------
/script/config_nerfh.txt:
--------------------------------------------------------------------------------
1 | ############################################### NeRF-Hist training example Cambridge ###############################################
2 | expname=nerfh
3 | basedir=../logs/kings
4 | datadir=../data/Cambridge/KingsCollege
5 | dataset_type=Cambridge
6 | lrate_decay=5
7 | trainskip=2
8 | testskip=1
9 | df=4
10 | load_pose_avg_stats=True
11 | NeRFH=True
12 | encode_hist=True
13 | # render_test=True # add this for eval
14 |
15 | ############################################### NeRF-Hist training example 7-Scenes ###############################################
16 | # expname=nerfh
17 | # basedir=../logs/heads
18 | # datadir=../data/7Scenes/heads
19 | # dataset_type=7Scenes
20 | # lrate_decay=0.754
21 | # trainskip=5
22 | # testskip=50
23 | # df=4
24 | # load_pose_avg_stats=True
25 | # NeRFH=True
26 | # encode_hist=True
27 | # # testskip=1 # add this for eval
28 | # # render_test=True # add this for eval
--------------------------------------------------------------------------------
/script/dm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/script/dm/__init__.py
--------------------------------------------------------------------------------
/script/dm/callbacks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import os, pdb
5 |
6 | def Callback():
7 | #TODO: Callback func. https://dzlab.github.io/dl/2019/03/16/pytorch-training-loop/
8 | def __init__(self): pass
9 | def on_train_begin(self): pass
10 | def on_train_end(self): pass
11 | def on_epoch_begin(self): pass
12 | def on_epoch_end(self): pass
13 | def on_batch_begin(self): pass
14 | def on_batch_end(self): pass
15 | def on_loss_begin(self): pass
16 | def on_loss_end(self): pass
17 | def on_step_begin(self): pass
18 | def on_step_end(self): pass
19 |
20 | class EarlyStopping:
21 | """Early stops the training if validation loss doesn't improve after a given patience."""
22 | # source https://blog.csdn.net/qq_37430422/article/details/103638681
23 | def __init__(self, args, patience=50, verbose=False, delta=0):
24 | """
25 | Args:
26 | patience (int): How long to wait after last time validation loss improved.
27 | Default: 50
28 | verbose (bool): If True, prints a message for each validation loss improvement.
29 | Default: False
30 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
31 | Default: 0
32 | """
33 | self.val_on_psnr = args.val_on_psnr
34 | self.patience = patience
35 | self.verbose = verbose
36 | self.counter = 0
37 | self.best_score = None
38 | self.early_stop = False
39 | self.val_loss_min = np.Inf
40 | self.delta = delta
41 |
42 | self.basedir = args.basedir
43 | self.model_name = args.model_name
44 |
45 | self.out_folder = os.path.join(self.basedir, self.model_name)
46 | self.ckpt_save_path = os.path.join(self.out_folder, 'checkpoint.pt')
47 | if not os.path.isdir(self.out_folder):
48 | os.mkdir(self.out_folder)
49 |
50 | def __call__(self, val_loss, model, epoch=-1, save_multiple=False, save_all=False, val_psnr=None):
51 |
52 | # find maximum psnr
53 | if self.val_on_psnr:
54 | score = val_psnr
55 | if self.best_score is None:
56 | self.best_score = score
57 | self.save_checkpoint(val_psnr, model, epoch=epoch, save_multiple=save_multiple)
58 | elif score < self.best_score + self.delta:
59 | self.counter += 1
60 |
61 | if self.counter >= self.patience:
62 | self.early_stop = True
63 |
64 | if save_all: # save all ckpt
65 | self.save_checkpoint(val_psnr, model, epoch=epoch, save_multiple=True, update_best=False)
66 | else: # save best ckpt only
67 | self.best_score = score
68 | self.save_checkpoint(val_psnr, model, epoch=epoch, save_multiple=save_multiple)
69 | self.counter = 0
70 |
71 | # find minimum loss
72 | else:
73 | score = -val_loss
74 | if self.best_score is None:
75 | self.best_score = score
76 | self.save_checkpoint(val_loss, model, epoch=epoch, save_multiple=save_multiple)
77 | elif score < self.best_score + self.delta:
78 | self.counter += 1
79 |
80 | if self.counter >= self.patience:
81 | self.early_stop = True
82 |
83 | if save_all: # save all ckpt
84 | self.save_checkpoint(val_loss, model, epoch=epoch, save_multiple=True, update_best=False)
85 | else: # save best ckpt only
86 | self.best_score = score
87 | self.save_checkpoint(val_loss, model, epoch=epoch, save_multiple=save_multiple)
88 | self.counter = 0
89 |
90 | def save_checkpoint(self, val_loss, model, epoch=-1, save_multiple=False, update_best=True):
91 | '''Saves model when validation loss decrease.'''
92 | if self.verbose:
93 | tqdm.write(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
94 | ckpt_save_path = self.ckpt_save_path
95 | if save_multiple:
96 | ckpt_save_path = ckpt_save_path[:-3]+f'-{epoch:04d}-{val_loss:.4f}.pt'
97 |
98 | torch.save(model.state_dict(), ckpt_save_path)
99 | if update_best:
100 | self.val_loss_min = val_loss
101 |
102 | def isBestModel(self):
103 | ''' Check if current model the best one.
104 | get early stop counter, if counter==0: it means current model has the best validation loss
105 | '''
106 | return self.counter==0
--------------------------------------------------------------------------------
/script/dm/options.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | def config_parser():
3 | parser = configargparse.ArgumentParser()
4 | parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
5 | parser.add_argument("--multi_gpu", action='store_true', help='use multiple gpu on the server')
6 |
7 | # 7Scenes
8 | parser.add_argument("--trainskip", type=int, default=1, help='will load 1/N images from train sets, useful for large datasets like 7 Scenes')
9 | parser.add_argument("--df", type=float, default=1., help='image downscale factor')
10 | parser.add_argument("--reduce_embedding", type=int, default=-1, help='fourier embedding mode: -1: paper default, \
11 | 0: reduce by half, 1: remove embedding, 2: DNeRF embedding')
12 | parser.add_argument("--epochToMaxFreq", type=int, default=-1, help='DNeRF embedding mode: (based on DNeRF paper): \
13 | hyper-parameter for when α should reach the maximum number of frequencies m')
14 | parser.add_argument("--render_pose_only", action='store_true', help='render a spiral video for 7 Scene')
15 | parser.add_argument("--save_pose_avg_stats", action='store_true', help='save a pose avg stats to unify NeRF, posenet, nerf tracking training')
16 | parser.add_argument("--load_pose_avg_stats", action='store_true', help='load precomputed pose avg stats to unify NeRF, posenet, nerf tracking training')
17 | parser.add_argument("--finetune_unlabel", action='store_true', help='finetune unlabeled sequence like MapNet')
18 | parser.add_argument("--i_eval", type=int, default=50, help='frequency of eval posenet result')
19 | parser.add_argument("--save_all_ckpt", action='store_true', help='save all ckpts for each epoch')
20 | parser.add_argument("--train_local_nerf", type=int, default=-1, help='train local NeRF with ith training sequence only, ie. Stairs can pick 0~3')
21 | parser.add_argument("--render_video_train", action='store_true', help='render train set NeRF and save as video, make sure i_eval is True')
22 | parser.add_argument("--render_video_test", action='store_true', help='render val set NeRF and save as video, make sure i_eval is True')
23 | parser.add_argument("--no_DNeRF_viewdir", action='store_true', default=False, help='will not use DNeRF in viewdir encoding')
24 | parser.add_argument("--val_on_psnr", action='store_true', default=False, help='EarlyStopping with max validation psnr')
25 | parser.add_argument("--feature_matching_lvl", nargs='+', type=int, default=[0,1,2],
26 | help='lvl of features used for feature matching, default use lvl 0, 1, 2')
27 |
28 | ##################### PoseNet Settings ########################
29 | parser.add_argument("--pose_only", type=int, default=0, help='posenet type to train, \
30 | 1: train baseline posenet, 2: posenet+nerf manual optimize, \
31 | 3: featurenet,')
32 | parser.add_argument("--learning_rate", type=float, default=0.00001, help='learning rate')
33 | parser.add_argument("--batch_size", type=int, default=1, help='train posenet only')
34 | parser.add_argument("--pretrain_model_path", type=str, default='', help='model path of pretrained pose regrssion model')
35 | parser.add_argument("--pretrain_featurenet_path", type=str, default='', help='model path of pretrained featurenet model')
36 | parser.add_argument("--model_name", type=str, help='pose model output folder name')
37 | parser.add_argument("--combine_loss", action='store_true',
38 | help='combined l2 pose loss + rgb mse loss')
39 | parser.add_argument("--combine_loss_w", nargs='+', type=float, default=[0.5, 0.5],
40 | help='weights of combined loss ex, [0.5 0.5], \
41 | default None, only use when combine_loss is True')
42 | parser.add_argument("--patience", nargs='+', type=int, default=[200, 50], help='set training schedule for patience [EarlyStopping, reduceLR]')
43 | parser.add_argument("--resize_factor", type=int, default=2, help='image resize downsample ratio')
44 | parser.add_argument("--freezeBN", action='store_true', help='Freeze the Batch Norm layer at training PoseNet')
45 | parser.add_argument("--preprocess_ImgNet", action='store_true', help='Normalize input data for PoseNet')
46 | parser.add_argument("--eval", action='store_true', help='eval model')
47 | parser.add_argument("--no_save_multiple", action='store_true', help='default, save multiple posenet model, if true, save only one posenet model')
48 | parser.add_argument("--resnet34", action='store_true', default=False, help='use resnet34 backbone instead of mobilenetV2')
49 | parser.add_argument("--efficientnet", action='store_true', default=False, help='use efficientnet-b3 backbone instead of mobilenetV2')
50 | parser.add_argument("--efficientnet_block", type=int, default=6, help='choose which features from feature block (0-6) of efficientnet to use')
51 | parser.add_argument("--dropout", type=float, default=0.5, help='dropout rate for resnet34 backbone')
52 | parser.add_argument("--DFNet", action='store_true', default=False, help='use DFNet')
53 | parser.add_argument("--DFNet_s", action='store_true', default=False, help='use accelerated DFNet, performance is similar to DFNet but slightly faster')
54 | parser.add_argument("--val_batch_size", type=int, default=1, help='batch_size for validation, higher number leads to faster speed')
55 |
56 | ##################### NeRF Settings ########################
57 | parser.add_argument('--config', is_config_file=True, help='config file path')
58 | parser.add_argument("--expname", type=str, help='experiment name')
59 | parser.add_argument("--basedir", type=str, default='../logs/', help='where to store ckpts and logs')
60 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
61 |
62 | # training options
63 | parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
64 | parser.add_argument("--netwidth", type=int, default=128, help='channels per layer')
65 | parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
66 | parser.add_argument("--netwidth_fine", type=int, default=128, help='channels per layer in fine network')
67 | parser.add_argument("--N_rand", type=int, default=1536, help='batch size (number of random rays per gradient step)')
68 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
69 | parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')
70 | parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
71 | parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
72 | parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')
73 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
74 | parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
75 | parser.add_argument("--no_grad_update", default=True, help='do not update nerf in training')
76 | parser.add_argument("--per_channel", default=False, action='store_true', help='using per channel cosine similarity loss instead of per pixel, defualt False')
77 |
78 | # NeRF-Hist training options
79 | parser.add_argument("--NeRFH", action='store_true', help='new implementation for NeRFH')
80 | parser.add_argument("--N_vocab", type=int, default=1000,
81 | help='''number of vocabulary (number of images)
82 | in the dataset for nn.Embedding''')
83 | parser.add_argument("--fix_index", action='store_true', help='fix training frame index as 0')
84 | parser.add_argument("--encode_hist", default=False, action='store_true', help='encode histogram instead of frame index')
85 | parser.add_argument("--hist_bin", type=int, default=10, help='image histogram bin size')
86 | parser.add_argument("--in_channels_a", type=int, default=50, help='appearance embedding dimension, hist_bin*N_a when embedding histogram')
87 | parser.add_argument("--in_channels_t", type=int, default=20, help='transient embedding dimension, hist_bin*N_tau when embedding histogram')
88 | parser.add_argument("--svd_reg", default=False, action='store_true', help='use svd regularize output at training')
89 |
90 | # rendering options
91 | parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
92 | parser.add_argument("--N_importance", type=int, default=64,help='number of additional fine samples per ray')
93 | parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter')
94 | parser.add_argument("--use_viewdirs", action='store_true', default=True, help='use full 5D input instead of 3D')
95 | parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
96 | parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
97 | parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
98 | parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
99 |
100 | parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
101 | parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
102 | parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
103 |
104 | # legacy mesh options
105 | parser.add_argument("--mesh_only", action='store_true', help='do not optimize, reload weights and save mesh to a file')
106 | parser.add_argument("--mesh_grid_size", type=int, default=80,help='number of grid points to sample in each dimension for marching cubes')
107 |
108 | # training options
109 | parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops')
110 | parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops')
111 |
112 | # dataset options
113 | parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels')
114 | parser.add_argument("--testskip", type=int, default=1, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
115 |
116 | ## legacy blender flags
117 | parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
118 | parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
119 |
120 | ## llff flags
121 | parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
122 | parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
123 | parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
124 | parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
125 | parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
126 | parser.add_argument("--no_bd_factor", action='store_true', default=False, help='do not use bd factor')
127 |
128 | # featruremetric supervision
129 | parser.add_argument("--featuremetric", action='store_true', help='use featuremetric supervision if true')
130 |
131 | # logging/saving options
132 | parser.add_argument("--i_print", type=int, default=1, help='frequency of console printout and metric loggin')
133 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
134 | parser.add_argument("--i_weights", type=int, default=200, help='frequency of weight ckpt saving')
135 | parser.add_argument("--i_testset", type=int, default=200, help='frequency of testset saving')
136 | parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
137 |
138 | return parser
--------------------------------------------------------------------------------
/script/dm/pose_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init
5 | import numpy as np
6 | from torchvision import models
7 | from efficientnet_pytorch import EfficientNet
8 | from torch.utils.tensorboard import SummaryWriter
9 | from tqdm import tqdm
10 | import pdb
11 | import matplotlib.pyplot as plt
12 |
13 | import math
14 | import time
15 |
16 | import pytorch3d.transforms as transforms
17 |
18 | def preprocess_data(inputs, device):
19 | # normalize inputs according to https://pytorch.org/hub/pytorch_vision_mobilenet_v2/
20 | mean = torch.Tensor([0.485, 0.456, 0.406]).to(device) # per channel subtraction
21 | std = torch.Tensor([0.229, 0.224, 0.225]).to(device) # per channel division
22 | inputs = (inputs - mean[None,:,None,None])/std[None,:,None,None]
23 | return inputs
24 |
25 | def filter_hook(m, g_in, g_out):
26 | g_filtered = []
27 | for g in g_in:
28 | g = g.clone()
29 | g[g != g] = 0
30 | g_filtered.append(g)
31 | return tuple(g_filtered)
32 |
33 | def vis_pose(vis_info):
34 | '''
35 | visualize predicted pose result vs. gt pose
36 | '''
37 | pdb.set_trace()
38 | pose = vis_info['pose']
39 | pose_gt = vis_info['pose_gt']
40 | theta = vis_info['theta']
41 | ang_threshold=10
42 | seq_num = theta.shape[0]
43 | # # create figure object
44 | # plot translation traj.
45 | fig = plt.figure(figsize = (8,6))
46 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
47 | ax1 = fig.add_axes([0, 0.2, 0.9, 0.85], projection='3d')
48 | ax1.scatter(pose[10:,0],pose[10:,1],zs=pose[10:,2], c='r', s=3**2,depthshade=0) # predict
49 | ax1.scatter(pose_gt[:,0], pose_gt[:,1], zs=pose_gt[:,2], c='g', s=3**2,depthshade=0) # GT
50 | ax1.scatter(pose[0:10,0],pose[0:10,1],zs=pose[0:10,2], c='k', s=3**2,depthshade=0) # predict
51 | # ax1.plot(pose[:,0],pose[:,1],zs=pose[:,2], c='r') # predict
52 | # ax1.plot(pose_gt[:,0], pose_gt[:,1], zs=pose_gt[:,2], c='g') # GT
53 | ax1.view_init(30, 120)
54 | ax1.set_xlabel('x (m)')
55 | ax1.set_ylabel('y (m)')
56 | ax1.set_zlabel('z (m)')
57 | # ax1.set_xlim(-10, 10)
58 | # ax1.set_ylim(-10, 10)
59 | # ax1.set_zlim(-10, 10)
60 |
61 | ax1.set_xlim(-1, 1)
62 | ax1.set_ylim(-1, 1)
63 | ax1.set_zlim(-1, 1)
64 |
65 | # ax1.set_xlim(-3, 3)
66 | # ax1.set_ylim(-3, 3)
67 | # ax1.set_zlim(-3, 3)
68 |
69 | # plot angular error
70 | ax2 = fig.add_axes([0.1, 0.05, 0.75, 0.2])
71 | err = theta.reshape(1, seq_num)
72 | err = np.tile(err, (20, 1))
73 | ax2.imshow(err, vmin=0,vmax=ang_threshold, aspect=3)
74 | ax2.set_yticks([])
75 | ax2.set_xticks([0, seq_num*1/5, seq_num*2/5, seq_num*3/5, seq_num*4/5, seq_num])
76 | fname = './vis_pose.png'
77 | plt.savefig(fname, dpi=50)
78 |
79 | def compute_error_in_q(args, dl, model, device, results, batch_size=1):
80 | use_SVD=True # Turn on for Direct-PN and Direct-PN+U reported result, despite it makes minuscule differences
81 | time_spent = []
82 | predict_pose_list = []
83 | gt_pose_list = []
84 | ang_error_list = []
85 | pose_result_raw = []
86 | pose_GT = []
87 | i = 0
88 |
89 | for batch in dl:
90 | if args.NeRFH:
91 | data, pose, img_idx = batch
92 | else:
93 | data, pose = batch
94 | data = data.to(device) # input
95 | pose = pose.reshape((batch_size,3,4)).numpy() # label
96 |
97 | if args.preprocess_ImgNet:
98 | data = preprocess_data(data, device)
99 |
100 | if use_SVD:
101 | # using SVD to make sure predict rotation is normalized rotation matrix
102 | with torch.no_grad():
103 | if args.featuremetric:
104 | _, predict_pose = model(data)
105 | else:
106 | predict_pose = model(data)
107 |
108 | R_torch = predict_pose.reshape((batch_size, 3, 4))[:,:3,:3] # debug
109 | predict_pose = predict_pose.reshape((batch_size, 3, 4)).cpu().numpy()
110 |
111 | R = predict_pose[:,:3,:3]
112 | res = R@np.linalg.inv(R)
113 | # print('R@np.linalg.inv(R):', res)
114 |
115 | u,s,v=torch.svd(R_torch)
116 | Rs = torch.matmul(u, v.transpose(-2,-1))
117 | predict_pose[:,:3,:3] = Rs[:,:3,:3].cpu().numpy()
118 | else:
119 | start_time = time.time()
120 | # inference NN
121 | with torch.no_grad():
122 | predict_pose = model(data)
123 | predict_pose = predict_pose.reshape((batch_size, 3, 4)).cpu().numpy()
124 | time_spent.append(time.time() - start_time)
125 |
126 | pose_q = transforms.matrix_to_quaternion(torch.Tensor(pose[:,:3,:3]))#.cpu().numpy() # gnd truth in quaternion
127 | pose_x = pose[:, :3, 3] # gnd truth position
128 | predicted_q = transforms.matrix_to_quaternion(torch.Tensor(predict_pose[:,:3,:3]))#.cpu().numpy() # predict in quaternion
129 | predicted_x = predict_pose[:, :3, 3] # predict position
130 | pose_q = pose_q.squeeze()
131 | pose_x = pose_x.squeeze()
132 | predicted_q = predicted_q.squeeze()
133 | predicted_x = predicted_x.squeeze()
134 |
135 | #Compute Individual Sample Error
136 | q1 = pose_q / torch.linalg.norm(pose_q)
137 | q2 = predicted_q / torch.linalg.norm(predicted_q)
138 | d = torch.abs(torch.sum(torch.matmul(q1,q2)))
139 | d = torch.clamp(d, -1., 1.) # acos can only input [-1~1]
140 | theta = (2 * torch.acos(d) * 180/math.pi).numpy()
141 | error_x = torch.linalg.norm(torch.Tensor(pose_x-predicted_x)).numpy()
142 | results[i,:] = [error_x, theta]
143 | #print ('Iteration: {} Error XYZ (m): {} Error Q (degrees): {}'.format(i, error_x, theta))
144 |
145 | # save results for visualization
146 | predict_pose_list.append(predicted_x)
147 | gt_pose_list.append(pose_x)
148 | ang_error_list.append(theta)
149 | pose_result_raw.append(predict_pose)
150 | pose_GT.append(pose)
151 | i += 1
152 | # pdb.set_trace()
153 | predict_pose_list = np.array(predict_pose_list)
154 | gt_pose_list = np.array(gt_pose_list)
155 | ang_error_list = np.array(ang_error_list)
156 | pose_result_raw = np.asarray(pose_result_raw)[:,0,:,:]
157 | pose_GT = np.asarray(pose_GT)[:,0,:,:]
158 | vis_info_ret = {"pose": predict_pose_list, "pose_gt": gt_pose_list, "theta": ang_error_list, "pose_result_raw": pose_result_raw, "pose_GT": pose_GT}
159 | return results, vis_info_ret
160 |
161 | # # pytorch
162 | def get_error_in_q(args, dl, model, sample_size, device, batch_size=1):
163 | ''' Convert Rotation matrix to quaternion, then calculate the location errors. original from PoseNet Paper '''
164 | model.eval()
165 |
166 | results = np.zeros((sample_size, 2))
167 | results, vis_info = compute_error_in_q(args, dl, model, device, results, batch_size)
168 | median_result = np.median(results,axis=0)
169 | mean_result = np.mean(results,axis=0)
170 |
171 | # standard log
172 | print ('Median error {}m and {} degrees.'.format(median_result[0], median_result[1]))
173 | print ('Mean error {}m and {} degrees.'.format(mean_result[0], mean_result[1]))
174 |
175 | # timing log
176 | #print ('Avg execution time (sec): {:.3f}'.format(np.mean(time_spent)))
177 |
178 | # standard log2
179 | # num_translation_less_5cm = np.asarray(np.where(results[:,0]<0.05))[0]
180 | # num_rotation_less_5 = np.asarray(np.where(results[:,1]<5))[0]
181 | # print ('translation error less than 5cm {}/{}.'.format(num_translation_less_5cm.shape[0], results.shape[0]))
182 | # print ('rotation error less than 5 degree {}/{}.'.format(num_rotation_less_5.shape[0], results.shape[0]))
183 | # print ('results:', results)
184 |
185 | # save for direct-pn paper log
186 | # if 0:
187 | # filename='Direct-PN+U_' + args.datadir.split('/')[-1] + '_result.txt'
188 | # np.savetxt(filename, predict_pose)
189 |
190 | # visualize results
191 | # vis_pose(vis_info)
192 |
193 | class EfficientNetB3(nn.Module):
194 | ''' EfficientNet-B3 backbone,
195 | model ref: https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
196 | '''
197 | def __init__(self, feat_dim=12):
198 | super(EfficientNetB3, self).__init__()
199 | self.backbone_net = EfficientNet.from_pretrained('efficientnet-b3')
200 | self.feature_extractor = self.backbone_net.extract_features
201 | self.avgpool = nn.AdaptiveAvgPool2d(1)
202 | self.fc_pose = nn.Linear(1536, feat_dim) # 1280 for efficientnet-b0, 1536 for efficientnet-b3
203 |
204 | def forward(self, Input):
205 | x = self.feature_extractor(Input)
206 | x = self.avgpool(x)
207 | x = x.reshape(x.size(0), -1)
208 | predict = self.fc_pose(x)
209 | return predict
210 |
211 | # PoseNet (SE(3)) w/ mobilev2 backbone
212 | class PoseNetV2(nn.Module):
213 | def __init__(self, feat_dim=12):
214 | super(PoseNetV2, self).__init__()
215 | self.backbone_net = models.mobilenet_v2(pretrained=True)
216 | self.feature_extractor = self.backbone_net.features
217 | self.avgpool = nn.AdaptiveAvgPool2d(1)
218 | self.fc_pose = nn.Linear(1280, feat_dim)
219 |
220 | def forward(self, Input):
221 | x = self.feature_extractor(Input)
222 | x = self.avgpool(x)
223 | x = x.reshape(x.size(0), -1)
224 | predict = self.fc_pose(x)
225 | # pdb.set_trace()
226 | return predict
227 |
228 | # PoseNet (SE(3)) w/ resnet34 backnone. We found dropout layer is unnecessary, so we set droprate as 0 in reported results.
229 | class PoseNet_res34(nn.Module):
230 | def __init__(self, droprate=0.5, pretrained=True,
231 | feat_dim=2048):
232 | super(PoseNet_res34, self).__init__()
233 | self.droprate = droprate
234 |
235 | # replace the last FC layer in feature extractor
236 | self.feature_extractor = models.resnet34(pretrained=True)
237 | self.feature_extractor.avgpool = nn.AdaptiveAvgPool2d(1)
238 | fe_out_planes = self.feature_extractor.fc.in_features
239 | self.feature_extractor.fc = nn.Linear(fe_out_planes, feat_dim)
240 | self.fc_pose = nn.Linear(feat_dim, 12)
241 |
242 | # initialize
243 | if pretrained:
244 | init_modules = [self.feature_extractor.fc]
245 | else:
246 | init_modules = self.modules()
247 |
248 | for m in init_modules:
249 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
250 | nn.init.kaiming_normal_(m.weight.data)
251 | if m.bias is not None:
252 | nn.init.constant_(m.bias.data, 0)
253 |
254 | def forward(self, x):
255 | x = self.feature_extractor(x)
256 | x = F.relu(x)
257 | if self.droprate > 0:
258 | x = F.dropout(x, p=self.droprate)
259 | predict = self.fc_pose(x)
260 | return predict
261 |
262 |
263 | # from MapNet paper CVPR 2018
264 | class PoseNet(nn.Module):
265 | def __init__(self, feature_extractor, droprate=0.5, pretrained=True,
266 | feat_dim=2048, filter_nans=False):
267 | super(PoseNet, self).__init__()
268 | self.droprate = droprate
269 |
270 | # replace the last FC layer in feature extractor
271 | self.feature_extractor = models.resnet34(pretrained=True)
272 | self.feature_extractor.avgpool = nn.AdaptiveAvgPool2d(1)
273 | fe_out_planes = self.feature_extractor.fc.in_features
274 | self.feature_extractor.fc = nn.Linear(fe_out_planes, feat_dim)
275 |
276 | self.fc_xyz = nn.Linear(feat_dim, 3)
277 | self.fc_wpqr = nn.Linear(feat_dim, 3)
278 | if filter_nans:
279 | self.fc_wpqr.register_backward_hook(hook=filter_hook)
280 | # initialize
281 | if pretrained:
282 | init_modules = [self.feature_extractor.fc, self.fc_xyz, self.fc_wpqr]
283 | else:
284 | init_modules = self.modules()
285 |
286 | for m in init_modules:
287 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
288 | nn.init.kaiming_normal_(m.weight.data)
289 | if m.bias is not None:
290 | nn.init.constant_(m.bias.data, 0)
291 |
292 | def forward(self, x):
293 | x = self.feature_extractor(x)
294 | x = F.relu(x)
295 | if self.droprate > 0:
296 | x = F.dropout(x, p=self.droprate)
297 |
298 | xyz = self.fc_xyz(x)
299 | wpqr = self.fc_wpqr(x)
300 | return torch.cat((xyz, wpqr), 1)
301 |
302 | class MapNet(nn.Module):
303 | """
304 | Implements the MapNet model (green block in Fig. 2 of paper)
305 | """
306 | def __init__(self, mapnet):
307 | """
308 | :param mapnet: the MapNet (two CNN blocks inside the green block in Fig. 2
309 | of paper). Not to be confused with MapNet, the model!
310 | """
311 | super(MapNet, self).__init__()
312 | self.mapnet = mapnet
313 |
314 | def forward(self, x):
315 | """
316 | :param x: image blob (N x T x C x H x W)
317 | :return: pose outputs
318 | (N x T x 6)
319 | """
320 | s = x.size()
321 | x = x.view(-1, *s[2:])
322 | poses = self.mapnet(x)
323 | poses = poses.view(s[0], s[1], -1)
324 | return poses
325 |
326 | def eval_on_epoch(args, dl, model, optimizer, loss_func, device):
327 | model.eval()
328 | val_loss_epoch = []
329 | for data, pose in dl:
330 | inputs = data.to(device)
331 | labels = pose.to(device)
332 | if args.preprocess_ImgNet:
333 | inputs = preprocess_data(inputs, device)
334 | predict = model(inputs)
335 | loss = loss_func(predict, labels)
336 | val_loss_epoch.append(loss.item())
337 | val_loss_epoch_mean = np.mean(val_loss_epoch)
338 | return val_loss_epoch_mean
339 |
340 |
341 | def train_on_epoch(args, dl, model, optimizer, loss_func, device):
342 | model.train()
343 | train_loss_epoch = []
344 | for data, pose in dl:
345 | inputs = data.to(device) # (N, Ch, H, W) ~ (4,3,200,200), 7scenes [4, 3, 256, 341] wierd shape...
346 | labels = pose.to(device)
347 | if args.preprocess_ImgNet:
348 | inputs = preprocess_data(inputs, device)
349 |
350 | predict = model(inputs)
351 | loss = loss_func(predict, labels)
352 | loss.backward()
353 | optimizer.step()
354 | optimizer.zero_grad()
355 | train_loss_epoch.append(loss.item())
356 | train_loss_epoch_mean = np.mean(train_loss_epoch)
357 | return train_loss_epoch_mean
358 |
359 | def train_posenet(args, train_dl, val_dl, model, epochs, optimizer, loss_func, scheduler, device, early_stopping):
360 | writer = SummaryWriter()
361 | model_log = tqdm(total=0, position=1, bar_format='{desc}')
362 | for epoch in tqdm(range(epochs), desc='epochs'):
363 |
364 | # train 1 epoch
365 | train_loss = train_on_epoch(args, train_dl, model, optimizer, loss_func, device)
366 | writer.add_scalar("Loss/train", train_loss, epoch)
367 |
368 | # validate every epoch
369 | val_loss = eval_on_epoch(args, val_dl, model, optimizer, loss_func, device)
370 | writer.add_scalar("Loss/val", val_loss, epoch)
371 |
372 | # reduce LR on plateau
373 | scheduler.step(val_loss)
374 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch)
375 |
376 | # logging
377 | tqdm.write('At epoch {0:6d} : train loss: {1:.4f}, val loss: {2:.4f}'.format(epoch, train_loss, val_loss))
378 |
379 | # check wether to early stop
380 | early_stopping(val_loss, model, epoch=epoch, save_multiple=(not args.no_save_multiple), save_all=args.save_all_ckpt)
381 | if early_stopping.early_stop:
382 | print("Early stopping")
383 | break
384 |
385 | model_log.set_description_str(f'Best val loss: {early_stopping.val_loss_min:.4f}')
386 |
387 | if epoch % args.i_eval == 0:
388 | get_error_in_q(args, val_dl, model, len(val_dl.dataset), device, batch_size=1)
389 |
390 |
391 | writer.flush()
392 |
--------------------------------------------------------------------------------
/script/dm/prepare_data.py:
--------------------------------------------------------------------------------
1 | import utils.set_sys_path
2 | import torch
3 | from torch.utils.data import TensorDataset, DataLoader
4 | from torchvision import transforms
5 | import numpy as np
6 |
7 | from dataset_loaders.load_llff import load_llff_data
8 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
9 |
10 | def prepare_data(args, images, poses_train, i_split):
11 | ''' prepare data for ready to train posenet, return dataloaders '''
12 | #TODO: Convert GPU friendly data generator later: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel
13 | #TODO: Probably a better implementation style here: https://github.com/PyTorchLightning/pytorch-lightning
14 |
15 | i_train, i_val, i_test = i_split
16 |
17 | img_train = torch.Tensor(images[i_train]).permute(0, 3, 1, 2) # now shape is [N, CH, H, W]
18 | pose_train = torch.Tensor(poses_train[i_train])
19 |
20 | trainset = TensorDataset(img_train, pose_train)
21 | train_dl = DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
22 |
23 | img_val = torch.Tensor(images[i_val]).permute(0, 3, 1, 2) # now shape is [N, CH, H, W]
24 | pose_val = torch.Tensor(poses_train[i_val])
25 |
26 | valset = TensorDataset(img_val, pose_val)
27 | val_dl = DataLoader(valset)
28 |
29 | img_test = torch.Tensor(images[i_test]).permute(0, 3, 1, 2) # now shape is [N, CH, H, W]
30 | pose_test = torch.Tensor(poses_train[i_test])
31 |
32 | testset = TensorDataset(img_test, pose_test)
33 | test_dl = DataLoader(testset)
34 |
35 | return train_dl, val_dl, test_dl
36 |
37 | def load_dataset(args):
38 | ''' load posenet training data '''
39 | if args.dataset_type == 'llff':
40 | if args.no_bd_factor:
41 | bd_factor = None
42 | else:
43 | bd_factor = 0.75
44 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
45 | recenter=True, bd_factor=bd_factor,
46 | spherify=args.spherify)
47 |
48 | hwf = poses[0,:3,-1]
49 | poses = poses[:,:3,:4]
50 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
51 | if not isinstance(i_test, list):
52 | i_test = [i_test]
53 |
54 | if args.llffhold > 0:
55 | print('Auto LLFF holdout,', args.llffhold)
56 | i_test = np.arange(images.shape[0])[::args.llffhold]
57 |
58 | i_val = i_test
59 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
60 | (i not in i_test and i not in i_val)])
61 |
62 | print('DEFINING BOUNDS')
63 | if args.no_ndc:
64 | near = np.ndarray.min(bds) * .9
65 | far = np.ndarray.max(bds) * 1.
66 |
67 | else:
68 | near = 0.
69 | far = 1.
70 |
71 | if args.finetune_unlabel:
72 | i_train = i_test
73 | i_split = [i_train, i_val, i_test]
74 | else:
75 | print('Unknown dataset type', args.dataset_type, 'exiting')
76 | return
77 |
78 | poses_train = poses[:,:3,:].reshape((poses.shape[0],12)) # get rid of last row [0,0,0,1]
79 | print("images.shape {}, poses_train.shape {}".format(images.shape, poses_train.shape))
80 |
81 | INPUT_SHAPE = images[0].shape
82 | print("=====================================================================")
83 | print("INPUT_SHAPE:", INPUT_SHAPE)
84 | return images, poses_train, render_poses, hwf, i_split, near, far
--------------------------------------------------------------------------------
/script/feature/dfnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from typing import List
6 |
7 | # VGG-16 Layer Names and Channels
8 | vgg16_layers = {
9 | "conv1_1": 64,
10 | "relu1_1": 64,
11 | "conv1_2": 64,
12 | "relu1_2": 64,
13 | "pool1": 64,
14 | "conv2_1": 128,
15 | "relu2_1": 128,
16 | "conv2_2": 128,
17 | "relu2_2": 128,
18 | "pool2": 128,
19 | "conv3_1": 256,
20 | "relu3_1": 256,
21 | "conv3_2": 256,
22 | "relu3_2": 256,
23 | "conv3_3": 256,
24 | "relu3_3": 256,
25 | "pool3": 256,
26 | "conv4_1": 512,
27 | "relu4_1": 512,
28 | "conv4_2": 512,
29 | "relu4_2": 512,
30 | "conv4_3": 512,
31 | "relu4_3": 512,
32 | "pool4": 512,
33 | "conv5_1": 512,
34 | "relu5_1": 512,
35 | "conv5_2": 512,
36 | "relu5_2": 512,
37 | "conv5_3": 512,
38 | "relu5_3": 512,
39 | "pool5": 512,
40 | }
41 |
42 | class AdaptLayers(nn.Module):
43 | """Small adaptation layers.
44 | """
45 |
46 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
47 | """Initialize one adaptation layer for every extraction point.
48 |
49 | Args:
50 | hypercolumn_layers: The list of the hypercolumn layer names.
51 | output_dim: The output channel dimension.
52 | """
53 | super(AdaptLayers, self).__init__()
54 | self.layers = []
55 | channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers]
56 | for i, l in enumerate(channel_sizes):
57 | layer = nn.Sequential(
58 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
59 | nn.ReLU(),
60 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
61 | nn.BatchNorm2d(output_dim),
62 | )
63 | self.layers.append(layer)
64 | self.add_module("adapt_layer_{}".format(i), layer) # ex: adapt_layer_0
65 |
66 | def forward(self, features: List[torch.tensor]):
67 | """Apply adaptation layers. # here is list of three levels of features
68 | """
69 |
70 | for i, _ in enumerate(features):
71 | features[i] = getattr(self, "adapt_layer_{}".format(i))(features[i])
72 | return features
73 |
74 | class DFNet(nn.Module):
75 | ''' DFNet implementation '''
76 | default_conf = {
77 | 'hypercolumn_layers': ["conv1_2", "conv3_3", "conv5_3"],
78 | 'output_dim': 128,
79 | }
80 | mean = [0.485, 0.456, 0.406]
81 | std = [0.229, 0.224, 0.225]
82 |
83 | def __init__(self, feat_dim=12, places365_model_path=''):
84 | super(DFNet, self).__init__()
85 |
86 | self.layer_to_index = {k: v for v, k in enumerate(vgg16_layers.keys())}
87 | self.hypercolumn_indices = [self.layer_to_index[n] for n in self.default_conf['hypercolumn_layers']] # [2, 14, 28]
88 |
89 | # Initialize architecture
90 | vgg16 = models.vgg16(pretrained=True)
91 |
92 | self.encoder = nn.Sequential(*list(vgg16.features.children()))
93 |
94 | self.scales = []
95 | current_scale = 0
96 | for i, layer in enumerate(self.encoder):
97 | if isinstance(layer, torch.nn.MaxPool2d):
98 | current_scale += 1
99 | if i in self.hypercolumn_indices:
100 | self.scales.append(2**current_scale)
101 |
102 | ## adaptation layers, see off branches from fig.3 in S2DNet paper
103 | self.adaptation_layers = AdaptLayers(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim'])
104 |
105 | # pose regression layers
106 | self.avgpool = nn.AdaptiveAvgPool2d(1)
107 | self.fc_pose = nn.Linear(512, feat_dim)
108 |
109 | def forward(self, x, return_feature=False, isSingleStream=False, return_pose=True, upsampleH=240, upsampleW=427):
110 | '''
111 | inference DFNet. It can regress camera pose as well as extract intermediate layer features.
112 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream
113 | :param return_feature: whether to return features as output
114 | :param isSingleStream: whether it's an single stream inference or siamese network inference
115 | :param upsampleH: feature upsample size H
116 | :param upsampleW: feature upsample size W
117 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None
118 | :return predict: [2B, 12] or [B, 12]
119 | '''
120 | # normalize input data
121 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std)
122 | x = (x - mean[:, None, None]) / std[:, None, None]
123 |
124 | ### encoder ###
125 | feature_maps = []
126 | for i in range(len(self.encoder)):
127 | x = self.encoder[i](x)
128 |
129 | if i in self.hypercolumn_indices:
130 | feature = x.clone()
131 | feature_maps.append(feature)
132 |
133 | if i==self.hypercolumn_indices[-1]:
134 | if return_pose==False:
135 | predict = None
136 | break
137 |
138 | ### extract and process intermediate features ###
139 | if return_feature:
140 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer
141 |
142 | if isSingleStream: # not siamese network style inference
143 | feature_stacks = []
144 | for f in feature_maps:
145 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f))
146 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W])
147 | else: # siamese network style inference
148 | feature_stacks_t = []
149 | feature_stacks_r = []
150 | for f in feature_maps:
151 | # split real and nerf batches
152 | batch = f.shape[0] # should be target batch_size + rgb batch_size
153 | feature_t = f[:batch//2]
154 | feature_r = f[batch//2:]
155 |
156 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img
157 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img
158 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W]
159 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W]
160 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W])
161 | else:
162 | feature_maps = None
163 |
164 | if return_pose==False:
165 | return feature_maps, predict
166 |
167 | ### pose regression head ###
168 | x = self.avgpool(x)
169 | x = x.reshape(x.size(0), -1)
170 | predict = self.fc_pose(x)
171 |
172 | return feature_maps, predict
173 |
174 | class DFNet_s(nn.Module):
175 | ''' A slight accelerated version of DFNet, we experimentally found this version's performance is similar to original DFNet but inferences faster '''
176 | default_conf = {
177 | 'hypercolumn_layers': ["conv1_2"],
178 | 'output_dim': 128,
179 | }
180 | mean = [0.485, 0.456, 0.406]
181 | std = [0.229, 0.224, 0.225]
182 |
183 | def __init__(self, feat_dim=12, places365_model_path=''):
184 | super(DFNet_s, self).__init__()
185 |
186 | self.layer_to_index = {k: v for v, k in enumerate(vgg16_layers.keys())}
187 | self.hypercolumn_indices = [self.layer_to_index[n] for n in self.default_conf['hypercolumn_layers']] # [2, 14, 28]
188 |
189 | # Initialize architecture
190 | vgg16 = models.vgg16(pretrained=True)
191 |
192 | self.encoder = nn.Sequential(*list(vgg16.features.children()))
193 |
194 | self.scales = []
195 | current_scale = 0
196 | for i, layer in enumerate(self.encoder):
197 | if isinstance(layer, torch.nn.MaxPool2d):
198 | current_scale += 1
199 | if i in self.hypercolumn_indices:
200 | self.scales.append(2**current_scale)
201 |
202 | ## adaptation layers, see off branches from fig.3 in S2DNet paper
203 | self.adaptation_layers = AdaptLayers(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim'])
204 |
205 | # pose regression layers
206 | self.avgpool = nn.AdaptiveAvgPool2d(1)
207 | self.fc_pose = nn.Linear(512, feat_dim)
208 |
209 | def forward(self, x, return_feature=False, isSingleStream=False, return_pose=True, upsampleH=240, upsampleW=427):
210 | '''
211 | inference DFNet_s. It can regress camera pose as well as extract intermediate layer features.
212 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream
213 | :param return_feature: whether to return features as output
214 | :param isSingleStream: whether it's an single stream inference or siamese network inference
215 | :param upsampleH: feature upsample size H
216 | :param upsampleW: feature upsample size W
217 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None
218 | :return predict: [2B, 12] or [B, 12]
219 | '''
220 |
221 | # normalize input data
222 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std)
223 | x = (x - mean[:, None, None]) / std[:, None, None]
224 |
225 | ### encoder ###
226 | feature_maps = []
227 | for i in range(len(self.encoder)):
228 | x = self.encoder[i](x)
229 |
230 | if i in self.hypercolumn_indices:
231 | feature = x.clone()
232 | feature_maps.append(feature)
233 |
234 | if i==self.hypercolumn_indices[-1]:
235 | if return_pose==False:
236 | predict = None
237 | break
238 |
239 | ### extract and process intermediate features ###
240 | if return_feature:
241 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer
242 |
243 | if isSingleStream: # not siamese network style inference
244 | feature_stacks = []
245 | for f in feature_maps:
246 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f))
247 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W])
248 | else: # siamese network style inference
249 | feature_stacks_t = []
250 | feature_stacks_r = []
251 | for f in feature_maps:
252 | # split real and nerf batches
253 | batch = f.shape[0] # should be target batch_size + rgb batch_size
254 | feature_t = f[:batch//2]
255 | feature_r = f[batch//2:]
256 |
257 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img
258 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img
259 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W]
260 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W]
261 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W])
262 | else:
263 | feature_maps = None
264 |
265 | if return_pose==False:
266 | return feature_maps, predict
267 |
268 | ### pose regression head ###
269 | x = self.avgpool(x)
270 | x = x.reshape(x.size(0), -1)
271 | predict = self.fc_pose(x)
272 |
273 | return feature_maps, predict
--------------------------------------------------------------------------------
/script/feature/efficientnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import copy, pdb
5 | from typing import List
6 | from efficientnet_pytorch import EfficientNet
7 |
8 | # efficientnet-B3 Layer Names and Channels
9 | EB3_layers = {
10 | "reduction_1": 24, # torch.Size([2, 24, 120, 213])
11 | "reduction_2": 32, # torch.Size([2, 32, 60, 106])
12 | "reduction_3": 48, # torch.Size([2, 48, 30, 53])
13 | "reduction_4": 136, # torch.Size([2, 136, 15, 26])
14 | "reduction_5": 384, # torch.Size([2, 384, 8, 13])
15 | "reduction_6": 1536, # torch.Size([2, 1536, 8, 13])
16 | }
17 |
18 | # efficientnet-B0 Layer Names and Channels
19 | EB0_layers = {
20 | "reduction_1": 16, # torch.Size([2, 16, 120, 213])
21 | "reduction_2": 24, # torch.Size([2, 24, 60, 106])
22 | "reduction_3": 40, # torch.Size([2, 40, 30, 53])
23 | "reduction_4": 112, # torch.Size([2, 112, 15, 26])
24 | "reduction_5": 320, # torch.Size([2, 320, 8, 13])
25 | "reduction_6": 1280, # torch.Size([2, 1280, 8, 13])
26 | }
27 |
28 | class AdaptLayers(nn.Module):
29 | """Small adaptation layers.
30 | """
31 |
32 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
33 | """Initialize one adaptation layer for every extraction point.
34 |
35 | Args:
36 | hypercolumn_layers: The list of the hypercolumn layer names.
37 | output_dim: The output channel dimension.
38 | """
39 | super(AdaptLayers, self).__init__()
40 | self.layers = []
41 | channel_sizes = [EB3_layers[name] for name in hypercolumn_layers]
42 | for i, l in enumerate(channel_sizes):
43 | layer = nn.Sequential(
44 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
45 | nn.ReLU(),
46 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
47 | nn.BatchNorm2d(output_dim),
48 | )
49 | self.layers.append(layer)
50 | self.add_module("adapt_layer_{}".format(i), layer) # ex: adapt_layer_0
51 |
52 | def forward(self, features: List[torch.tensor]):
53 | """Apply adaptation layers. # here is list of three levels of features
54 | """
55 |
56 | for i, _ in enumerate(features):
57 | features[i] = getattr(self, "adapt_layer_{}".format(i))(features[i])
58 | return features
59 |
60 | class EfficientNetB3(nn.Module):
61 | ''' DFNet with EB3 backbone '''
62 | default_conf = {
63 | # 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_6"],
64 | 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_5"],
65 | # 'hypercolumn_layers': ["reduction_2", "reduction_4", "reduction_6"],
66 | 'output_dim': 128,
67 | }
68 | mean = [0.485, 0.456, 0.406]
69 | std = [0.229, 0.224, 0.225]
70 |
71 | def __init__(self, feat_dim=12, places365_model_path=''):
72 | super(EfficientNetB3, self).__init__()
73 | # Initialize architecture
74 | self.backbone_net = EfficientNet.from_pretrained('efficientnet-b3')
75 | self.feature_extractor = self.backbone_net.extract_endpoints
76 |
77 | # self.feature_block_index = [1, 3, 6] # same as the 'hypercolumn_layers'
78 | self.feature_block_index = [1, 3, 5] # same as the 'hypercolumn_layers'
79 | # self.feature_block_index = [2, 4, 6] # same as the 'hypercolumn_layers'
80 |
81 | ## adaptation layers, see off branches from fig.3 in S2DNet paper
82 | self.adaptation_layers = AdaptLayers(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim'])
83 |
84 | # pose regression layers
85 | self.avgpool = nn.AdaptiveAvgPool2d(1)
86 | self.fc_pose = nn.Linear(1536, feat_dim)
87 |
88 | def forward(self, x, return_feature=False, isSingleStream=False, upsampleH=120, upsampleW=213):
89 | '''
90 | inference DFNet. It can regress camera pose as well as extract intermediate layer features.
91 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream
92 | :param return_feature: whether to return features as output
93 | :param isSingleStream: whether it's an single stream inference or siamese network inference
94 | :param upsampleH: feature upsample size H
95 | :param upsampleW: feature upsample size W
96 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None
97 | :return predict: [2B, 12] or [B, 12]
98 | '''
99 | # normalize input data
100 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std)
101 | x = (x - mean[:, None, None]) / std[:, None, None]
102 |
103 | ### encoder ###
104 | feature_maps = []
105 | list_x = self.feature_extractor(x)
106 |
107 | x = list_x['reduction_6'] # features to save
108 | for i in self.feature_block_index:
109 | fe = list_x['reduction_'+str(i)].clone()
110 | feature_maps.append(fe)
111 |
112 | ### extract and process intermediate features ###
113 | if return_feature:
114 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer
115 |
116 | pdb.set_trace()
117 | if isSingleStream: # not siamese network style inference
118 | feature_stacks = []
119 |
120 | for f in feature_maps:
121 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f))
122 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W])
123 | else: # siamese network style inference
124 | feature_stacks_t = []
125 | feature_stacks_r = []
126 |
127 | for f in feature_maps:
128 | # split real and nerf batches
129 | batch = f.shape[0] # should be target batch_size + rgb batch_size
130 | feature_t = f[:batch//2]
131 | feature_r = f[batch//2:]
132 |
133 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img
134 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img
135 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W]
136 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W]
137 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W])
138 |
139 | else:
140 | feature_maps = None
141 |
142 | ### pose regression head ###
143 | x = self.avgpool(x)
144 | x = x.reshape(x.size(0), -1)
145 | predict = self.fc_pose(x)
146 |
147 | return feature_maps, predict
148 |
149 | class AdaptLayers2(nn.Module):
150 | """Small adaptation layers.
151 | """
152 |
153 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
154 | """Initialize one adaptation layer for every extraction point.
155 |
156 | Args:
157 | hypercolumn_layers: The list of the hypercolumn layer names.
158 | output_dim: The output channel dimension.
159 | """
160 | super(AdaptLayers2, self).__init__()
161 | self.layers = []
162 | channel_sizes = [EB0_layers[name] for name in hypercolumn_layers]
163 | for i, l in enumerate(channel_sizes):
164 | layer = nn.Sequential(
165 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
166 | nn.ReLU(),
167 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
168 | nn.BatchNorm2d(output_dim),
169 | )
170 | self.layers.append(layer)
171 | self.add_module("adapt_layer_{}".format(i), layer) # ex: adapt_layer_0
172 |
173 | def forward(self, features: List[torch.tensor]):
174 | """Apply adaptation layers. # here is list of three levels of features
175 | """
176 |
177 | for i, _ in enumerate(features):
178 | features[i] = getattr(self, "adapt_layer_{}".format(i))(features[i])
179 | return features
180 |
181 | class EfficientNetB0(nn.Module):
182 | ''' DFNet with EB0 backbone, feature levels can be customized '''
183 | default_conf = {
184 | # 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_6"],
185 | 'hypercolumn_layers': ["reduction_1", "reduction_3", "reduction_5"],
186 | # 'hypercolumn_layers': ["reduction_2", "reduction_4", "reduction_6"],
187 | # 'hypercolumn_layers': ["reduction_1"],
188 | 'output_dim': 128,
189 | }
190 | mean = [0.485, 0.456, 0.406]
191 | std = [0.229, 0.224, 0.225]
192 |
193 | def __init__(self, feat_dim=12, places365_model_path=''):
194 | super(EfficientNetB0, self).__init__()
195 | # Initialize architecture
196 | self.backbone_net = EfficientNet.from_pretrained('efficientnet-b0')
197 | self.feature_extractor = self.backbone_net.extract_endpoints
198 |
199 | # self.feature_block_index = [1, 3, 6] # same as the 'hypercolumn_layers'
200 | self.feature_block_index = [1, 3, 5] # same as the 'hypercolumn_layers'
201 | # self.feature_block_index = [2, 4, 6] # same as the 'hypercolumn_layers'
202 | # self.feature_block_index = [1]
203 |
204 | ## adaptation layers, see off branches from fig.3 in S2DNet paper
205 | self.adaptation_layers = AdaptLayers2(self.default_conf['hypercolumn_layers'], self.default_conf['output_dim'])
206 |
207 | # pose regression layers
208 | self.avgpool = nn.AdaptiveAvgPool2d(1)
209 | self.fc_pose = nn.Linear(1280, feat_dim)
210 |
211 | def forward(self, x, return_feature=False, isSingleStream=False, return_pose=False, upsampleH=120, upsampleW=213):
212 | '''
213 | inference DFNet. It can regress camera pose as well as extract intermediate layer features.
214 | :param x: image blob (2B x C x H x W) two stream or (B x C x H x W) single stream
215 | :param return_feature: whether to return features as output
216 | :param isSingleStream: whether it's an single stream inference or siamese network inference
217 | :param return_pose: TODO: if only return_pose, we don't need to compute return_feature part
218 | :param upsampleH: feature upsample size H
219 | :param upsampleW: feature upsample size W
220 | :return feature_maps: (2, [B, C, H, W]) or (1, [B, C, H, W]) or None
221 | :return predict: [2B, 12] or [B, 12]
222 | '''
223 | # normalize input data
224 | mean, std = x.new_tensor(self.mean), x.new_tensor(self.std)
225 | x = (x - mean[:, None, None]) / std[:, None, None]
226 |
227 | ### encoder ###
228 | feature_maps = []
229 | list_x = self.feature_extractor(x)
230 |
231 | x = list_x['reduction_6'] # features to save
232 | for i in self.feature_block_index:
233 | fe = list_x['reduction_'+str(i)].clone()
234 | feature_maps.append(fe)
235 |
236 | ### extract and process intermediate features ###
237 | if return_feature:
238 | feature_maps = self.adaptation_layers(feature_maps) # (3, [B, C, H', W']), H', W' are different in each layer
239 |
240 | if isSingleStream: # not siamese network style inference
241 | feature_stacks = []
242 |
243 | for f in feature_maps:
244 | feature_stacks.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(f))
245 | feature_maps = [torch.stack(feature_stacks)] # (1, [3, B, C, H, W])
246 | else: # siamese network style inference
247 | feature_stacks_t = []
248 | feature_stacks_r = []
249 |
250 | for f in feature_maps:
251 |
252 | # split real and nerf batches
253 | batch = f.shape[0] # should be target batch_size + rgb batch_size
254 | feature_t = f[:batch//2]
255 | feature_r = f[batch//2:]
256 |
257 | feature_stacks_t.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_t)) # GT img
258 | feature_stacks_r.append(torch.nn.UpsamplingBilinear2d(size=(upsampleH, upsampleW))(feature_r)) # render img
259 | feature_stacks_t = torch.stack(feature_stacks_t) # [3, B, C, H, W]
260 | feature_stacks_r = torch.stack(feature_stacks_r) # [3, B, C, H, W]
261 | feature_maps = [feature_stacks_t, feature_stacks_r] # (2, [3, B, C, H, W])
262 |
263 | else:
264 | feature_maps = None
265 |
266 | ### pose regression head ###
267 | x = self.avgpool(x)
268 | x = x.reshape(x.size(0), -1)
269 | predict = self.fc_pose(x)
270 |
271 | return feature_maps, predict
272 |
273 | def main():
274 | """
275 | test model
276 | """
277 | from torchsummary import summary
278 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
279 | feat_model = EfficientNetB3()
280 | # feat_model.load_state_dict(torch.load(''))
281 | feat_model.to(device)
282 | summary(feat_model, (3, 240, 427))
283 |
284 | if __name__ == '__main__':
285 | main()
286 |
--------------------------------------------------------------------------------
/script/feature/options.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | def config_parser():
3 | parser = configargparse.ArgumentParser()
4 | parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
5 | parser.add_argument("--multi_gpu", action='store_true', help='use multiple gpu on the server')
6 | parser.add_argument('--config', is_config_file=True, help='config file path')
7 | parser.add_argument("--expname", type=str, help='experiment name')
8 | parser.add_argument("--basedir", type=str, default='../logs', help='where to store ckpts and logs')
9 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
10 | parser.add_argument("--places365_model_path", type=str, default='', help='ckpt path of places365 pretrained model')
11 |
12 | # 7Scenes
13 | parser.add_argument("--trainskip", type=int, default=1, help='will load 1/N images from train sets, useful for large datasets like 7 Scenes')
14 | parser.add_argument("--df", type=float, default=1., help='image downscale factor')
15 | parser.add_argument("--reduce_embedding", type=int, default=-1, help='fourier embedding mode: -1: paper default, \
16 | 0: reduce by half, 1: remove embedding, 2: DNeRF embedding')
17 | parser.add_argument("--epochToMaxFreq", type=int, default=-1, help='DNeRF embedding mode: (based on Nerfie paper): \
18 | hyper-parameter for when α should reach the maximum number of frequencies m')
19 | parser.add_argument("--render_pose_only", action='store_true', help='render a spiral video for 7 Scene')
20 | parser.add_argument("--save_pose_avg_stats", action='store_true', help='save a pose avg stats to unify NeRF, posenet, direct-pn training')
21 | parser.add_argument("--load_pose_avg_stats", action='store_true', help='load precomputed pose avg stats to unify NeRF, posenet, nerf tracking training')
22 | parser.add_argument("--train_local_nerf", type=int, default=-1, help='train local NeRF with ith training sequence only, ie. Stairs can pick 0~3')
23 | parser.add_argument("--render_video_train", action='store_true', help='render train set NeRF and save as video, make sure render_test is True')
24 | parser.add_argument("--render_video_test", action='store_true', help='render val set NeRF and save as video, make sure render_test is True')
25 | parser.add_argument("--frustum_overlap_th", type=float, help='frustsum overlap threshold')
26 | parser.add_argument("--no_DNeRF_viewdir", action='store_true', default=False, help='will not use DNeRF in viewdir encoding')
27 | parser.add_argument("--load_unique_view_stats", action='store_true', help='load unique views frame index')
28 | parser.add_argument("--finetune_unlabel", action='store_true', help='finetune unlabeled sequence like MapNet')
29 | parser.add_argument("--i_eval", type=int, default=20, help='frequency of eval posenet result')
30 | parser.add_argument("--save_all_ckpt", action='store_true', help='save all ckpts for each epoch')
31 | parser.add_argument("--val_on_psnr", action='store_true', default=False, help='EarlyStopping with max validation psnr')
32 |
33 | # NeRF training options
34 | parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
35 | parser.add_argument("--netwidth", type=int, default=128, help='channels per layer')
36 | parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
37 | parser.add_argument("--netwidth_fine", type=int, default=128, help='channels per layer in fine network')
38 | parser.add_argument("--N_rand", type=int, default=1536, help='batch size (number of random rays per gradient step)')
39 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
40 | parser.add_argument("--lrate_decay", type=float, default=250, help='exponential learning rate decay (in 1000 steps)')
41 | parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
42 | parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
43 | parser.add_argument("--no_batching", action='store_true', default=True, help='only take random rays from 1 image at a time')
44 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
45 | parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
46 |
47 | # NeRF-Hist training options
48 | parser.add_argument("--NeRFH", action='store_true', default=True, help='new implementation for NeRFH, please add --encode_hist')
49 | parser.add_argument("--N_vocab", type=int, default=1000,
50 | help='''number of vocabulary (number of images)
51 | in the dataset for nn.Embedding''')
52 | parser.add_argument("--fix_index", action='store_true', help='fix training frame index as 0')
53 | parser.add_argument("--encode_hist", default=False, action='store_true', help='encode histogram instead of frame index')
54 | parser.add_argument("--hist_bin", type=int, default=10, help='image histogram bin size')
55 | parser.add_argument("--in_channels_a", type=int, default=50, help='appearance embedding dimension, hist_bin*N_a when embedding histogram')
56 | parser.add_argument("--in_channels_t", type=int, default=20, help='transient embedding dimension, hist_bin*N_tau when embedding histogram')
57 | parser.add_argument("--svd_reg", default=False, action='store_true', help='use svd regularize output at training')
58 |
59 | # NeRF rendering options
60 | parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
61 | parser.add_argument("--N_importance", type=int, default=64,help='number of additional fine samples per ray')
62 | parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter')
63 | parser.add_argument("--use_viewdirs", action='store_true', default=True, help='use full 5D input instead of 3D')
64 | parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
65 | parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
66 | parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
67 | parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
68 | parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
69 | parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
70 | parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
71 | parser.add_argument("--no_grad_update", action='store_true', default=False, help='do not update nerf in training')
72 | parser.add_argument("--tinyimg", action='store_true', default=False, help='render nerf img in a tiny scale image, this is a temporal compromise for direct feature matching, must FIX later')
73 | parser.add_argument("--tinyscale", type=float, default=4., help='determine the scale of downsizing nerf rendring, must FIX later')
74 |
75 | ##################### PoseNet Settings ########################
76 | parser.add_argument("--pose_only", type=int, default=1, help='posenet type to train, \
77 | 1: train baseline posenet, 2: posenet+nerf manual optimize, \
78 | 3: VLocNet, 4: DGRNet')
79 | parser.add_argument("--learning_rate", type=float, default=0.0001, help='learning rate')
80 | parser.add_argument("--batch_size", type=int, default=1, help='dataloader batch size, Attention: this is NOT the actual training batch size, \
81 | please use --featurenet_batch_size for training')
82 | parser.add_argument("--featurenet_batch_size", type=int, default=8, help='featurenet training batch size, choose smaller batch size')
83 | parser.add_argument("--pretrain_model_path", type=str, default='', help='model path of pretrained model')
84 | parser.add_argument("--model_name", type=str, help='pose model output folder name')
85 | parser.add_argument("--combine_loss_w", nargs='+', type=float, default=[1, 1, 1],
86 | help='weights of combined loss ex, [1, 1, 1], \
87 | default None, only use when combine_loss is True')
88 | parser.add_argument("--patience", nargs='+', type=int, default=[200, 50], help='set training schedule for patience [EarlyStopping, reduceLR]')
89 | parser.add_argument("--resize_factor", type=int, default=2, help='image resize downsample ratio')
90 | parser.add_argument("--freezeBN", action='store_true', help='Freeze the Batch Norm layer at training PoseNet')
91 | parser.add_argument("--preprocess_ImgNet", action='store_true', help='Normalize input data for PoseNet')
92 | parser.add_argument("--eval", action='store_true', help='eval model')
93 | parser.add_argument("--no_save_multiple", action='store_true', help='default, save multiple posenet model, if true, save only one posenet model')
94 | parser.add_argument("--resnet34", action='store_true', default=False, help='use resnet34 backbone instead of mobilenetV2')
95 | parser.add_argument("--efficientnet", action='store_true', default=False, help='use efficientnet-b3 backbone instead of mobilenetV2')
96 | parser.add_argument("--dropout", type=float, default=0.5, help='dropout rate for resnet34 backbone')
97 | parser.add_argument("--DFNet", action='store_true', default=False, help='use DFNet')
98 | parser.add_argument("--DFNet_s", action='store_true', default=False, help='use accelerated DFNet, performance is similar to DFNet but slightly faster')
99 | parser.add_argument("--featurelossonly", action='store_true', default=False, help='only use feature loss to train feature extraction')
100 | parser.add_argument("--random_view_synthesis", action='store_true', default=False, help='add random view synthesis')
101 | parser.add_argument("--rvs_refresh_rate", type=int, default=2, help='re-synthesis new views per X epochs')
102 | parser.add_argument("--rvs_trans", type=float, default=5, help='jitter range for rvs on translation')
103 | parser.add_argument("--rvs_rotation", type=float, default=1.2, help='jitter range for rvs on rotation, this is in log_10 uniform range, log(15) = 1.2')
104 | parser.add_argument("--d_max", type=float, default=1, help='rvs bounds d_max')
105 | parser.add_argument("--val_batch_size", type=int, default=1, help='batch_size for validation, higher number leads to faster speed')
106 |
107 | # legacy mesh options
108 | parser.add_argument("--mesh_only", action='store_true', help='do not optimize, reload weights and save mesh to a file')
109 | parser.add_argument("--mesh_grid_size", type=int, default=80,help='number of grid points to sample in each dimension for marching cubes')
110 |
111 | # training options
112 | parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops')
113 | parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops')
114 | parser.add_argument("--epochs", type=int, default=2000,help='number of epochs to train')
115 | parser.add_argument("--poselossonly", action='store_true', help='eval model')
116 | parser.add_argument("--tripletloss", action='store_true', help='use triplet loss at training featurenet, this is to prevent catastophic failing')
117 | parser.add_argument("--triplet_margin", type=float,default=1., help='triplet loss margin hyperparameter')
118 | parser.add_argument("--render_feature_only", action='store_true', default=False, help='render features and save to a path')
119 |
120 | # dataset options
121 | parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / 7Scenes')
122 | parser.add_argument("--testskip", type=int, default=1, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
123 |
124 | ## legacy blender flags
125 | parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
126 | parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
127 |
128 | ## llff flags
129 | parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
130 | parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
131 | parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
132 | parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
133 | parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
134 | parser.add_argument("--no_bd_factor", action='store_true', default=False, help='do not use bd factor')
135 |
136 | # logging/saving options
137 | parser.add_argument("--i_print", type=int, default=1, help='frequency of console printout and metric loggin')
138 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
139 | parser.add_argument("--i_weights", type=int, default=200, help='frequency of weight ckpt saving')
140 | parser.add_argument("--i_testset", type=int, default=200, help='frequency of testset saving')
141 | parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
142 |
143 | return parser
--------------------------------------------------------------------------------
/script/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ActiveVisionLab/DFNet/2c8fa7e324f8d17352ed469a8b793e0167e4c592/script/models/__init__.py
--------------------------------------------------------------------------------
/script/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import pdb
4 |
5 | class ColorLoss(nn.Module):
6 | def __init__(self, coef=1):
7 | super().__init__()
8 | self.coef = coef
9 | self.loss = nn.MSELoss(reduction='mean')
10 |
11 | def forward(self, inputs, targets):
12 | loss = self.loss(inputs['rgb_coarse'], targets)
13 | if 'rgb_fine' in inputs:
14 | loss += self.loss(inputs['rgb_fine'], targets)
15 |
16 | return self.coef * loss
17 |
18 |
19 | class NerfWLoss(nn.Module):
20 | """
21 | Equation 13 in the NeRF-W paper.
22 | Name abbreviations:
23 | c_l: coarse color loss
24 | f_l: fine color loss (1st term in equation 13)
25 | b_l: beta loss (2nd term in equation 13)
26 | s_l: sigma loss (3rd term in equation 13)
27 | targets # [N, 3]
28 | inputs['rgb_coarse'] # [N, 3]
29 | inputs['rgb_fine'] # [N, 3]
30 | inputs['beta'] # [N]
31 | inputs['transient_sigmas'] # [N, 2*N_Samples]
32 | :return:
33 | """
34 | def __init__(self, coef=1, lambda_u=0.01):
35 | """
36 | lambda_u: in equation 13
37 | """
38 | super().__init__()
39 | self.coef = coef
40 | self.lambda_u = lambda_u
41 |
42 | def forward(self, inputs, targets, use_hier_rgbs=False, rgb_h=None, rgb_w=None):
43 |
44 | ret = {}
45 | ret['c_l'] = 0.5 * ((inputs['rgb_coarse']-targets)**2).mean()
46 | if 'rgb_fine' in inputs:
47 | if 'beta' not in inputs: # no transient head, normal MSE loss
48 | ret['f_l'] = 0.5 * ((inputs['rgb_fine']-targets)**2).mean()
49 | else:
50 | ret['f_l'] = ((inputs['rgb_fine']-targets)**2/(2*inputs['beta'].unsqueeze(1)**2)).mean()
51 | ret['b_l'] = 3 + torch.log(inputs['beta']).mean() # +3 to make it positive
52 | ret['s_l'] = self.lambda_u * inputs['transient_sigmas'].mean()
53 |
54 | for k, v in ret.items():
55 | ret[k] = self.coef * v
56 |
57 | return ret
58 |
59 | loss_dict = {'color': ColorLoss,
60 | 'nerfw': NerfWLoss}
--------------------------------------------------------------------------------
/script/models/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia.losses import ssim as dssim
3 |
4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'):
5 | value = (image_pred-image_gt)**2
6 | if valid_mask is not None:
7 | value = value[valid_mask]
8 | if reduction == 'mean':
9 | return torch.mean(value)
10 | return value
11 |
12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'):
13 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction))
14 |
15 | def ssim(image_pred, image_gt, reduction='mean'):
16 | """
17 | image_pred and image_gt: (1, 3, H, W)
18 | """
19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1]
20 | return 1-2*dssim_ # in [-1, 1]
--------------------------------------------------------------------------------
/script/models/nerf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | torch.autograd.set_detect_anomaly(True)
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | # Misc
9 | img2mse = lambda x, y : torch.mean((x - y) ** 2)
10 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
12 |
13 | def batchify(fn, chunk):
14 | """Constructs a version of 'fn' that applies to smaller batches.
15 | """
16 | if chunk is None:
17 | return fn
18 | def ret(inputs):
19 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
20 | return ret
21 |
22 |
23 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
24 | """Prepares inputs and applies network 'fn'.
25 | """
26 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
27 | embedded = embed_fn(inputs_flat)
28 |
29 | if viewdirs is not None:
30 | input_dirs = viewdirs[:,None].expand(inputs.shape)
31 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
32 | embedded_dirs = embeddirs_fn(input_dirs_flat)
33 | embedded = torch.cat([embedded, embedded_dirs], -1)
34 |
35 | outputs_flat = batchify(fn, netchunk)(embedded)
36 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
37 | return outputs
38 |
39 | def run_network_DNeRF(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64, epoch=None, no_DNeRF_viewdir=False):
40 | """Prepares inputs and applies network 'fn'.
41 | """
42 | if epoch<0 or epoch==None:
43 | print("Error: run_network_DNeRF(): Invalid epoch")
44 | sys.exit()
45 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
46 | embedded = embed_fn(inputs_flat, epoch)
47 | # add weighted function here
48 | if viewdirs is not None:
49 | input_dirs = viewdirs[:,None].expand(inputs.shape)
50 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
51 |
52 | if no_DNeRF_viewdir:
53 | embedded_dirs = embeddirs_fn(input_dirs_flat)
54 | else:
55 | embedded_dirs = embeddirs_fn(input_dirs_flat, epoch)
56 | embedded = torch.cat([embedded, embedded_dirs], -1)
57 |
58 | outputs_flat = batchify(fn, netchunk)(embedded)
59 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
60 | return outputs
61 |
62 |
63 | # Positional encoding (section 5.1)
64 | class Embedder:
65 | def __init__(self, **kwargs):
66 | self.kwargs = kwargs
67 | self.N_freqs = 0
68 | self.N = -1 # epoch to max frequency, for Nerfie embedding only
69 | self.create_embedding_fn()
70 |
71 | def create_embedding_fn(self):
72 | embed_fns = []
73 | d = self.kwargs['input_dims']
74 | out_dim = 0
75 | if self.kwargs['include_input']:
76 | embed_fns.append(lambda x : x)
77 | out_dim += d
78 |
79 | max_freq = self.kwargs['max_freq_log2']
80 | self.N_freqs = self.kwargs['num_freqs']
81 |
82 | if self.kwargs['log_sampling']:
83 | freq_bands = 2.**torch.linspace(0., max_freq, steps=self.N_freqs) # tensor([ 1., 2., 4., 8., 16., 32., 64., 128., 256., 512.])
84 | else:
85 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=self.N_freqs)
86 |
87 | for freq in freq_bands: # 10 iters for 3D location, 4 iters for 2D direction
88 | for p_fn in self.kwargs['periodic_fns']:
89 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
90 | out_dim += d
91 | self.embed_fns = embed_fns
92 | self.out_dim = out_dim
93 |
94 | def embed(self, inputs):
95 | # inputs [65536, 3]
96 | if self.kwargs['max_freq_log2'] != 0:
97 | ret = torch.cat([fn(inputs) for fn in self.embed_fns], -1) # cos, sin embedding # ret.shape [65536, 63]
98 | else:
99 | ret = inputs
100 | return ret
101 |
102 | def get_embed_weight(self, epoch, num_freqs, N):
103 | ''' Nerfie Paper Eq.(8) '''
104 | alpha = num_freqs * epoch / N
105 | W_j = []
106 | for i in range(num_freqs):
107 | tmp = torch.clamp(torch.Tensor([alpha - i]), 0, 1)
108 | tmp2 = (1 - torch.cos(torch.Tensor([np.pi]) * tmp)) / 2
109 | W_j.append(tmp2)
110 | return W_j
111 |
112 | def embed_DNeRF(self, inputs, epoch):
113 | ''' Nerfie paper section 3.5 Coarse-to-Fine Deformation Regularization '''
114 | # get weight for each frequency band j
115 | W_j = self.get_embed_weight(epoch, self.N_freqs, self.N) # W_j: [W_0, W_1, W_2, ..., W_{m-1}]
116 |
117 | # Fourier embedding
118 | out = []
119 | for fn in self.embed_fns: # 17, embed_fns:[input, cos, sin, cos, sin, ..., cos, sin]
120 | out.append(fn(inputs))
121 |
122 | # apply weighted positional encoding, only to cos&sins
123 | for i in range(len(W_j)):
124 | out[2*i+1] = W_j[i] * out[2*i+1]
125 | out[2*i+2] = W_j[i] * out[2*i+2]
126 | ret = torch.cat(out, -1)
127 | return ret
128 |
129 | def update_N(self, N):
130 | self.N=N
131 |
132 |
133 | def get_embedder(multires, i=0, reduce_mode=-1, epochToMaxFreq=-1):
134 | if i == -1:
135 | return nn.Identity(), 3
136 |
137 | if reduce_mode == 0:
138 | # reduce embedding
139 | embed_kwargs = {
140 | 'include_input' : True,
141 | 'input_dims' : 3,
142 | 'max_freq_log2' : (multires-1)//2,
143 | 'num_freqs' : multires//2,
144 | 'log_sampling' : True,
145 | 'periodic_fns' : [torch.sin, torch.cos],
146 | }
147 | elif reduce_mode == 1:
148 | # remove embedding
149 | embed_kwargs = {
150 | 'include_input' : True,
151 | 'input_dims' : 3,
152 | 'max_freq_log2' : 0,
153 | 'num_freqs' : 0,
154 | 'log_sampling' : True,
155 | 'periodic_fns' : [torch.sin, torch.cos],
156 | }
157 | elif reduce_mode == 2:
158 | # DNeRF embedding
159 | embed_kwargs = {
160 | 'include_input' : True,
161 | 'input_dims' : 3,
162 | 'max_freq_log2' : multires-1,
163 | 'num_freqs' : multires,
164 | 'log_sampling' : True,
165 | 'periodic_fns' : [torch.sin, torch.cos],
166 | }
167 | else:
168 | # paper default
169 | embed_kwargs = {
170 | 'include_input' : True,
171 | 'input_dims' : 3,
172 | 'max_freq_log2' : multires-1,
173 | 'num_freqs' : multires,
174 | 'log_sampling' : True,
175 | 'periodic_fns' : [torch.sin, torch.cos],
176 | }
177 |
178 | embedder_obj = Embedder(**embed_kwargs)
179 | if reduce_mode == 2:
180 | embedder_obj.update_N(epochToMaxFreq)
181 | embed = lambda x, epoch, eo=embedder_obj: eo.embed_DNeRF(x, epoch)
182 | else:
183 | embed = lambda x, eo=embedder_obj : eo.embed(x)
184 | return embed, embedder_obj.out_dim, embedder_obj# 63 for pos, 27 for view dir
185 |
186 | # Model
187 | class NeRF(nn.Module):
188 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
189 | """
190 | """
191 | super(NeRF, self).__init__()
192 | self.D = D
193 | self.W = W
194 | self.input_ch = input_ch
195 | self.input_ch_views = input_ch_views
196 | self.skips = skips
197 | self.use_viewdirs = use_viewdirs
198 |
199 | self.pts_linears = nn.ModuleList(
200 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
201 |
202 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
203 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
204 |
205 | ### Implementation according to the NeRF paper
206 | # self.views_linears = nn.ModuleList(
207 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
208 |
209 | if use_viewdirs:
210 | self.feature_linear = nn.Linear(W, W)
211 | self.alpha_linear = nn.Linear(W, 1)
212 | self.rgb_linear = nn.Linear(W//2, 3)
213 | else:
214 | self.output_linear = nn.Linear(W, output_ch)
215 |
216 | def forward(self, x):
217 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
218 | h = input_pts
219 | for i, l in enumerate(self.pts_linears):
220 | h = self.pts_linears[i](h)
221 | h = F.relu(h)
222 | if i in self.skips:
223 | h = torch.cat([input_pts, h], -1)
224 |
225 | if self.use_viewdirs:
226 | alpha = self.alpha_linear(h)
227 | feature = self.feature_linear(h)
228 | h = torch.cat([feature, input_views], -1)
229 |
230 | for i, l in enumerate(self.views_linears):
231 | h = self.views_linears[i](h)
232 | h = F.relu(h)
233 |
234 | rgb = self.rgb_linear(h)
235 | outputs = torch.cat([rgb, alpha], -1)
236 | else:
237 | outputs = self.output_linear(h)
238 |
239 | return outputs
240 |
241 | def load_weights_from_keras(self, weights):
242 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
243 |
244 | # Load pts_linears
245 | for i in range(self.D):
246 | idx_pts_linears = 2 * i
247 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
248 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
249 |
250 | # Load feature_linear
251 | idx_feature_linear = 2 * self.D
252 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
253 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
254 |
255 | # Load views_linears
256 | idx_views_linears = 2 * self.D + 2
257 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
258 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
259 |
260 | # Load rgb_linear
261 | idx_rbg_linear = 2 * self.D + 4
262 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
263 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
264 |
265 | # Load alpha_linear
266 | idx_alpha_linear = 2 * self.D + 6
267 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
268 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
269 |
270 | def create_nerf(args):
271 | """Instantiate NeRF's MLP model.
272 | """
273 | if args.reduce_embedding==2: # use DNeRF embedding
274 | embed_fn, input_ch, embedder_obj = get_embedder(args.multires, args.i_embed, args.reduce_embedding, args.epochToMaxFreq) # input_ch.shape=63
275 | else:
276 | embed_fn, input_ch, _ = get_embedder(args.multires, args.i_embed, args.reduce_embedding) # input_ch.shape=63
277 |
278 | input_ch_views = 0
279 | embeddirs_fn = None
280 | if args.use_viewdirs:
281 | if args.reduce_embedding==2: # use DNeRF embedding
282 | if args.no_DNeRF_viewdir: # no DNeRF embedding for viewdir
283 | embeddirs_fn, input_ch_views, _ = get_embedder(args.multires_views, args.i_embed)
284 | else:
285 | embeddirs_fn, input_ch_views, embedddirs_obj = get_embedder(args.multires_views, args.i_embed, args.reduce_embedding, args.epochToMaxFreq)
286 | else:
287 | embeddirs_fn, input_ch_views, _ = get_embedder(args.multires_views, args.i_embed, args.reduce_embedding) # input_ch_views.shape=27
288 | output_ch = 5 if args.N_importance > 0 else 4
289 | skips = [4]
290 | model = NeRF(D=args.netdepth, W=args.netwidth, input_ch=input_ch, output_ch=output_ch, skips=skips,
291 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
292 | device = torch.device("cuda")
293 | if args.multi_gpu:
294 | model = torch.nn.DataParallel(model).to(device)
295 | else:
296 | model = model.to(device)
297 | grad_vars = list(model.parameters())
298 |
299 | model_fine = None
300 | if args.N_importance > 0:
301 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, input_ch=input_ch, output_ch=output_ch, skips=skips,
302 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
303 | if args.multi_gpu:
304 | model_fine = torch.nn.DataParallel(model_fine).to(device)
305 | else:
306 | model_fine = model_fine.to(device)
307 | grad_vars += list(model_fine.parameters())
308 |
309 | if args.reduce_embedding==2: # use DNeRF embedding
310 | network_query_fn = lambda inputs, viewdirs, network_fn, epoch: run_network_DNeRF(inputs, viewdirs, network_fn,
311 | embed_fn=embed_fn,
312 | embeddirs_fn=embeddirs_fn,
313 | netchunk=args.netchunk,
314 | epoch=epoch, no_DNeRF_viewdir=args.no_DNeRF_viewdir)
315 | else:
316 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
317 | embed_fn=embed_fn,
318 | embeddirs_fn=embeddirs_fn,
319 | netchunk=args.netchunk)
320 |
321 | # Create optimizer
322 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
323 |
324 | start = 0
325 | basedir = args.basedir
326 | expname = args.expname
327 |
328 | # Load checkpoints
329 | if args.ft_path is not None and args.ft_path!='None':
330 | ckpts = [args.ft_path]
331 | else:
332 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
333 |
334 | print('Found ckpts', ckpts)
335 | if len(ckpts) > 0 and not args.no_reload:
336 | ckpt_path = ckpts[-1]
337 | print('Reloading from', ckpt_path)
338 | ckpt = torch.load(ckpt_path)
339 |
340 | start = ckpt['global_step']
341 | optimizer.load_state_dict(ckpt['optimizer_state_dict'])
342 | # Load model
343 | model.load_state_dict(ckpt['network_fn_state_dict'])
344 |
345 | if model_fine is not None:
346 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
347 |
348 | ##########################
349 |
350 | render_kwargs_train = {
351 | 'network_query_fn' : network_query_fn,
352 | 'perturb' : args.perturb,
353 | 'N_importance' : args.N_importance,
354 | 'network_fine' : model_fine,
355 | 'N_samples' : args.N_samples,
356 | 'network_fn' : model,
357 | 'use_viewdirs' : args.use_viewdirs,
358 | 'white_bkgd' : args.white_bkgd,
359 | 'raw_noise_std' : args.raw_noise_std,
360 | }
361 |
362 | # NDC only good for LLFF-style forward facing data
363 | if args.dataset_type != 'llff' or args.no_ndc:
364 | print('Not ndc!')
365 | render_kwargs_train['ndc'] = False
366 | render_kwargs_train['lindisp'] = args.lindisp
367 |
368 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
369 | render_kwargs_test['perturb'] = False
370 | render_kwargs_test['raw_noise_std'] = 0.
371 |
372 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
--------------------------------------------------------------------------------
/script/models/options.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | def config_parser():
3 | parser = configargparse.ArgumentParser()
4 | parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
5 | parser.add_argument("--device", type=int, default=-1, help='CUDA_VISIBLE_DEVICES')
6 | parser.add_argument("--multi_gpu", action='store_true', help='use multiple gpu on the server')
7 | parser.add_argument('--config', is_config_file=True, help='config file path')
8 | parser.add_argument("--expname", type=str, help='experiment name')
9 | parser.add_argument("--basedir", type=str, default='../logs', help='where to store ckpts and logs')
10 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
11 |
12 | # 7Scenes
13 | parser.add_argument("--trainskip", type=int, default=1, help='will load 1/N images from train sets, useful for large datasets like 7 Scenes')
14 | parser.add_argument("--df", type=float, default=1., help='image downscale factor')
15 | parser.add_argument("--reduce_embedding", type=int, default=-1, help='fourier embedding mode: -1: paper default, \
16 | 0: reduce by half, 1: remove embedding, 2: DNeRF embedding')
17 | parser.add_argument("--epochToMaxFreq", type=int, default=-1, help='DNeRF embedding mode: (based on Nerfie paper): \
18 | hyper-parameter for when α should reach the maximum number of frequencies m')
19 | parser.add_argument("--render_pose_only", action='store_true', help='render a spiral video for 7 Scene')
20 | parser.add_argument("--save_pose_avg_stats", action='store_true', help='save a pose avg stats to unify NeRF, posenet, direct-pn training')
21 | parser.add_argument("--load_pose_avg_stats", action='store_true', help='load precomputed pose avg stats to unify NeRF, posenet, nerf tracking training')
22 | parser.add_argument("--train_local_nerf", type=int, default=-1, help='train local NeRF with ith training sequence only, ie. Stairs can pick 0~3')
23 | parser.add_argument("--render_video_train", action='store_true', help='render train set NeRF and save as video, make sure render_test is True')
24 | parser.add_argument("--render_video_test", action='store_true', help='render val set NeRF and save as video, make sure render_test is True')
25 | parser.add_argument("--frustum_overlap_th", type=float, help='frustsum overlap threshold')
26 | parser.add_argument("--no_DNeRF_viewdir", action='store_true', default=False, help='will not use DNeRF in viewdir encoding')
27 | parser.add_argument("--load_unique_view_stats", action='store_true', help='load unique views frame index')
28 |
29 | # NeRF training options
30 | parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
31 | parser.add_argument("--netwidth", type=int, default=128, help='channels per layer')
32 | parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
33 | parser.add_argument("--netwidth_fine", type=int, default=128, help='channels per layer in fine network')
34 | parser.add_argument("--N_rand", type=int, default=1536, help='batch size (number of random rays per gradient step)')
35 | parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
36 | parser.add_argument("--lrate_decay", type=float, default=250, help='exponential learning rate decay (in 1000 steps)')
37 | parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
38 | parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
39 | parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')
40 | parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
41 | parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
42 | parser.add_argument("--no_grad_update", action='store_true', default=False, help='do not update nerf in training')
43 |
44 | # NeRF-Hist training options
45 | parser.add_argument("--NeRFH", action='store_true', help='my implementation for NeRFH, to enable NeRF-Hist training, please make sure to add --encode_hist, otherwise it is similar to NeRFW')
46 | parser.add_argument("--N_vocab", type=int, default=1000,
47 | help='''number of vocabulary (number of images)
48 | in the dataset for nn.Embedding''')
49 | parser.add_argument("--fix_index", action='store_true', help='fix training frame index as 0')
50 | parser.add_argument("--encode_hist", default=False, action='store_true', help='encode histogram instead of frame index')
51 | parser.add_argument("--hist_bin", type=int, default=10, help='image histogram bin size')
52 | parser.add_argument("--in_channels_a", type=int, default=50, help='appearance embedding dimension, hist_bin*N_a when embedding histogram')
53 | parser.add_argument("--in_channels_t", type=int, default=20, help='transient embedding dimension, hist_bin*N_tau when embedding histogram')
54 |
55 | # NeRF rendering options
56 | parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
57 | parser.add_argument("--N_importance", type=int, default=64,help='number of additional fine samples per ray')
58 | parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter')
59 | parser.add_argument("--use_viewdirs", default=True, action='store_true', help='use full 5D input instead of 3D')
60 | parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
61 | parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
62 | parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
63 | parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
64 | parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
65 | parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
66 | parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
67 |
68 | # legacy mesh options
69 | parser.add_argument("--mesh_only", action='store_true', help='do not optimize, reload weights and save mesh to a file')
70 | parser.add_argument("--mesh_grid_size", type=int, default=80,help='number of grid points to sample in each dimension for marching cubes')
71 |
72 | # training options
73 | parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops')
74 | parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops')
75 | parser.add_argument("--epochs", type=int, default=600,help='number of epochs to train')
76 |
77 | # dataset options
78 | parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / 7Scenes')
79 | parser.add_argument("--testskip", type=int, default=1, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
80 |
81 | ## legacy blender flags
82 | parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
83 | # parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
84 |
85 | ## llff flags
86 | parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
87 | parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
88 | parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
89 | parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
90 | parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
91 | parser.add_argument("--no_bd_factor", action='store_true', default=False, help='do not use bd factor')
92 |
93 | # logging/saving options
94 | parser.add_argument("--i_print", type=int, default=1, help='frequency of console printout and metric loggin')
95 | parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
96 | parser.add_argument("--i_weights", type=int, default=200, help='frequency of weight ckpt saving')
97 | parser.add_argument("--i_testset", type=int, default=200, help='frequency of testset saving')
98 | parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
99 |
100 | return parser
--------------------------------------------------------------------------------
/script/models/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | # Ray helpers
5 | def get_rays(H, W, focal, c2w):
6 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H), indexing='ij') # pytorch's meshgrid has indexing='ij'
7 | i = i.t()
8 | j = j.t()
9 | dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
10 |
11 | # Rotate ray directions from camera frame to the world frame
12 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
13 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
14 | rays_o = c2w[:3,-1].expand(rays_d.shape)
15 | return rays_o, rays_d # rays_o (100,100,3), rays_d (100,100,3)
16 |
17 |
18 | def get_rays_np(H, W, focal, c2w):
19 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
20 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
21 | # Rotate ray directions from camera frame to the world frame
22 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
23 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
24 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
25 | return rays_o, rays_d
26 |
27 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
28 | # Shift ray origins to near plane
29 | t = -(near + rays_o[...,2]) / rays_d[...,2] # t_n = −(n + o_z)/d_z move o to the ray's intersection with near plane
30 | rays_o = rays_o + t[...,None] * rays_d
31 |
32 | # Projection Formular (20)
33 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
34 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
35 | o2 = 1. + 2. * near / rays_o[...,2]
36 | # Formular (21)
37 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
38 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
39 | d2 = -2. * near / rays_o[...,2]
40 |
41 | rays_o = torch.stack([o0,o1,o2], -1) # o'
42 | rays_d = torch.stack([d0,d1,d2], -1) # d'
43 |
44 | return rays_o, rays_d
--------------------------------------------------------------------------------
/script/run_nerf.py:
--------------------------------------------------------------------------------
1 | import utils.set_sys_path
2 | import os, sys
3 | import numpy as np
4 | import imageio
5 | import json
6 | import random
7 | import time
8 | import torch
9 | import torch.nn as nn
10 | # from torch.utils.tensorboard import SummaryWriter
11 | from tqdm import tqdm, trange
12 |
13 | from models.ray_utils import *
14 | from models.nerfw import * # NeRF-w and NeRF-hist
15 | from models.options import config_parser
16 | from models.rendering import *
17 | from dataset_loaders.load_7Scenes import load_7Scenes_dataloader_NeRF
18 | from dataset_loaders.load_Cambridge import load_Cambridge_dataloader_NeRF
19 |
20 | # losses
21 | from models.losses import loss_dict
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 | np.random.seed(0)
25 | torch.manual_seed(0)
26 | import random
27 | random.seed(0)
28 |
29 | parser = config_parser()
30 | args = parser.parse_args()
31 |
32 | def train_on_epoch_nerfw(args, train_dl, H, W, focal, N_rand, optimizer, loss_func, global_step, render_kwargs_train):
33 | for batch_idx, (target, pose, img_idx) in enumerate(train_dl):
34 | target = target[0].permute(1,2,0).to(device)
35 | pose = pose.reshape(3,4).to(device) # reshape to 3x4 rot matrix
36 | img_idx = img_idx.to(device)
37 |
38 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
39 | if N_rand is not None:
40 | rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
41 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W), indexing='ij'), -1) # (H, W, 2)
42 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
43 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
44 | select_coords = coords[select_inds].long() # (N_rand, 2)
45 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
46 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
47 | batch_rays = torch.stack([rays_o, rays_d], 0)
48 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
49 |
50 | # ##### Core optimization loop #####
51 | rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, retraw=True, img_idx=img_idx, **render_kwargs_train)
52 | optimizer.zero_grad()
53 |
54 | # compute loss
55 | results = {}
56 | results['rgb_fine'] = rgb
57 | results['rgb_coarse'] = extras['rgb0']
58 | results['beta'] = extras['beta']
59 | results['transient_sigmas'] = extras['transient_sigmas']
60 |
61 | loss_d = loss_func(results, target_s)
62 | loss = sum(l for l in loss_d.values())
63 |
64 | with torch.no_grad():
65 | img_loss = img2mse(rgb, target_s)
66 | psnr = mse2psnr(img_loss)
67 | loss.backward()
68 | optimizer.step()
69 |
70 | # NOTE: IMPORTANT!
71 | ### update learning rate ###
72 | decay_rate = 0.1
73 | decay_steps = args.lrate_decay * 1000
74 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
75 | for param_group in optimizer.param_groups:
76 | param_group['lr'] = new_lrate
77 | ################################
78 |
79 | torch.set_default_tensor_type('torch.FloatTensor')
80 | return loss, psnr
81 |
82 | def train_nerf(args, train_dl, val_dl, hwf, i_split, near, far, render_poses=None, render_img=None):
83 |
84 | i_train, i_val, i_test = i_split
85 | # Cast intrinsics to right types
86 | H, W, focal = hwf
87 | H, W = int(H), int(W)
88 | hwf = [H, W, focal]
89 |
90 | # Create log dir and copy the config file
91 | basedir = args.basedir
92 | expname = args.expname
93 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
94 | f = os.path.join(basedir, expname, 'args.txt')
95 | with open(f, 'w') as file:
96 | for arg in sorted(vars(args)):
97 | attr = getattr(args, arg)
98 | file.write('{} = {}\n'.format(arg, attr))
99 | if args.config is not None:
100 | f = os.path.join(basedir, expname, 'config.txt')
101 | with open(f, 'w') as file:
102 | file.write(open(args.config, 'r').read())
103 |
104 | # Create nerf model
105 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
106 | global_step = start
107 |
108 | bds_dict = {
109 | 'near' : near,
110 | 'far' : far,
111 | }
112 | render_kwargs_train.update(bds_dict)
113 | render_kwargs_test.update(bds_dict)
114 | if args.reduce_embedding==2:
115 | render_kwargs_train['i_epoch'] = -1
116 | render_kwargs_test['i_epoch'] = -1
117 |
118 | if args.render_test:
119 | print('TRAIN views are', i_train)
120 | print('TEST views are', i_test)
121 | print('VAL views are', i_val)
122 | if args.reduce_embedding==2:
123 | render_kwargs_test['i_epoch'] = global_step
124 | render_test(args, train_dl, val_dl, hwf, start, render_kwargs_test)
125 | return
126 |
127 | # Prepare raybatch tensor if batching random rays
128 | N_rand = args.N_rand
129 | # use_batching = not args.no_batching
130 |
131 | N_epoch = args.epochs + 1 # epoch
132 | print('Begin')
133 | print('TRAIN views are', i_train)
134 | print('TEST views are', i_test)
135 | print('VAL views are', i_val)
136 |
137 |
138 | # loss function
139 | loss_func = loss_dict['nerfw'](coef=1)
140 |
141 | for i in trange(start, N_epoch):
142 | time0 = time.time()
143 | if args.reduce_embedding==2:
144 | render_kwargs_train['i_epoch'] = i
145 | loss, psnr = train_on_epoch_nerfw(args, train_dl, H, W, focal, N_rand, optimizer, loss_func, global_step, render_kwargs_train)
146 | dt = time.time()-time0
147 | ##### end #####
148 |
149 | # Rest is logging
150 | if i%args.i_weights==0 and i!=0:
151 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
152 | if args.N_importance > 0: # have fine sample network
153 | torch.save({
154 | 'global_step': global_step,
155 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
156 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
157 | 'embedding_a_state_dict': render_kwargs_train['embedding_a'].state_dict(),
158 | 'embedding_t_state_dict': render_kwargs_train['embedding_t'].state_dict(),
159 | 'optimizer_state_dict': optimizer.state_dict(),
160 | }, path)
161 | else:
162 | torch.save({
163 | 'global_step': global_step,
164 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
165 | 'optimizer_state_dict': optimizer.state_dict(),
166 | }, path)
167 | print('Saved checkpoints at', path)
168 |
169 | if i%args.i_testset==0 and i > 0: # run thru all validation set
170 |
171 | # clean GPU memory before testing, try to avoid OOM
172 | torch.cuda.empty_cache()
173 |
174 | if args.reduce_embedding==2:
175 | render_kwargs_test['i_epoch'] = i
176 | trainsavedir = os.path.join(basedir, expname, 'trainset_{:06d}'.format(i))
177 | os.makedirs(trainsavedir, exist_ok=True)
178 | images_train = []
179 | poses_train = []
180 | index_train = []
181 | j_skip = 10 # save holdout view render result Trainset/j_skip
182 | # randomly choose some holdout views from training set
183 | for batch_idx, (img, pose, img_idx) in enumerate(train_dl):
184 | if batch_idx % j_skip != 0:
185 | continue
186 | img_val = img.permute(0,2,3,1) # (1,H,W,3)
187 | pose_val = torch.zeros(1,4,4)
188 | pose_val[0,:3,:4] = pose.reshape(3,4)[:3,:4] # (1,3,4))
189 | pose_val[0,3,3] = 1.
190 | images_train.append(img_val)
191 | poses_train.append(pose_val)
192 | index_train.append(img_idx)
193 | images_train = torch.cat(images_train, dim=0).numpy()
194 | poses_train = torch.cat(poses_train, dim=0).to(device)
195 | index_train = torch.cat(index_train, dim=0).to(device)
196 | print('train poses shape', poses_train.shape)
197 |
198 | with torch.no_grad():
199 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
200 | render_path(args, poses_train, hwf, args.chunk, render_kwargs_test, gt_imgs=images_train, savedir=trainsavedir, img_ids=index_train)
201 | torch.set_default_tensor_type('torch.FloatTensor')
202 | print('Saved train set')
203 | del images_train
204 | del poses_train
205 |
206 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
207 | os.makedirs(testsavedir, exist_ok=True)
208 | images_val = []
209 | poses_val = []
210 | index_val = []
211 | # views from validation set
212 | for img, pose, img_idx in val_dl:
213 | img_val = img.permute(0,2,3,1) # (1,H,W,3)
214 | pose_val = torch.zeros(1,4,4)
215 | pose_val[0,:3,:4] = pose.reshape(3,4)[:3,:4] # (1,3,4))
216 | pose_val[0,3,3] = 1.
217 | images_val.append(img_val)
218 | poses_val.append(pose_val)
219 | index_val.append(img_idx)
220 |
221 | images_val = torch.cat(images_val, dim=0).numpy()
222 | poses_val = torch.cat(poses_val, dim=0).to(device)
223 | index_val = torch.cat(index_val, dim=0).to(device)
224 | print('test poses shape', poses_val.shape)
225 |
226 | with torch.no_grad():
227 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
228 | render_path(args, poses_val, hwf, args.chunk, render_kwargs_test, gt_imgs=images_val, savedir=testsavedir, img_ids=index_val)
229 | torch.set_default_tensor_type('torch.FloatTensor')
230 | print('Saved test set')
231 |
232 | # clean GPU memory after testing
233 | torch.cuda.empty_cache()
234 | del images_val
235 | del poses_val
236 |
237 | if i%args.i_print==0:
238 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
239 |
240 | global_step += 1
241 |
242 | def train():
243 |
244 | print(parser.format_values())
245 |
246 | # Load data
247 | if args.dataset_type == '7Scenes':
248 |
249 | train_dl, val_dl, hwf, i_split, bds, render_poses, render_img = load_7Scenes_dataloader_NeRF(args)
250 | near = bds[0]
251 | far = bds[1]
252 |
253 | print('NEAR FAR', near, far)
254 | train_nerf(args, train_dl, val_dl, hwf, i_split, near, far, render_poses, render_img)
255 | return
256 |
257 | elif args.dataset_type == 'Cambridge':
258 |
259 | train_dl, val_dl, hwf, i_split, bds, render_poses, render_img = load_Cambridge_dataloader_NeRF(args)
260 | near = bds[0]
261 | far = bds[1]
262 |
263 | print('NEAR FAR', near, far)
264 | train_nerf(args, train_dl, val_dl, hwf, i_split, near, far, render_poses, render_img)
265 | return
266 |
267 | else:
268 | print('Unknown dataset type', args.dataset_type, 'exiting')
269 | return
270 |
271 | if __name__=='__main__':
272 |
273 | train()
--------------------------------------------------------------------------------
/script/train.py:
--------------------------------------------------------------------------------
1 | import utils.set_sys_path
2 | import os, sys
3 | import numpy as np
4 | import imageio
5 | import random
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from tqdm import tqdm, trange
12 | from torchsummary import summary
13 | # from torchinfo import summary
14 | import matplotlib.pyplot as plt
15 |
16 | from dm.pose_model import *
17 | from dm.direct_pose_model import *
18 | from dm.callbacks import EarlyStopping
19 | # from dm.prepare_data import prepare_data, load_dataset
20 | from dm.options import config_parser
21 | from models.rendering import render_path
22 | from models.nerfw import to8b
23 | from dataset_loaders.load_7Scenes import load_7Scenes_dataloader
24 | from dataset_loaders.load_Cambridge import load_Cambridge_dataloader
25 | from utils.utils import freeze_bn_layer
26 | from feature.direct_feature_matching import train_feature_matching
27 | # import torch.onnx
28 |
29 | parser = config_parser()
30 | args = parser.parse_args()
31 | device = torch.device('cuda:0') # this is really controlled in train.sh
32 |
33 | def render_test(args, train_dl, val_dl, hwf, start, model, device, render_kwargs_test):
34 | model.eval()
35 |
36 | # ### Eval Training set result
37 | if args.render_video_train:
38 | images_train = []
39 | poses_train = []
40 | # views from train set
41 | for img, pose in train_dl:
42 | predict_pose = inference_pose_regression(args, img, device, model)
43 | device_cpu = torch.device('cpu')
44 | predict_pose = predict_pose.to(device_cpu) # put predict pose back to cpu
45 |
46 | img_val = img.permute(0,2,3,1) # (1,240,320,3)
47 | pose_val = torch.zeros(1,4,4)
48 | pose_val[0,:3,:4] = predict_pose.reshape(3,4)[:3,:4] # (1,3,4))
49 | pose_val[0,3,3] = 1.
50 | images_train.append(img_val)
51 | poses_train.append(pose_val)
52 |
53 | images_train = torch.cat(images_train, dim=0).numpy()
54 | poses_train = torch.cat(poses_train, dim=0)
55 | print('train poses shape', poses_train.shape)
56 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
57 | with torch.no_grad():
58 | rgbs, disps = render_path(poses_train.to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images_train, savedir=None)
59 | torch.set_default_tensor_type('torch.FloatTensor')
60 | print('Saving trainset as video', rgbs.shape, disps.shape)
61 | moviebase = os.path.join(args.basedir, args.model_name, '{}_trainset_{:06d}_'.format(args.model_name, start))
62 | imageio.mimwrite(moviebase + 'train_rgb.mp4', to8b(rgbs), fps=15, quality=8)
63 | imageio.mimwrite(moviebase + 'train_disp.mp4', to8b(disps / np.max(disps)), fps=15, quality=8)
64 |
65 | ### Eval Validation set result
66 | if args.render_video_test:
67 | images_val = []
68 | poses_val = []
69 | # views from val set
70 | for img, pose in val_dl:
71 | predict_pose = inference_pose_regression(args, img, device, model)
72 | device_cpu = torch.device('cpu')
73 | predict_pose = predict_pose.to(device_cpu) # put predict pose back to cpu
74 |
75 | img_val = img.permute(0,2,3,1) # (1,240,360,3)
76 | pose_val = torch.zeros(1,4,4)
77 | pose_val[0,:3,:4] = predict_pose.reshape(3,4)[:3,:4] # (1,3,4))
78 | pose_val[0,3,3] = 1.
79 | images_val.append(img_val)
80 | poses_val.append(pose_val)
81 |
82 | images_val = torch.cat(images_val, dim=0).numpy()
83 | poses_val = torch.cat(poses_val, dim=0)
84 | print('test poses shape', poses_val.shape)
85 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
86 | with torch.no_grad():
87 | rgbs, disps = render_path(poses_val.to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images_val, savedir=None)
88 | torch.set_default_tensor_type('torch.FloatTensor')
89 | print('Saving testset as video', rgbs.shape, disps.shape)
90 | moviebase = os.path.join(args.basedir, args.model_name, '{}_test_{:06d}_'.format(args.model_name, start))
91 | imageio.mimwrite(moviebase + 'test_rgb.mp4', to8b(rgbs), fps=15, quality=8)
92 | imageio.mimwrite(moviebase + 'test_disp.mp4', to8b(disps / np.max(disps)), fps=15, quality=8)
93 | return
94 |
95 | def train():
96 | print(parser.format_values())
97 | # Load data
98 | if args.dataset_type == '7Scenes':
99 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_7Scenes_dataloader(args)
100 | elif args.dataset_type == 'Cambridge':
101 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_Cambridge_dataloader(args)
102 | else:
103 | print("please choose dataset_type: 7Scenes or Cambridge, exiting...")
104 | sys.exit()
105 |
106 | ### pose regression module, here requires a pretrained DFNet for Pose Estimator F
107 | assert(args.pretrain_model_path != '') # make sure to add a valid PATH using --pretrain_model_path
108 | # load pretrained DFNet model
109 | model = load_exisiting_model(args)
110 |
111 | if args.freezeBN:
112 | model = freeze_bn_layer(model)
113 | model.to(device)
114 |
115 | ### feature extraction module, here requires a pretrained DFNet for Feature Extractor G using --pretrain_featurenet_path
116 | if args.pretrain_featurenet_path == '':
117 | print('Use the same DFNet for Feature Extraction and Pose Regression')
118 | feat_model = load_exisiting_model(args)
119 | else:
120 | # you can optionally load different pretrained DFNet for feature extractor and pose estimator
121 | feat_model = load_exisiting_model(args, isFeatureNet=True)
122 |
123 | feat_model.eval()
124 | feat_model.to(device)
125 |
126 | # set optimizer
127 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) #weight_decay=weight_decay, **kwargs
128 |
129 | # set callbacks parameters
130 | early_stopping = EarlyStopping(args, patience=args.patience[0], verbose=False)
131 |
132 | # start training
133 | if args.dataset_type == '7Scenes':
134 | train_feature_matching(args, model, feat_model, optimizer, i_split, hwf, near, far, device, early_stopping, train_dl=train_dl, val_dl=val_dl, test_dl=test_dl)
135 | elif args.dataset_type == 'Cambridge':
136 | train_feature_matching(args, model, feat_model, optimizer, i_split, hwf, near, far, device, early_stopping, train_dl=train_dl, val_dl=val_dl, test_dl=test_dl)
137 |
138 | def eval():
139 | print(parser.format_values())
140 | # Load data
141 | if args.dataset_type == '7Scenes':
142 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_7Scenes_dataloader(args)
143 | elif args.dataset_type == 'Cambridge':
144 | train_dl, val_dl, test_dl, hwf, i_split, near, far = load_Cambridge_dataloader(args)
145 | else:
146 | print("please choose dataset_type: 7Scenes or Cambridge, exiting...")
147 | sys.exit()
148 |
149 | # load pretrained DFNet_dm model
150 | model = load_exisiting_model(args)
151 | if args.freezeBN:
152 | model = freeze_bn_layer(model)
153 | model.to(device)
154 |
155 | print(len(test_dl.dataset))
156 |
157 | get_error_in_q(args, test_dl, model, len(test_dl.dataset), device, batch_size=1)
158 |
159 | if __name__ == '__main__':
160 | if args.eval:
161 | torch.manual_seed(0)
162 | random.seed(0)
163 | np.random.seed(0)
164 | eval()
165 | else:
166 | train()
167 |
--------------------------------------------------------------------------------
/script/utils/set_sys_path.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
--------------------------------------------------------------------------------
/script/utils/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | helper functions to train robust feature extractors
3 | '''
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import datetime
9 | from torchvision.utils import make_grid
10 | import matplotlib.pyplot as plt
11 | import pdb
12 | from PIL import Image
13 | from torchvision.utils import save_image
14 | from math import pi
15 | import cv2
16 | # from pykalman import KalmanFilter
17 |
18 | def freeze_bn_layer(model):
19 | ''' freeze bn layer by not require grad but still behave differently when model.train() vs. model.eval() '''
20 | print("Freezing BatchNorm Layers...")
21 | for module in model.modules():
22 | if isinstance(module, nn.BatchNorm2d):
23 | # print("this is a BN layer:", module)
24 | if hasattr(module, 'weight'):
25 | module.weight.requires_grad_(False)
26 | if hasattr(module, 'bias'):
27 | module.bias.requires_grad_(False)
28 | return model
29 |
30 | def freeze_bn_layer_train(model):
31 | ''' set batchnorm to eval()
32 | it is useful to align train and testing result
33 | '''
34 | # model.train()
35 | # print("Freezing BatchNorm Layers...")
36 | for module in model.modules():
37 | if isinstance(module, nn.BatchNorm2d):
38 | module.eval()
39 | return model
40 |
41 | def save_image_saliancy(tensor, path, normalize: bool = False, scale_each: bool = False,):
42 | """
43 | Modification based on TORCHVISION.UTILS
44 | ::param: tensor (batch, channel, H, W)
45 | """
46 | # grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=32)
47 | grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=6)
48 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
49 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
50 | fig = plt.figure()
51 | plt.imshow(ndarr[:,:,0], cmap='jet') # viridis, plasma
52 | plt.axis('off')
53 | fig.savefig(path, bbox_inches='tight',dpi=fig.dpi,pad_inches=0.0)
54 | plt.close()
55 |
56 | def save_image_saliancy_single(tensor, path, normalize: bool = False, scale_each: bool = False,):
57 | """
58 | Modification based on TORCHVISION.UTILS, save single feature map
59 | ::param: tensor (batch, channel, H, W)
60 | """
61 | # grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=32)
62 | grid = make_grid(tensor.detach(), normalize=normalize, scale_each=scale_each, nrow=1)
63 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
64 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
65 | fig = plt.figure()
66 | # plt.imshow(ndarr[:,:,0], cmap='plasma') # viridis, jet
67 | plt.imshow(ndarr[:,:,0], cmap='jet') # viridis, jet
68 | plt.axis('off')
69 | fig.savefig(path, bbox_inches='tight',dpi=fig.dpi,pad_inches=0.0)
70 | plt.close()
71 |
72 | def print_feature_examples(features, path):
73 | """
74 | print feature maps
75 | ::param: features
76 | """
77 | kwargs = {'normalize' : True, } # 'scale_each' : True
78 |
79 | for i in range(len(features)):
80 | fn = path + '{}.png'.format(i)
81 | # save_image(features[i].permute(1,0,2,3), fn, **kwargs)
82 | save_image_saliancy(features[i].permute(1,0,2,3), fn, normalize=True)
83 | # pdb.set_trace()
84 | ###
85 |
86 | def plot_features(features, path='f', isList=True):
87 | """
88 | print feature maps
89 | :param features: (3, [batch, H, W]) or [3, batch, H, W]
90 | :param path: save image path
91 | :param isList: wether the features is an list
92 | :return:
93 | """
94 | kwargs = {'normalize' : True, } # 'scale_each' : True
95 |
96 | if isList:
97 | dim = features[0].dim()
98 | else:
99 | dim = features.dim()
100 | assert(dim==3 or dim==4)
101 |
102 | if dim==4 and isList:
103 | print_feature_examples(features, path)
104 | elif dim==4 and (isList==False):
105 | fn = path
106 | lvl, b, H, W = features.shape
107 | for i in range(features.shape[0]):
108 | fn = path + '{}.png'.format(i)
109 | save_image_saliancy(features[i][None,...].permute(1,0,2,3).cpu(), fn, normalize=True)
110 |
111 | # # concat everything
112 | # features = features.reshape([-1, H, W])
113 | # # save_image(features[None,...].permute(1,0,2,3).cpu(), fn, **kwargs)
114 | # save_image_saliancy(features[None,...].permute(1,0,2,3).cpu(), fn, normalize=True)
115 |
116 | elif dim==3 and isList: # print all images in the list
117 | for i in range(len(features)):
118 | fn = path + '{}.png'.format(i)
119 | # save_image(features[i][None,...].permute(1,0,2,3).cpu(), fn, **kwargs)
120 | save_image_saliancy(features[i][None,...].permute(1,0,2,3).cpu(), fn, normalize=True)
121 | elif dim==3 and (isList==False):
122 | fn = path
123 | save_image_saliancy(features[None,...].permute(1,0,2,3).cpu(), fn, normalize=True)
124 |
125 | def sample_homography_np(
126 | shape, shift=0, perspective=True, scaling=True, rotation=True, translation=True,
127 | n_scales=5, n_angles=25, scaling_amplitude=0.1, perspective_amplitude_x=0.1,
128 | perspective_amplitude_y=0.1, patch_ratio=0.5, max_angle=pi/2,
129 | allow_artifacts=False, translation_overflow=0.):
130 | """Sample a random valid homography.
131 |
132 | Computes the homography transformation between a random patch in the original image
133 | and a warped projection with the same image size.
134 | As in `tf.contrib.image.transform`, it maps the output point (warped patch) to a
135 | transformed input point (original patch).
136 | The original patch, which is initialized with a simple half-size centered crop, is
137 | iteratively projected, scaled, rotated and translated.
138 |
139 | Arguments:
140 | shape: A rank-2 `Tensor` specifying the height and width of the original image.
141 | perspective: A boolean that enables the perspective and affine transformations.
142 | scaling: A boolean that enables the random scaling of the patch.
143 | rotation: A boolean that enables the random rotation of the patch.
144 | translation: A boolean that enables the random translation of the patch.
145 | n_scales: The number of tentative scales that are sampled when scaling.
146 | n_angles: The number of tentatives angles that are sampled when rotating.
147 | scaling_amplitude: Controls the amount of scale.
148 | perspective_amplitude_x: Controls the perspective effect in x direction.
149 | perspective_amplitude_y: Controls the perspective effect in y direction.
150 | patch_ratio: Controls the size of the patches used to create the homography. (like crop size)
151 | max_angle: Maximum angle used in rotations.
152 | allow_artifacts: A boolean that enables artifacts when applying the homography.
153 | translation_overflow: Amount of border artifacts caused by translation.
154 |
155 | Returns:
156 | A `Tensor` of shape `[1, 8]` corresponding to the flattened homography transform.
157 | """
158 |
159 | # print("debugging")
160 |
161 |
162 | # Corners of the output image
163 | pts1 = np.stack([[0., 0.], [0., 1.], [1., 1.], [1., 0.]], axis=0)
164 | # Corners of the input patch
165 | margin = (1 - patch_ratio) / 2
166 | pts2 = margin + np.array([[0, 0], [0, patch_ratio],
167 | [patch_ratio, patch_ratio], [patch_ratio, 0]])
168 |
169 | from numpy.random import normal
170 | from numpy.random import uniform
171 | from scipy.stats import truncnorm
172 |
173 | # Random perspective and affine perturbations
174 | # lower, upper = 0, 2
175 | std_trunc = 2
176 | # pdb.set_trace()
177 | if perspective:
178 | if not allow_artifacts:
179 | perspective_amplitude_x = min(perspective_amplitude_x, margin)
180 | perspective_amplitude_y = min(perspective_amplitude_y, margin)
181 | perspective_displacement = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_y/2).rvs(1)
182 | h_displacement_left = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1)
183 | h_displacement_right = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1)
184 | pts2 += np.array([[h_displacement_left, perspective_displacement],
185 | [h_displacement_left, -perspective_displacement],
186 | [h_displacement_right, perspective_displacement],
187 | [h_displacement_right, -perspective_displacement]]).squeeze()
188 |
189 | # Random scaling
190 | # sample several scales, check collision with borders, randomly pick a valid one
191 | if scaling:
192 | scales = truncnorm(-1*std_trunc, std_trunc, loc=1, scale=scaling_amplitude/2).rvs(n_scales)
193 | scales = np.concatenate((np.array([1]), scales), axis=0)
194 |
195 | center = np.mean(pts2, axis=0, keepdims=True)
196 | scaled = (pts2 - center)[np.newaxis, :, :] * scales[:, np.newaxis, np.newaxis] + center
197 | if allow_artifacts:
198 | valid = np.arange(n_scales) # all scales are valid except scale=1
199 | else:
200 | valid = (scaled >= 0.) * (scaled < 1.)
201 | valid = valid.prod(axis=1).prod(axis=1)
202 | valid = np.where(valid)[0]
203 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int)
204 | pts2 = scaled[idx,:,:]
205 |
206 | # Random translation
207 | if translation:
208 | # pdb.set_trace()
209 | t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0)
210 | if allow_artifacts:
211 | t_min += translation_overflow
212 | t_max += translation_overflow
213 | pts2 += np.array([uniform(-t_min[0], t_max[0],1), uniform(-t_min[1], t_max[1], 1)]).T
214 |
215 | # Random rotation
216 | # sample several rotations, check collision with borders, randomly pick a valid one
217 | if rotation:
218 | angles = np.linspace(-max_angle, max_angle, num=n_angles)
219 | angles = np.concatenate((angles, np.array([0.])), axis=0) # in case no rotation is valid
220 | center = np.mean(pts2, axis=0, keepdims=True)
221 | rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles),
222 | np.cos(angles)], axis=1), [-1, 2, 2])
223 | rotated = np.matmul( (pts2 - center)[np.newaxis,:,:], rot_mat) + center
224 | if allow_artifacts:
225 | valid = np.arange(n_angles) # all scales are valid except scale=1
226 | else: # find multiple valid option and choose the valid one
227 | valid = (rotated >= 0.) * (rotated < 1.)
228 | valid = valid.prod(axis=1).prod(axis=1)
229 | valid = np.where(valid)[0]
230 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int)
231 | pts2 = rotated[idx,:,:]
232 |
233 | # Rescale to actual size
234 | shape = shape[::-1] # different convention [y, x]
235 | pts1 *= shape[np.newaxis,:]
236 | pts2 *= shape[np.newaxis,:]
237 |
238 | homography = cv2.getPerspectiveTransform(np.float32(pts1+shift), np.float32(pts2+shift))
239 | return homography
240 |
241 | def warp_points(points, homographies, device='cpu'):
242 | """
243 | Warp a list of points with the given homography.
244 |
245 | Arguments:
246 | points: list of N points, shape (N, 2(x, y))).
247 | homography: batched or not (shapes (B, 3, 3) and (...) respectively).
248 |
249 | Returns: a Tensor of shape (N, 2) or (B, N, 2(x, y)) (depending on whether the homography
250 | is batched) containing the new coordinates of the warped points.
251 |
252 | """
253 | # expand points len to (x, y, 1)
254 | no_batches = len(homographies.shape) == 2
255 | homographies = homographies.unsqueeze(0) if no_batches else homographies
256 |
257 | batch_size = homographies.shape[0]
258 | points = torch.cat((points.float(), torch.ones((points.shape[0], 1)).to(device)), dim=1)
259 | points = points.to(device)
260 | homographies = homographies.view(batch_size*3,3)
261 |
262 | warped_points = homographies@points.transpose(0,1)
263 |
264 | # normalize the points
265 | warped_points = warped_points.view([batch_size, 3, -1])
266 | warped_points = warped_points.transpose(2, 1)
267 | warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:]
268 | return warped_points[0,:,:] if no_batches else warped_points
269 |
270 | def inv_warp_image_batch(img, mat_homo_inv, device='cpu', mode='bilinear'):
271 | '''
272 | Inverse warp images in batch
273 |
274 | :param img:
275 | batch of images
276 | tensor [batch_size, 1, H, W]
277 | :param mat_homo_inv:
278 | batch of homography matrices
279 | tensor [batch_size, 3, 3]
280 | :param device:
281 | GPU device or CPU
282 | :return:
283 | batch of warped images
284 | tensor [batch_size, 1, H, W]
285 | '''
286 | # compute inverse warped points
287 | if len(img.shape) == 2 or len(img.shape) == 3:
288 | img = img.view(1,1,img.shape[0], img.shape[1])
289 | if len(mat_homo_inv.shape) == 2:
290 | mat_homo_inv = mat_homo_inv.view(1,3,3)
291 |
292 | Batch, channel, H, W = img.shape
293 | coor_cells = torch.stack(torch.meshgrid(torch.linspace(-1, 1, W), torch.linspace(-1, 1, H), indexing='ij'), dim=2)
294 | coor_cells = coor_cells.transpose(0, 1)
295 | coor_cells = coor_cells.to(device)
296 | coor_cells = coor_cells.contiguous()
297 |
298 | src_pixel_coords = warp_points(coor_cells.view([-1, 2]), mat_homo_inv, device)
299 | src_pixel_coords = src_pixel_coords.view([Batch, H, W, 2])
300 | src_pixel_coords = src_pixel_coords.float()
301 |
302 | warped_img = F.grid_sample(img, src_pixel_coords, mode=mode, align_corners=True)
303 | return warped_img
304 |
305 | def compute_valid_mask(image_shape, inv_homography, device='cpu', erosion_radius=0):
306 | """
307 | Compute a boolean mask of the valid pixels resulting from an homography applied to
308 | an image of a given shape. Pixels that are False correspond to bordering artifacts.
309 | A margin can be discarded using erosion.
310 |
311 | Arguments:
312 | input_shape: Tensor of rank 2 representing the image shape, i.e. `[H, W]`.
313 | homography: Tensor of shape (B, 8) or (8,), where B is the batch size.
314 | `erosion_radius: radius of the margin to be discarded.
315 |
316 | Returns: a Tensor of type `tf.int32` and shape (H, W).
317 | """
318 |
319 | if inv_homography.dim() == 2:
320 | inv_homography = inv_homography.view(-1, 3, 3)
321 | batch_size = inv_homography.shape[0]
322 | mask = torch.ones(batch_size, 1, image_shape[0], image_shape[1]).to(device)
323 | mask = inv_warp_image_batch(mask, inv_homography, device=device, mode='nearest')
324 | mask = mask.view(batch_size, image_shape[0], image_shape[1])
325 | mask = mask.cpu().numpy()
326 | if erosion_radius > 0:
327 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erosion_radius*2,)*2)
328 | for i in range(batch_size):
329 | mask[i, :, :] = cv2.erode(mask[i, :, :], kernel, iterations=1)
330 |
331 | return torch.tensor(mask).to(device)
332 |
333 | def Kalman1D(observations,damping=1):
334 | # To return the smoothed time series data
335 | observation_covariance = damping
336 | initial_value_guess = observations[0]
337 | transition_matrix = 1
338 | transition_covariance = 0.1
339 | initial_value_guess
340 | kf = KalmanFilter(
341 | initial_state_mean=initial_value_guess,
342 | initial_state_covariance=observation_covariance,
343 | observation_covariance=observation_covariance,
344 | transition_covariance=transition_covariance,
345 | transition_matrices=transition_matrix
346 | )
347 | pred_state, state_cov = kf.smooth(observations)
348 | return pred_state
349 |
350 | def Kalman3D(observations,damping=1):
351 | '''
352 | In:
353 | observation: Nx3
354 | Out:
355 | pred_state: Nx3
356 | '''
357 | # To return the smoothed time series data
358 | observation_covariance = damping
359 | transition_matrix = 1
360 | transition_covariance = 0.1
361 | initial_value_guess_x = observations[0,0]
362 | initial_value_guess_y = observations[0,1] # ?
363 | initial_value_guess_z = observations[0,2] # ?
364 |
365 | # perform 1D smooth for each axis
366 | kfx = KalmanFilter(
367 | initial_state_mean=initial_value_guess_x,
368 | initial_state_covariance=observation_covariance,
369 | observation_covariance=observation_covariance,
370 | transition_covariance=transition_covariance,
371 | transition_matrices=transition_matrix
372 | )
373 | pred_state_x, state_cov_x = kfx.smooth(observations[:, 0])
374 |
375 | kfy = KalmanFilter(
376 | initial_state_mean=initial_value_guess_y,
377 | initial_state_covariance=observation_covariance,
378 | observation_covariance=observation_covariance,
379 | transition_covariance=transition_covariance,
380 | transition_matrices=transition_matrix
381 | )
382 | pred_state_y, state_cov_y = kfy.smooth(observations[:, 1])
383 |
384 | kfz = KalmanFilter(
385 | initial_state_mean=initial_value_guess_z,
386 | initial_state_covariance=observation_covariance,
387 | observation_covariance=observation_covariance,
388 | transition_covariance=transition_covariance,
389 | transition_matrices=transition_matrix
390 | )
391 | pred_state_z, state_cov_z = kfy.smooth(observations[:, 2])
392 |
393 | pred_state = np.concatenate((pred_state_x, pred_state_y, pred_state_z), axis=1)
394 | return pred_state
--------------------------------------------------------------------------------