├── LICENSE
├── README.md
├── configs
├── Replica
│ ├── office0.yaml
│ ├── office1.yaml
│ ├── office2.yaml
│ ├── office3.yaml
│ ├── office4.yaml
│ ├── replica.yaml
│ ├── room0.yaml
│ ├── room1.yaml
│ └── room2.yaml
├── ScanNet
│ ├── scannet.yaml
│ ├── scene0000.yaml
│ ├── scene0002.yaml
│ ├── scene0005.yaml
│ ├── scene0012.yaml
│ ├── scene0050.yaml
│ ├── scene0059.yaml
│ ├── scene0084.yaml
│ ├── scene0106.yaml
│ ├── scene0169.yaml
│ ├── scene0181.yaml
│ ├── scene0207.yaml
│ ├── scene0472.yaml
│ ├── scene0580.yaml
│ └── scene0616.yaml
└── df_prior.yaml
├── environment.yaml
├── get_tsdf.py
├── pretrained
└── low_high.pt
├── run.py
├── scripts
├── download_cull_replica_mesh.sh
└── download_replica.sh
└── src
├── DF_Prior.py
├── Mapper.py
├── Tracker.py
├── __init__.py
├── __pycache__
├── DF_Prior.cpython-37.pyc
├── Mapper.cpython-37.pyc
├── Mapper.cpython-38.pyc
├── NICE_SLAM.cpython-37.pyc
├── NICE_SLAM.cpython-38.pyc
├── Tracker.cpython-37.pyc
├── Tracker.cpython-38.pyc
├── __init__.cpython-310.pyc
├── __init__.cpython-37.pyc
├── __init__.cpython-38.pyc
├── common.cpython-310.pyc
├── common.cpython-37.pyc
├── common.cpython-38.pyc
├── config.cpython-310.pyc
├── config.cpython-37.pyc
├── config.cpython-38.pyc
├── depth2pointcloud.cpython-37.pyc
├── fusion.cpython-310.pyc
├── fusion.cpython-37.pyc
└── fusion.cpython-38.pyc
├── common.py
├── config.py
├── conv_onet
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── config.cpython-310.pyc
│ ├── config.cpython-37.pyc
│ └── config.cpython-38.pyc
├── config.py
└── models
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── decoder.cpython-310.pyc
│ ├── decoder.cpython-37.pyc
│ └── decoder.cpython-38.pyc
│ └── decoder.py
├── fusion.py
├── tools
├── cull_mesh.py
├── eval_ate.py
├── eval_recon.py
└── evaluate_scannet.py
└── utils
├── Logger.py
├── Mesher.py
├── Renderer.py
├── Visualizer.py
├── __pycache__
├── Logger.cpython-37.pyc
├── Logger.cpython-38.pyc
├── Mesher.cpython-37.pyc
├── Mesher.cpython-38.pyc
├── Renderer.cpython-37.pyc
├── Renderer.cpython-38.pyc
├── Visualizer.cpython-37.pyc
├── Visualizer.cpython-38.pyc
├── datasets.cpython-37.pyc
└── datasets.cpython-38.pyc
└── datasets.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 MachinePerceptionLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Learning Neural Implicit through Volume Rendering with Attentive Depth Fusion Priors
4 |
5 | Pengchong Hu
6 | ·
7 | Zhizhong Han
8 |
9 |
10 | NeurIPS 2023
11 |
12 |
13 |
14 |
15 |
16 |
17 | Table of Contents
18 |
19 | -
20 | Installation
21 |
22 | -
23 | Dataset
24 |
25 | -
26 | Run
27 |
28 | -
29 | Evaluation
30 |
31 | -
32 | Acknowledgement
33 |
34 | -
35 | Citation
36 |
37 |
38 |
39 |
40 |
41 | ## Installation
42 | Please install all dependencies by following the instrutions here. You can use [anaconda](https://www.anaconda.com/) to finish the installation easily.
43 |
44 | You can bulid a conda environment called `df-prior`. Note that for linux users, you need to install `libopenexr-dev` first before building the environment.
45 |
46 | ```bash
47 | git clone https://github.com/MachinePerceptionLab/Attentive_DFPrior.git
48 | cd Attentive_DFPrior
49 |
50 | sudo apt-get install libopenexr-dev
51 |
52 | conda env create -f environment.yaml
53 | conda activate df-prior
54 | ```
55 |
56 | ## Dataset
57 |
58 | ### Replica
59 |
60 | Please download the Replica dataset generated by the authors of iMAP into `./Datasets/Replica` folder.
61 |
62 | ```bash
63 | bash scripts/download_replica.sh # Released by authors of NICE-SLAM
64 | ```
65 |
66 | ### ScanNet
67 |
68 | Please follow the data downloading procedure on [ScanNet](http://www.scan-net.org/) website, and extract color/depth frames from the `.sens` file using this [code](https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/reader.py).
69 |
70 |
71 | [Directory structure of ScanNet (click to expand)]
72 |
73 | DATAROOT is `./Datasets` by default. If a sequence (`sceneXXXX_XX`) is stored in other places, please change the `input_folder` path in the config file or in the command line.
74 |
75 | ```
76 | DATAROOT
77 | └── scannet
78 | └── scans
79 | └── scene0000_00
80 | └── frames
81 | ├── color
82 | │ ├── 0.jpg
83 | │ ├── 1.jpg
84 | │ ├── ...
85 | │ └── ...
86 | ├── depth
87 | │ ├── 0.png
88 | │ ├── 1.png
89 | │ ├── ...
90 | │ └── ...
91 | ├── intrinsic
92 | └── pose
93 | ├── 0.txt
94 | ├── 1.txt
95 | ├── ...
96 | └── ...
97 |
98 | ```
99 |
100 |
101 | ## Run
102 |
103 | To run our code, you first need to generate the TSDF volume and corresponding bounds. We provide the generated TSDF volume and bounds for Replica and ScanNet: replica_tsdf_volume.tar, scannet_tsdf_volume.tar.
104 |
105 | You also can generate the TSDF volume and corresponding bounds by using the following code:
106 |
107 | ```bash
108 | CUDA_VISIVLE_DEVICES=0 python get_tsdf.py configs/Replica/room0.yaml --space 1 # For Replica
109 | CUDA_VISIVLE_DEVICES=0 python get_tsdf.py configs/ScanNet/scene0050_00.yaml --space 10 # For ScanNet
110 | ```
111 |
112 | You can run DF-Prior by using the following code:
113 |
114 | ```bash
115 | CUDA_VISIVLE_DEVICES=0 python -W ignore run.py configs/Replica/room0.yaml # For Replica
116 | CUDA_VISIVLE_DEVICES=0 python -W ignore run.py configs/ScanNet/scene0050.yaml # For ScanNet
117 | ```
118 |
119 | The mesh for evaluation is saved as `$OUTPUT_FOLDER/mesh/final_mesh_eval_rec.ply`, where the unseen regions are culled using all frames.
120 |
121 |
122 |
123 |
124 | ## Evaluation
125 | ### Average Trajectory Error
126 | To evaluate the average trajectory error. Run the command below with the corresponding config file:
127 | ```bash
128 | python src/tools/eval_ate.py configs/Replica/room0.yaml
129 | ```
130 |
131 | ### Reconstruction Error
132 | #### Replica
133 | To evaluate the reconstruction error in Replica, first download the ground truth Replica meshes where unseen region have been culled.
134 | ```bash
135 | bash scripts/download_cull_replica_mesh.sh # Released by authors of NICE-SLAM
136 | ```
137 | Then run the command below. The 2D metric requires rendering of 1000 depth images, which will take some time (~9 minutes). Use `-2d` to enable 2D metric. Use `-3d` to enable 3D metric.
138 | ```bash
139 | # assign any output_folder and gt mesh you like, here is just an example
140 | OUTPUT_FOLDER=output/Replica/room0
141 | GT_MESH=cull_replica_mesh/room0.ply
142 | python src/tools/eval_recon.py --rec_mesh $OUTPUT_FOLDER/mesh/final_mesh_eval_rec.ply --gt_mesh $GT_MESH -2d -3d
143 | ```
144 |
145 | We also provide code to cull the mesh given camera poses. Here we take culling of ground truth mesh of Replica room0 as an example.
146 | ```bash
147 | python src/tools/cull_mesh.py --input_mesh Datasets/Replica/room0_mesh.ply --traj Datasets/Replica/room0/traj.txt --output_mesh cull_replica_mesh/room0.ply
148 | ```
149 | #### ScanNet
150 | To evaluate the reconstruction error in ScanNet, first download the ground truth ScanNet meshes.zip into `./Datasets/scannet` folder. Then run the command below.
151 | ```bash
152 | python src/tools/evaluate_scannet.py configs/ScanNet/scene0050.yaml
153 | ```
154 | We also provide our reconstructed meshes in Replica and ScanNet for evaluation purposes: meshes.zip.
155 |
156 | ## Acknowledgement
157 | We adapt codes from some awesome repositories, including [NICE-SLAM](https://github.com/cvg/nice-slam), [NeuralRGBD](https://github.com/dazinovic/neural-rgbd-surface-reconstruction), [tsdf-fusion](https://github.com/andyzeng/tsdf-fusion-python), [manhattan-sdf](https://github.com/zju3dv/manhattan_sdf), [MonoSDF](https://github.com/autonomousvision/monosdf). Thanks for making the code available. We also thank [Zihan Zhu](https://zzh2000.github.io/) of [NICE-SLAM](https://github.com/cvg/nice-slam), for prompt responses to our inquiries regarding the details of their methods.
158 |
159 | ## Citation
160 | If you find our code or paper useful, please cite
161 | ```bibtex
162 | @inproceedings{Hu2023LNI-ADFP,
163 | title = {Learning Neural Implicit through Volume Rendering with Attentive Depth Fusion Priors},
164 | author = {Hu, Pengchong and Han, Zhizhong},
165 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
166 | year = {2023}
167 | }
168 | ```
169 |
--------------------------------------------------------------------------------
/configs/Replica/office0.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-5.5,5.9],[-6.7,5.4],[-4.7,5.3]]
4 | marching_cubes_bound: [[-5.5,5.9],[-6.7,5.4],[-4.7,5.3]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/office0
8 | output: output/Replica/office0
9 | id: office0
--------------------------------------------------------------------------------
/configs/Replica/office1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-5.3,6.5],[-5.1,6.0],[-4.5,5.2]]
4 | marching_cubes_bound: [[-5.3,6.5],[-5.1,6.0],[-4.5,5.2]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/office1
8 | output: output/Replica/office1
9 | id: office1
--------------------------------------------------------------------------------
/configs/Replica/office2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-5.0,4.6],[-4.4,6.9],[-2.8,3.1]]
4 | marching_cubes_bound: [[-5.0,4.6],[-4.4,6.9],[-2.8,3.1]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/office2
8 | output: output/Replica/office2
9 | id: office2
--------------------------------------------------------------------------------
/configs/Replica/office3.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-6.7,5.1],[-7.5,4.9],[-2.8,3.5]]
4 | marching_cubes_bound: [[-6.7,5.1],[-7.5,4.9],[-2.8,3.5]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/office3
8 | output: output/Replica/office3
9 | id: office3
10 |
--------------------------------------------------------------------------------
/configs/Replica/office4.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-3.7,7.8],[-4.8,6.7],[-3.7,4.1]]
4 | marching_cubes_bound: [[-3.7,7.8],[-4.8,6.7],[-3.7,4.1]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/office4
8 | output: output/Replica/office4
9 | id: office4
--------------------------------------------------------------------------------
/configs/Replica/replica.yaml:
--------------------------------------------------------------------------------
1 | dataset: 'replica'
2 | meshing:
3 | eval_rec: True
4 | tracking:
5 | vis_freq: 50
6 | vis_inside_freq: 25
7 | ignore_edge_W: 100
8 | ignore_edge_H: 100
9 | seperate_LR: False
10 | const_speed_assumption: True
11 | lr: 0.001
12 | pixels: 200
13 | iters: 10
14 | gt_camera: False
15 | mapping:
16 | every_frame: 5
17 | vis_freq: 50
18 | vis_inside_freq: 30
19 | mesh_freq: 50
20 | ckpt_freq: 500
21 | keyframe_every: 50
22 | mapping_window_size: 5
23 | pixels: 1000
24 | iters_first: 1500
25 | iters: 60
26 | stage:
27 | low:
28 | mlp_lr: 0.0
29 | decoders_lr: 0.0
30 | low_lr: 0.1
31 | high_lr: 0.0
32 | color_lr: 0.0
33 | high:
34 | mlp_lr: 0.0
35 | decoders_lr: 0.0
36 | low_lr: 0.005
37 | high_lr: 0.005
38 | color_lr: 0.0
39 | color:
40 | mlp_lr: 0.005
41 | decoders_lr: 0.005
42 | low_lr: 0.005
43 | high_lr: 0.005
44 | color_lr: 0.005
45 | cam:
46 | H: 680
47 | W: 1200
48 | fx: 600.0
49 | fy: 600.0
50 | cx: 599.5
51 | cy: 339.5
52 | png_depth_scale: 6553.5 #for depth image in png format
53 | crop_edge: 0
--------------------------------------------------------------------------------
/configs/Replica/room0.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-2.9,8.9],[-3.2,5.5],[-3.5,3.3]]
4 | marching_cubes_bound: [[-2.9,8.9],[-3.2,5.5],[-3.5,3.3]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/room0
8 | output: output/Replica/room0
9 | id: room0
--------------------------------------------------------------------------------
/configs/Replica/room1.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-7.0,2.8],[-4.6,4.3],[-3.0,2.9]]
4 | marching_cubes_bound: [[-7.0,2.8],[-4.6,4.3],[-3.0,2.9]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/room1
8 | output: output/Replica/room1
9 | id: room1
--------------------------------------------------------------------------------
/configs/Replica/room2.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/Replica/replica.yaml
2 | mapping:
3 | bound: [[-4.3,9.5],[-6.7,5.2],[-6.4,4.2]]
4 | marching_cubes_bound: [[-4.3,9.5],[-6.7,5.2],[-6.4,4.2]]
5 | data:
6 | dataset: replica
7 | input_folder: Datasets/Replica/room2
8 | output: output/Replica/room2
9 | id: room2
--------------------------------------------------------------------------------
/configs/ScanNet/scannet.yaml:
--------------------------------------------------------------------------------
1 | dataset: 'scannet'
2 | tracking:
3 | vis_freq: 50
4 | vis_inside_freq: 25
5 | ignore_edge_W: 20
6 | ignore_edge_H: 20
7 | seperate_LR: False
8 | const_speed_assumption: True
9 | lr: 0.0005
10 | pixels: 1000
11 | iters: 50
12 | gt_camera: True #False
13 | mapping:
14 | every_frame: 5
15 | vis_freq: 50
16 | vis_inside_freq: 30
17 | mesh_freq: 50
18 | ckpt_freq: 500
19 | keyframe_every: 5
20 | mapping_window_size: 10
21 | pixels: 5000
22 | iters_first: 1500
23 | iters: 60
24 | cam:
25 | H: 480
26 | W: 640
27 | fx: 577.590698
28 | fy: 578.729797
29 | cx: 318.905426
30 | cy: 242.683609
31 | png_depth_scale: 1000. #for depth image in png format
32 | crop_edge: 10
--------------------------------------------------------------------------------
/configs/ScanNet/scene0000.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-2.0,11.0],[-2.0,11.5],[-2.0,5.5]]
4 | marching_cubes_bound: [[-2.0,11.0],[-2.0,11.5],[-2.0,5.5]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0000_00
8 | output: output/scannet/scans/scene0000_00
9 | id: 00
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0002.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[0.6,5.2],[0.0,5.6],[0.1,3.4]] #[[0.0,5.5],[0.0,6.0],[-0.5,3.5]]
4 | marching_cubes_bound: [[0.6,5.2],[0.0,5.6],[0.1,3.4]] #[[0.0,5.5],[0.0,6.0],[-0.5,3.5]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0002_00
8 | output: output/scannet/scans/scene0002_00
9 | id: 02
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0005.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.1,5.5],[0.4,5.4],[-0.1,2.5]] # [[-0.5,5.5],[0.0,5.5],[-0.5,3.0]]
4 | marching_cubes_bound: [[-0.1,5.5],[0.4,5.4],[-0.1,2.5]] # [[-0.5,5.5],[0.0,5.5],[-0.5,3.0]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0005_00
8 | output: output/scannet/scans/scene0005_00
9 | id: 05
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0012.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.5,5.5],[-0.5,5.5],[-0.5,3.0]]
4 | marching_cubes_bound: [[-0.5,5.5],[-0.5,5.5],[-0.5,3.0]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0012_00
8 | output: output/scannet/scans/scene0012_00
9 | id: 12
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0050.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[0.5,7.0],[0.0,4.5],[-0.5,3.0]]
4 | marching_cubes_bound: [[0.5,7.0],[0.0,4.5],[-0.5,3.0]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0050_00
8 | output: output/scannet/scans/scene0050_00
9 | id: 50
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0059.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.9,7.3],[-1.0,9.6],[-1.0,3.7]]
4 | marching_cubes_bound: [[-0.9,7.3],[-1.0,9.6],[-1.0,3.7]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0059_00
8 | output: output/scannet/scans/scene0059_00
9 | id: 59
--------------------------------------------------------------------------------
/configs/ScanNet/scene0084.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.5,3.0],[-0.5,7.5],[0.0,2.5]]
4 | marching_cubes_bound: [[-0.5,3.0],[-0.5,7.5],[0.0,2.5]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0084_00
8 | output: output/scannet/scans/scene0084_00
9 | id: 84
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0106.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-1.1,9.8],[-1.0,10.0],[-1.0,4.3]]
4 | marching_cubes_bound: [[-1.1,9.8],[-1.0,10.0],[-1.0,4.3]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0106_00
8 | output: output/scannet/scans/scene0106_00
9 | id: 106
--------------------------------------------------------------------------------
/configs/ScanNet/scene0169.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.2,9.8],[-1.0,8.5],[-1.0,3.4]]
4 | marching_cubes_bound: [[-0.2,9.8],[-1.0,8.5],[-1.0,3.4]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0169_00
8 | output: output/scannet/scans/scene0169_00
9 | id: 169
--------------------------------------------------------------------------------
/configs/ScanNet/scene0181.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-1.0,8.9],[-0.9,8.0],[-1.0,3.6]]
4 | marching_cubes_bound: [[-1.0,8.9],[-0.9,8.0],[-1.0,3.6]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0181_00
8 | output: output/scannet/scans/scene0181_00
9 | id: 181
--------------------------------------------------------------------------------
/configs/ScanNet/scene0207.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[0.3,9.9],[-1.0,8.0],[-1.0,3.8]]
4 | marching_cubes_bound: [[0.3,9.9],[-1.0,8.0],[-1.0,3.8]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0207_00
8 | output: output/scannet/scans/scene0207_00
9 | id: 207
--------------------------------------------------------------------------------
/configs/ScanNet/scene0472.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.6,9.5],[-1.5,9.5],[-1.5,3.5]]
4 | marching_cubes_bound: [[-0.6,9.5],[-1.5,9.5],[-1.5,3.5]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0472_00
8 | output: output/scannet/scans/scene0472_00
9 | id: 472
--------------------------------------------------------------------------------
/configs/ScanNet/scene0580.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[0.0,5.5],[-0.5,3.5],[0.0,4.0]]
4 | marching_cubes_bound: [[0.0,5.5],[-0.5,3.5],[0.0,4.0]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0580_00
8 | output: output/scannet/scans/scene0580_00
9 | id: 580
10 |
--------------------------------------------------------------------------------
/configs/ScanNet/scene0616.yaml:
--------------------------------------------------------------------------------
1 | inherit_from: configs/ScanNet/scannet.yaml
2 | mapping:
3 | bound: [[-0.5,6.0],[-0.5,5.0],[-0.5,3.0]]
4 | marching_cubes_bound: [[-0.5,6.0],[-0.5,5.0],[-0.5,3.0]]
5 | data:
6 | dataset: scannet
7 | input_folder: Datasets/scannet/scans/scene0616_00
8 | output: output/scannet/scans/scene0616_00
9 | id: 616
10 |
--------------------------------------------------------------------------------
/configs/df_prior.yaml:
--------------------------------------------------------------------------------
1 | sync_method: strict
2 | scale: 1
3 | verbose: True
4 | occupancy: True
5 | low_gpu_mem: True
6 | grid_len:
7 | low: 0.32
8 | high: 0.16
9 | color: 0.16
10 | bound_divisible: 0.32
11 | pretrained_decoders:
12 | low_high: pretrained/low_high.pt # one ckpt contain both low and high
13 | meshing:
14 | level_set: 0
15 | resolution: 256 # change to 512 for higher resolution geometry
16 | eval_rec: False
17 | clean_mesh: True
18 | depth_test: False
19 | clean_mesh_bound_scale: 1.02
20 | get_largest_components: False
21 | color_mesh_extraction_method: direct_point_query
22 | remove_small_geometry_threshold: 0.2
23 | tracking:
24 | ignore_edge_W: 20
25 | ignore_edge_H: 20
26 | use_color_in_tracking: True
27 | device: "cuda:0"
28 | handle_dynamic: True
29 | vis_freq: 50
30 | vis_inside_freq: 25
31 | w_color_loss: 0.5
32 | seperate_LR: False
33 | const_speed_assumption: True
34 | no_vis_on_first_frame: True
35 | gt_camera: True #False
36 | lr: 0.001
37 | pixels: 200
38 | iters: 10
39 | mapping:
40 | device: "cuda:0"
41 | color_refine: True
42 | low_iter_ratio: 0.4
43 | high_iter_ratio: 0.6
44 | every_frame: 5
45 | fix_high: True
46 | fix_color: False
47 | no_vis_on_first_frame: True
48 | no_mesh_on_first_frame: True
49 | no_log_on_first_frame: True
50 | vis_freq: 50
51 | vis_inside_freq: 25 #each iteration
52 | mesh_freq: 50
53 | ckpt_freq: 500
54 | keyframe_every: 50
55 | mapping_window_size: 5
56 | w_color_loss: 0.2
57 | frustum_feature_selection: True
58 | keyframe_selection_method: 'overlap'
59 | save_selected_keyframes_info: False
60 | lr_first_factor: 5
61 | lr_factor: 1
62 | pixels: 1000
63 | iters_first: 1500
64 | iters: 60
65 | stage:
66 | low:
67 | mlp_lr: 0.0
68 | decoders_lr: 0.0
69 | low_lr: 0.1
70 | high_lr: 0.0
71 | color_lr: 0.0
72 | high:
73 | mlp_lr: 0.005
74 | decoders_lr: 0.0
75 | low_lr: 0.005
76 | high_lr: 0.005
77 | color_lr: 0.0
78 | color:
79 | mlp_lr: 0.005
80 | decoders_lr: 0.005
81 | low_lr: 0.005
82 | high_lr: 0.005
83 | color_lr: 0.005
84 | cam:
85 | H: 680
86 | W: 1200
87 | fx: 600.0
88 | fy: 600.0
89 | cx: 599.5
90 | cy: 339.5
91 | png_depth_scale: 6553.5 #for depth image in png format
92 | crop_edge: 0
93 | rendering:
94 | N_samples: 32
95 | N_surface: 16
96 | N_importance: 0
97 | lindisp: False
98 | perturb: 0.0
99 | data:
100 | dim: 3
101 | model:
102 | c_dim: 32
103 | pos_embedding_method: fourier
104 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: df-prior
2 | channels:
3 | - pytorch-nightly
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=1_llvm
11 | - blas=1.0=mkl
12 | - brotli=1.0.9=he6710b0_2
13 | - brotlipy=0.7.0=py37h27cfd23_1003
14 | - bzip2=1.0.8=h7f98852_4
15 | - ca-certificates=2023.05.30=h06a4308_0
16 | - certifi=2022.12.7=py37h06a4308_0
17 | - cffi=1.15.0=py37hd667e15_1
18 | - charset-normalizer=2.0.12=pyhd8ed1ab_0
19 | - cryptography=36.0.0=py37h9ce1e76_0
20 | - cuda-cudart=12.1.105=0
21 | - cuda-cupti=12.1.105=0
22 | - cuda-libraries=12.1.0=0
23 | - cuda-nvrtc=12.1.105=0
24 | - cuda-nvtx=12.1.105=0
25 | - cuda-opencl=12.1.105=0
26 | - cuda-runtime=12.1.0=0
27 | - cudatoolkit=11.3.1=h2bc3f7f_2
28 | - cycler=0.11.0=pyhd3eb1b0_0
29 | - dbus=1.13.18=hb2f20db_0
30 | - embree=2.17.7=ha770c72_1
31 | - expat=2.4.4=h295c915_0
32 | - ffmpeg=4.3=hf484d3e_0
33 | - fontconfig=2.13.1=h6c09931_0
34 | - fonttools=4.25.0=pyhd3eb1b0_0
35 | - freetype=2.10.4=h0708190_1
36 | - giflib=5.2.1=h7b6447c_0
37 | - glib=2.69.1=h4ff587b_1
38 | - gmp=6.2.1=h58526e2_0
39 | - gnutls=3.6.13=h85f3911_1
40 | - gst-plugins-base=1.14.0=h8213a91_2
41 | - gstreamer=1.14.0=h28cd5cc_2
42 | - icu=58.2=he6710b0_3
43 | - idna=3.3=pyhd3eb1b0_0
44 | - jpeg=9b=h024ee3a_2
45 | - kiwisolver=1.3.2=py37h295c915_0
46 | - lame=3.100=h7f98852_1001
47 | - lcms2=2.12=h3be6417_0
48 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
49 | - libcublas=12.1.0.26=0
50 | - libcufft=11.0.2.4=0
51 | - libcufile=1.6.1.9=0
52 | - libcurand=10.3.2.106=0
53 | - libcusolver=11.4.4.55=0
54 | - libcusparse=12.0.2.55=0
55 | - libffi=3.3=he6710b0_2
56 | - libgcc-ng=11.2.0=h1d223b6_14
57 | - libiconv=1.16=h516909a_0
58 | - libnpp=12.0.2.50=0
59 | - libnsl=2.0.0=h7f98852_0
60 | - libnvjitlink=12.1.105=0
61 | - libnvjpeg=12.1.0.39=0
62 | - libpng=1.6.37=h21135ba_2
63 | - libstdcxx-ng=11.2.0=he4da1e4_14
64 | - libtiff=4.2.0=h85742a9_0
65 | - libuuid=1.0.3=h7f8727e_2
66 | - libuv=1.43.0=h7f98852_0
67 | - libwebp=1.2.0=h89dd481_0
68 | - libwebp-base=1.2.0=h27cfd23_0
69 | - libxcb=1.14=h7b6447c_0
70 | - libxml2=2.9.12=h03d6c58_0
71 | - libzlib=1.2.11=h166bdaf_1014
72 | - llvm-openmp=13.0.1=he0ac6c6_1
73 | - lz4-c=1.9.3=h295c915_1
74 | - matplotlib=3.4.3=py37h06a4308_0
75 | - matplotlib-base=3.4.3=py37hbbc1b5f_0
76 | - mkl=2021.4.0=h8d4b97c_729
77 | - mkl-service=2.4.0=py37h402132d_0
78 | - mkl_fft=1.3.1=py37h3e078e5_1
79 | - mkl_random=1.2.2=py37h219a48f_0
80 | - munkres=1.1.4=py_0
81 | - ncurses=6.3=h9c3ff4c_0
82 | - nettle=3.6=he412f7d_0
83 | - ninja=1.10.2=h4bd325d_1
84 | - numpy-base=1.21.2=py37h79a1101_0
85 | - olefile=0.46=pyh9f0ad1d_1
86 | - openh264=2.1.1=h780b84a_0
87 | - openssl=1.1.1n=h166bdaf_0
88 | - pcre=8.45=h295c915_0
89 | - pip=22.0.4=pyhd8ed1ab_0
90 | - pycparser=2.21=pyhd3eb1b0_0
91 | - pyembree=0.1.6=py37h0da4684_1
92 | - pyopenssl=22.0.0=pyhd3eb1b0_0
93 | - pyparsing=3.0.4=pyhd3eb1b0_0
94 | - pyqt=5.9.2=py37h05f1152_2
95 | - pysocks=1.7.1=py37_1
96 | - python=3.7.11=h12debd9_0
97 | - python-dateutil=2.8.2=pyhd3eb1b0_0
98 | - python_abi=3.7=2_cp37m
99 | - pytorch-cuda=12.1=ha16c6d3_5
100 | - pytorch-mutex=1.0=cuda
101 | - qt=5.9.7=h5867ecd_1
102 | - readline=8.1=h46c0cb4_0
103 | - requests=2.27.1=pyhd3eb1b0_0
104 | - setuptools=61.2.0=py37h89c1867_3
105 | - sip=4.19.8=py37hf484d3e_0
106 | - six=1.16.0=pyh6c4a22f_0
107 | - sqlite=3.37.1=h4ff8645_0
108 | - tbb=2021.5.0=h4bd325d_0
109 | - tk=8.6.12=h27826a3_0
110 | - torchaudio=0.11.0=py37_cu113
111 | - tornado=6.1=py37h27cfd23_0
112 | - typing_extensions=4.3.0=py37h06a4308_0
113 | - wheel=0.37.1=pyhd8ed1ab_0
114 | - xz=5.2.5=h516909a_1
115 | - zlib=1.2.11=h166bdaf_1014
116 | - zstd=1.4.9=haebb681_0
117 | - pip:
118 | - addict==2.4.0
119 | - ansi2html==1.8.0
120 | - anyio==3.5.0
121 | - appdirs==1.4.4
122 | - argon2-cffi==21.3.0
123 | - argon2-cffi-bindings==21.2.0
124 | - argparse==1.4.0
125 | - arrow==1.2.3
126 | - attrs==21.4.0
127 | - babel==2.9.1
128 | - backcall==0.2.0
129 | - beautifulsoup4==4.10.0
130 | - bleach==4.1.0
131 | - click==8.1.7
132 | - cloudpickle==2.0.0
133 | - colorama==0.4.4
134 | - comm==0.1.4
135 | - configargparse==1.7
136 | - dash==2.14.1
137 | - dash-core-components==2.0.0
138 | - dash-html-components==2.0.0
139 | - dash-table==5.0.0
140 | - dask==2022.2.0
141 | - debugpy==1.6.0
142 | - decorator==5.1.1
143 | - defusedxml==0.7.1
144 | - deprecation==2.1.0
145 | - docopt==0.6.2
146 | - entrypoints==0.4
147 | - fastjsonschema==2.19.0
148 | - filelock==3.6.0
149 | - flask==2.2.5
150 | - freetype-py==2.3.0
151 | - fsspec==2022.2.0
152 | - gdown==4.4.0
153 | - imageio==2.16.1
154 | - imath==0.0.1
155 | - importlib-metadata==4.11.3
156 | - importlib-resources==5.6.0
157 | - inform==1.28
158 | - ipykernel==6.10.0
159 | - ipython==7.32.0
160 | - ipython-genutils==0.2.0
161 | - ipywidgets==8.1.1
162 | - itsdangerous==2.1.2
163 | - jedi==0.18.1
164 | - jinja2==3.1.1
165 | - joblib==1.1.0
166 | - json5==0.9.6
167 | - jsonschema==4.4.0
168 | - jupyter-client==7.2.0
169 | - jupyter-core==4.9.2
170 | - jupyter-packaging==0.12.0
171 | - jupyter-server==1.16.0
172 | - jupyterlab==3.3.2
173 | - jupyterlab-pygments==0.1.2
174 | - jupyterlab-server==2.12.0
175 | - jupyterlab-widgets==3.0.9
176 | - llvmlite==0.39.1
177 | - locket==0.2.1
178 | - mako==1.2.4
179 | - markupsafe==2.1.1
180 | - mathutils==2.81.2
181 | - matplotlib-inline==0.1.3
182 | - mistune==0.8.4
183 | - nbclassic==0.3.7
184 | - nbclient==0.5.13
185 | - nbconvert==6.4.5
186 | - nbformat==5.7.0
187 | - nest-asyncio==1.5.4
188 | - networkx==2.6.3
189 | - notebook==6.4.10
190 | - notebook-shim==0.1.0
191 | - numba==0.56.4
192 | - numpy==1.21.5
193 | - open3d==0.17.0
194 | - opencv-python==4.5.5.64
195 | - openexr==1.3.7
196 | - packaging==21.3
197 | - pandas==1.3.5
198 | - pandocfilters==1.5.0
199 | - parso==0.8.3
200 | - partd==1.2.0
201 | - pexpect==4.8.0
202 | - pickleshare==0.7.5
203 | - pillow==9.5.0
204 | - platformdirs==3.2.0
205 | - plotly==5.18.0
206 | - plyfile==0.9
207 | - prometheus-client==0.13.1
208 | - prompt-toolkit==3.0.28
209 | - psutil==5.9.0
210 | - ptyprocess==0.7.0
211 | - pycuda==2022.1
212 | - pyglet==2.0.6
213 | - pygments==2.11.2
214 | - pyopengl==3.1.0
215 | - pypng==0.20220715.0
216 | - pyquaternion==0.9.9
217 | - pyrender==0.1.45
218 | - pyrsistent==0.18.1
219 | - pytools==2022.1.12
220 | - pytz==2022.1
221 | - pywavelets==1.3.0
222 | - pyyaml==6.0
223 | - pyzmq==22.3.0
224 | - quantiphy==2.19
225 | - retrying==1.3.4
226 | - rtree==0.9.7
227 | - scikit-image==0.19.2
228 | - scikit-learn==1.0.2
229 | - scipy==1.7.3
230 | - seaborn==0.11.2
231 | - send2trash==1.8.0
232 | - sniffio==1.2.0
233 | - soupsieve==2.3.1
234 | - tenacity==8.2.3
235 | - terminado==0.13.3
236 | - testpath==0.6.0
237 | - threadpoolctl==3.1.0
238 | - tifffile==2021.11.2
239 | - tomlkit==0.10.1
240 | - toolz==0.11.2
241 | - torch==1.11.0+cu113
242 | - torchvision==0.12.0+cu113
243 | - tqdm==4.63.1
244 | - traitlets==5.1.1
245 | - trimesh==3.10.7
246 | - tvm==1.0.0
247 | - typing-extensions==4.5.0
248 | - urllib3==1.26.9
249 | - wcwidth==0.2.5
250 | - webencodings==0.5.1
251 | - websocket-client==1.3.2
252 | - werkzeug==2.2.3
253 | - widgetsnbextension==4.0.9
254 | - zipp==3.7.0
255 |
--------------------------------------------------------------------------------
/get_tsdf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 |
6 | from src import config
7 | import src.fusion as fusion
8 | import open3d as o3d
9 | from src.utils.datasets import get_dataset
10 |
11 |
12 | def update_cam(cfg):
13 | """
14 | Update the camera intrinsics according to pre-processing config,
15 | such as resize or edge crop.
16 | """
17 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][
18 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']
19 | # resize the input images to crop_size (variable name used in lietorch)
20 | if 'crop_size' in cfg['cam']:
21 | crop_size = cfg['cam']['crop_size']
22 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][
23 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']
24 | sx = crop_size[1] / W
25 | sy = crop_size[0] / H
26 | fx = sx*fx
27 | fy = sy*fy
28 | cx = sx*cx
29 | cy = sy*cy
30 | W = crop_size[1]
31 | H = crop_size[0]
32 |
33 |
34 | # croping will change H, W, cx, cy, so need to change here
35 | if cfg['cam']['crop_edge'] > 0:
36 | H -= cfg['cam']['crop_edge']*2
37 | W -= cfg['cam']['crop_edge']*2
38 | cx -= cfg['cam']['crop_edge']
39 | cy -= cfg['cam']['crop_edge']
40 |
41 | return H, W, fx, fy, cx, cy
42 |
43 |
44 | def init_tsdf_volume(cfg, args, space=10):
45 | """
46 | Initialize the TSDF volume.
47 | Get the TSDF volume and bounds.
48 |
49 | space: the space between frames to integrate into the TSDF volume.
50 |
51 | """
52 | # scale the bound if there is a global scaling factor
53 | scale = cfg['scale']
54 | bound = torch.from_numpy(
55 | np.array(cfg['mapping']['bound'])*scale)
56 | bound_divisible = cfg['grid_len']['bound_divisible']
57 | # enlarge the bound a bit to allow it divisible by bound_divisible
58 | bound[:, 1] = (((bound[:, 1]-bound[:, 0]) /
59 | bound_divisible).int()+1)*bound_divisible+bound[:, 0]
60 |
61 | # TSDF volume
62 | H, W, fx, fy, cx, cy = update_cam(cfg)
63 | intrinsic = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy).intrinsic_matrix # (3, 3)
64 |
65 | print("Initializing voxel volume...")
66 | vol_bnds = np.array(bound)
67 | tsdf_vol = fusion.TSDFVolume(vol_bnds, voxel_size=4.0/256)
68 |
69 | frame_reader = get_dataset(cfg, args, scale)
70 |
71 | for idx in range(len(frame_reader)):
72 | if idx % space != 0: continue
73 | print(f'frame: {idx}')
74 | _, gt_color, gt_depth, gt_c2w = frame_reader[idx]
75 |
76 | # convert to open3d camera pose
77 | c2w = gt_c2w.cpu().numpy()
78 |
79 | if np.isfinite(c2w).any():
80 | c2w[:3, 1] *= -1.0
81 | c2w[:3, 2] *= -1.0
82 |
83 | depth = gt_depth.cpu().numpy() # (368, 496, 3)
84 | color = gt_color.cpu().numpy()
85 | depth = depth.astype(np.float32)
86 | color = np.array((color * 255).astype(np.uint8))
87 | tsdf_vol.integrate(color, depth, intrinsic, c2w, obs_weight=1.)
88 |
89 | print('Getting TSDF volume')
90 | tsdf_volume, _, bounds = tsdf_vol.get_volume()
91 |
92 | print("Getting mesh")
93 | verts, faces, norms, colors = tsdf_vol.get_mesh()
94 |
95 | tsdf_volume = torch.tensor(tsdf_volume)
96 | tsdf_volume = tsdf_volume.reshape(1, 1, tsdf_volume.shape[0], tsdf_volume.shape[1], tsdf_volume.shape[2])
97 | tsdf_volume = tsdf_volume.permute(0, 1, 4, 3, 2)
98 |
99 | return tsdf_volume, bounds, verts, faces, norms, colors
100 |
101 | def get_tsdf():
102 | """
103 | Save the TSDF volume and bounds to a file.
104 |
105 | """
106 | parser = argparse.ArgumentParser(
107 | description='Arguments for running the code.'
108 | )
109 | parser.add_argument('config', type=str, help='Path to config file.')
110 | parser.add_argument('--input_folder', type=str,
111 | help='input folder, this have higher priority, can overwrite the one in config file')
112 | parser.add_argument('--output', type=str,
113 | help='output folder, this have higher priority, can overwrite the one in config file')
114 | parser.add_argument('--space', type=int, default=10, help='the space between frames to integrate into the TSDF volume.')
115 |
116 | args = parser.parse_args()
117 | cfg = config.load_config(args.config, 'configs/df_prior.yaml')
118 |
119 | dataset = cfg['data']['dataset']
120 | scene_id = cfg['data']['id']
121 |
122 |
123 | path = f'{dataset}_tsdf_volume'
124 | os.makedirs(path, exist_ok=True)
125 |
126 | tsdf_volume, bounds, verts, faces, norms, colors = init_tsdf_volume(cfg, args, space=args.space)
127 |
128 | if dataset == 'scannet':
129 | tsdf_volume_path = os.path.join(path, f'scene{scene_id}_tsdf_volume.pt')
130 | bounds_path = os.path.join(path, f'scene{scene_id}_bounds.pt')
131 |
132 | elif dataset == 'replica':
133 | tsdf_volume_path = os.path.join(path, f'{scene_id}_tsdf_volume.pt')
134 | bounds_path = os.path.join(path, f'{scene_id}_bounds.pt')
135 |
136 |
137 | torch.save(tsdf_volume, tsdf_volume_path)
138 | torch.save(bounds, bounds_path)
139 |
140 |
141 |
142 | if __name__ == '__main__':
143 | get_tsdf()
144 |
--------------------------------------------------------------------------------
/pretrained/low_high.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/pretrained/low_high.pt
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from src import config
8 | from src.DF_Prior import DF_Prior
9 |
10 |
11 | def setup_seed(seed):
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed_all(seed)
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | torch.backends.cudnn.deterministic = True
17 |
18 |
19 | def main():
20 |
21 | parser = argparse.ArgumentParser(
22 | description='Arguments for running the code.'
23 | )
24 | parser.add_argument('config', type=str, help='Path to config file.')
25 | parser.add_argument('--input_folder', type=str,
26 | help='input folder, this have higher priority, can overwrite the one in config file')
27 | parser.add_argument('--output', type=str,
28 | help='output folder, this have higher priority, can overwrite the one in config file')
29 | args = parser.parse_args()
30 | cfg = config.load_config(args.config, 'configs/df_prior.yaml')
31 |
32 | slam = DF_Prior(cfg, args)
33 | slam.run()
34 |
35 |
36 | if __name__ == '__main__':
37 | main()
38 |
--------------------------------------------------------------------------------
/scripts/download_cull_replica_mesh.sh:
--------------------------------------------------------------------------------
1 | # you can also download the cull_replica_mesh.zip manually through
2 | # link: https://caiyun.139.com/m/i?1A5CvGNQmSDsI password: Dhhp
3 | wget https://cvg-data.inf.ethz.ch/nice-slam/cull_replica_mesh.zip
4 | unzip cull_replica_mesh.zip
--------------------------------------------------------------------------------
/scripts/download_replica.sh:
--------------------------------------------------------------------------------
1 | mkdir -p Datasets
2 | cd Datasets
3 | # you can also download the Replica.zip manually through
4 | # link: https://caiyun.139.com/m/i?1A5Ch5C3abNiL password: v3fY (the zip is split into smaller zips because of the size limitation of caiyun)
5 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip
6 | unzip Replica.zip
7 |
--------------------------------------------------------------------------------
/src/DF_Prior.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | import torch.multiprocessing
7 | import torch.multiprocessing as mp
8 |
9 | from src import config
10 | from src.Mapper import Mapper
11 | from src.Tracker import Tracker
12 | from src.utils.datasets import get_dataset
13 | from src.utils.Logger import Logger
14 | from src.utils.Mesher import Mesher
15 | from src.utils.Renderer import Renderer
16 |
17 | # import src.fusion as fusion
18 | # import open3d as o3d
19 |
20 | torch.multiprocessing.set_sharing_strategy('file_system')
21 |
22 |
23 | class DF_Prior():
24 | """
25 | DF_Prior main class.
26 | Mainly allocate shared resources, and dispatch mapping and tracking process.
27 | """
28 |
29 | def __init__(self, cfg, args):
30 |
31 | self.cfg = cfg
32 | self.args = args
33 |
34 | self.occupancy = cfg['occupancy']
35 | self.low_gpu_mem = cfg['low_gpu_mem']
36 | self.verbose = cfg['verbose']
37 | self.dataset = cfg['dataset']
38 | if args.output is None:
39 | self.output = cfg['data']['output']
40 | else:
41 | self.output = args.output
42 | self.ckptsdir = os.path.join(self.output, 'ckpts')
43 | os.makedirs(self.output, exist_ok=True)
44 | os.makedirs(self.ckptsdir, exist_ok=True)
45 | os.makedirs(f'{self.output}/mesh', exist_ok=True)
46 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = cfg['cam']['H'], cfg['cam'][
47 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']
48 | self.update_cam()
49 |
50 | model = config.get_model(cfg)
51 | self.shared_decoders = model
52 |
53 | self.scale = cfg['scale']
54 |
55 | self.load_bound(cfg)
56 | self.load_pretrain(cfg)
57 | self.grid_init(cfg)
58 |
59 | # need to use spawn
60 | try:
61 | mp.set_start_method('spawn', force=True)
62 | except RuntimeError:
63 | pass
64 |
65 | self.frame_reader = get_dataset(cfg, args, self.scale)
66 | self.n_img = len(self.frame_reader)
67 | self.estimate_c2w_list = torch.zeros((self.n_img, 4, 4))
68 | self.estimate_c2w_list.share_memory_()
69 |
70 | dataset = self.cfg['data']['dataset']
71 | scene_id = self.cfg['data']['id']
72 | self.scene_id = scene_id
73 | print(scene_id)
74 | # load tsdf grid
75 | if dataset == 'scannet':
76 | self.tsdf_volume_shared = torch.load(f'scannet_tsdf_volume/scene{scene_id}_tsdf_volume.pt')
77 | elif dataset == 'replica':
78 | self.tsdf_volume_shared = torch.load(f'replica_tsdf_volume/{scene_id}_tsdf_volume.pt')
79 | self.tsdf_volume_shared = self.tsdf_volume_shared.to(self.cfg['mapping']['device'])
80 | self.tsdf_volume_shared.share_memory_()
81 |
82 | # load tsdf grid bound
83 | if dataset == 'scannet':
84 | self.tsdf_bnds = torch.load(f'scannet_tsdf_volume/scene{scene_id}_bounds.pt')
85 | elif dataset == 'replica':
86 | self.tsdf_bnds = torch.load(f'replica_tsdf_volume/{scene_id}_bounds.pt')
87 | self.tsdf_bnds = torch.tensor(self.tsdf_bnds).to(self.cfg['mapping']['device'])
88 | self.tsdf_bnds.share_memory_()
89 |
90 | self.vol_bnds = self.tsdf_bnds
91 | self.vol_bnds.share_memory_()
92 |
93 | self.gt_c2w_list = torch.zeros((self.n_img, 4, 4))
94 | self.gt_c2w_list.share_memory_()
95 | self.idx = torch.zeros((1)).int()
96 | self.idx.share_memory_()
97 | self.mapping_first_frame = torch.zeros((1)).int()
98 | self.mapping_first_frame.share_memory_()
99 | # the id of the newest frame Mapper is processing
100 | self.mapping_idx = torch.zeros((1)).int()
101 | self.mapping_idx.share_memory_()
102 | self.mapping_cnt = torch.zeros((1)).int() # counter for mapping
103 | self.mapping_cnt.share_memory_()
104 | for key, val in self.shared_c.items():
105 | val = val.to(self.cfg['mapping']['device'])
106 | val.share_memory_()
107 | self.shared_c[key] = val
108 | self.shared_decoders = self.shared_decoders.to(
109 | self.cfg['mapping']['device'])
110 | self.shared_decoders.share_memory()
111 | self.renderer = Renderer(cfg, args, self)
112 | self.mesher = Mesher(cfg, args, self)
113 | self.logger = Logger(cfg, args, self)
114 | self.mapper = Mapper(cfg, args, self)
115 | self.tracker = Tracker(cfg, args, self)
116 | self.print_output_desc()
117 |
118 |
119 | def print_output_desc(self):
120 | print(f"INFO: The output folder is {self.output}")
121 | if 'Demo' in self.output:
122 | print(
123 | f"INFO: The GT, generated and residual depth/color images can be found under " +
124 | f"{self.output}/vis/")
125 | else:
126 | print(
127 | f"INFO: The GT, generated and residual depth/color images can be found under " +
128 | f"{self.output}/tracking_vis/ and {self.output}/mapping_vis/")
129 | print(f"INFO: The mesh can be found under {self.output}/mesh/")
130 | print(f"INFO: The checkpoint can be found under {self.output}/ckpt/")
131 |
132 |
133 | def update_cam(self):
134 | """
135 | Update the camera intrinsics according to pre-processing config,
136 | such as resize or edge crop.
137 | """
138 | # resize the input images to crop_size (variable name used in lietorch)
139 | if 'crop_size' in self.cfg['cam']:
140 | crop_size = self.cfg['cam']['crop_size']
141 | sx = crop_size[1] / self.W
142 | sy = crop_size[0] / self.H
143 | self.fx = sx*self.fx
144 | self.fy = sy*self.fy
145 | self.cx = sx*self.cx
146 | self.cy = sy*self.cy
147 | self.W = crop_size[1]
148 | self.H = crop_size[0]
149 |
150 | # croping will change H, W, cx, cy, so need to change here
151 | if self.cfg['cam']['crop_edge'] > 0:
152 | self.H -= self.cfg['cam']['crop_edge']*2
153 | self.W -= self.cfg['cam']['crop_edge']*2
154 | self.cx -= self.cfg['cam']['crop_edge']
155 | self.cy -= self.cfg['cam']['crop_edge']
156 |
157 |
158 | # def init_tsdf_volume(self, cfg, args):
159 | # # scale the bound if there is a global scaling factor
160 | # scale = cfg['scale']
161 | # bound = torch.from_numpy(
162 | # np.array(cfg['mapping']['bound'])*scale)
163 | # bound_divisible = cfg['grid_len']['bound_divisible']
164 | # # enlarge the bound a bit to allow it divisible by bound_divisible
165 | # bound[:, 1] = (((bound[:, 1]-bound[:, 0]) /
166 | # bound_divisible).int()+1)*bound_divisible+bound[:, 0]
167 | # intrinsic = o3d.camera.PinholeCameraIntrinsic(self.W, self.H, self.fx, self.fy, self.cx, self.cy).intrinsic_matrix # (3, 3)
168 |
169 | # print("Initializing voxel volume...")
170 | # vol_bnds = np.array(bound)
171 | # tsdf_vol = fusion.TSDFVolume(vol_bnds, voxel_size=4/256)
172 |
173 |
174 | # return tsdf_vol, intrinsic, vol_bnds
175 |
176 |
177 | def load_bound(self, cfg):
178 | """
179 | Pass the scene bound parameters to different decoders and self.
180 |
181 | Args:
182 | cfg (dict): parsed config dict.
183 | """
184 | # scale the bound if there is a global scaling factor
185 | self.bound = torch.from_numpy(
186 | np.array(cfg['mapping']['bound'])*self.scale)
187 | bound_divisible = cfg['grid_len']['bound_divisible']
188 | # enlarge the bound a bit to allow it divisible by bound_divisible
189 | self.bound[:, 1] = (((self.bound[:, 1]-self.bound[:, 0]) /
190 | bound_divisible).int()+1)*bound_divisible+self.bound[:, 0]
191 | self.shared_decoders.bound = self.bound
192 | self.shared_decoders.low_decoder.bound = self.bound
193 | self.shared_decoders.high_decoder.bound = self.bound
194 | self.shared_decoders.color_decoder.bound = self.bound
195 |
196 |
197 | def load_pretrain(self, cfg):
198 | """
199 | Load parameters of pretrained ConvOnet checkpoints to the decoders.
200 |
201 | Args:
202 | cfg (dict): parsed config dict
203 | """
204 |
205 | ckpt = torch.load(cfg['pretrained_decoders']['low_high'],
206 | map_location=cfg['mapping']['device'])
207 | low_dict = {}
208 | high_dict = {}
209 | for key, val in ckpt['model'].items():
210 | if ('decoder' in key) and ('encoder' not in key):
211 | if 'coarse' in key:
212 | key = key[8+7:]
213 | low_dict[key] = val
214 | elif 'fine' in key:
215 | key = key[8+5:]
216 | high_dict[key] = val
217 | self.shared_decoders.low_decoder.load_state_dict(low_dict)
218 | self.shared_decoders.high_decoder.load_state_dict(high_dict)
219 |
220 |
221 | def grid_init(self, cfg):
222 | """
223 | Initialize the hierarchical feature grids.
224 |
225 | Args:
226 | cfg (dict): parsed config dict.
227 | """
228 |
229 | low_grid_len = cfg['grid_len']['low']
230 | self.low_grid_len = low_grid_len
231 | high_grid_len = cfg['grid_len']['high']
232 | self.high_grid_len = high_grid_len
233 | color_grid_len = cfg['grid_len']['color']
234 | self.color_grid_len = color_grid_len
235 |
236 | c = {}
237 | c_dim = cfg['model']['c_dim']
238 | xyz_len = self.bound[:, 1]-self.bound[:, 0]
239 |
240 |
241 |
242 | low_key = 'grid_low'
243 | low_val_shape = list(map(int, (xyz_len/low_grid_len).tolist()))
244 | low_val_shape[0], low_val_shape[2] = low_val_shape[2], low_val_shape[0]
245 | self.low_val_shape = low_val_shape
246 | val_shape = [1, c_dim, *low_val_shape]
247 | low_val = torch.zeros(val_shape).normal_(mean=0, std=0.01)
248 | c[low_key] = low_val
249 |
250 | high_key = 'grid_high'
251 | high_val_shape = list(map(int, (xyz_len/high_grid_len).tolist()))
252 | high_val_shape[0], high_val_shape[2] = high_val_shape[2], high_val_shape[0]
253 | self.high_val_shape = high_val_shape
254 | val_shape = [1, c_dim, *high_val_shape]
255 | high_val = torch.zeros(val_shape).normal_(mean=0, std=0.0001)
256 | c[high_key] = high_val
257 |
258 | color_key = 'grid_color'
259 | color_val_shape = list(map(int, (xyz_len/color_grid_len).tolist()))
260 | color_val_shape[0], color_val_shape[2] = color_val_shape[2], color_val_shape[0]
261 | self.color_val_shape = color_val_shape
262 | val_shape = [1, c_dim, *color_val_shape]
263 | color_val = torch.zeros(val_shape).normal_(mean=0, std=0.01)
264 | c[color_key] = color_val
265 |
266 | self.shared_c = c
267 |
268 |
269 | def tracking(self, rank):
270 | """
271 | Tracking Thread.
272 |
273 | Args:
274 | rank (int): Thread ID.
275 | """
276 |
277 | # should wait until the mapping of first frame is finished
278 | while (1):
279 | if self.mapping_first_frame[0] == 1:
280 | break
281 | time.sleep(1)
282 |
283 | self.tracker.run()
284 |
285 |
286 | def mapping(self, rank):
287 | """
288 | Mapping Thread. (updates low, high, and color level)
289 |
290 | Args:
291 | rank (int): Thread ID.
292 | """
293 |
294 | self.mapper.run()
295 |
296 |
297 | def run(self):
298 | """
299 | Dispatch Threads.
300 | """
301 |
302 | processes = []
303 | for rank in range(2):
304 | if rank == 0:
305 | p = mp.Process(target=self.tracking, args=(rank, ))
306 | elif rank == 1:
307 | p = mp.Process(target=self.mapping, args=(rank, ))
308 | p.start()
309 | processes.append(p)
310 | for p in processes:
311 | p.join()
312 |
313 |
314 | # This part is required by torch.multiprocessing
315 | if __name__ == '__main__':
316 | pass
317 |
--------------------------------------------------------------------------------
/src/Tracker.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from colorama import Fore, Style
8 | from torch.autograd import Variable
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 |
12 | from src.common import (get_camera_from_tensor, get_samples,
13 | get_tensor_from_camera)
14 | from src.utils.datasets import get_dataset
15 | from src.utils.Visualizer import Visualizer
16 |
17 |
18 |
19 |
20 | class Tracker(object):
21 | def __init__(self, cfg, args, slam
22 | ):
23 | self.cfg = cfg
24 | self.args = args
25 |
26 | self.scale = cfg['scale']
27 | self.occupancy = cfg['occupancy']
28 | self.sync_method = cfg['sync_method']
29 |
30 | self.idx = slam.idx
31 | self.bound = slam.bound
32 | self.mesher = slam.mesher
33 | self.output = slam.output
34 | self.verbose = slam.verbose
35 | self.shared_c = slam.shared_c
36 | self.renderer = slam.renderer
37 | self.gt_c2w_list = slam.gt_c2w_list
38 | self.low_gpu_mem = slam.low_gpu_mem
39 | self.mapping_idx = slam.mapping_idx
40 | self.mapping_cnt = slam.mapping_cnt
41 | self.shared_decoders = slam.shared_decoders
42 | self.estimate_c2w_list = slam.estimate_c2w_list
43 | with torch.no_grad():
44 | self.tsdf_volume_shared = slam.tsdf_volume_shared
45 | self.tsdf_bnds = slam.tsdf_bnds
46 |
47 |
48 | self.cam_lr = cfg['tracking']['lr']
49 | self.device = cfg['tracking']['device']
50 | self.num_cam_iters = cfg['tracking']['iters']
51 | self.gt_camera = cfg['tracking']['gt_camera']
52 | self.tracking_pixels = cfg['tracking']['pixels']
53 | self.seperate_LR = cfg['tracking']['seperate_LR']
54 | self.w_color_loss = cfg['tracking']['w_color_loss']
55 | self.ignore_edge_W = cfg['tracking']['ignore_edge_W']
56 | self.ignore_edge_H = cfg['tracking']['ignore_edge_H']
57 | self.handle_dynamic = cfg['tracking']['handle_dynamic']
58 | self.use_color_in_tracking = cfg['tracking']['use_color_in_tracking']
59 | self.const_speed_assumption = cfg['tracking']['const_speed_assumption']
60 |
61 | self.every_frame = cfg['mapping']['every_frame']
62 | self.no_vis_on_first_frame = cfg['mapping']['no_vis_on_first_frame'] # ori mapping
63 |
64 | self.prev_mapping_idx = -1
65 | self.frame_reader = get_dataset(
66 | cfg, args, self.scale, device=self.device)
67 | self.n_img = len(self.frame_reader)
68 | self.frame_loader = DataLoader(
69 | self.frame_reader, batch_size=1, shuffle=False, num_workers=1)
70 | self.visualizer = Visualizer(freq=cfg['tracking']['vis_freq'], inside_freq=cfg['tracking']['vis_inside_freq'],
71 | vis_dir=os.path.join(self.output, 'vis' if 'Demo' in self.output else 'tracking_vis'),
72 | renderer=self.renderer, verbose=self.verbose, device=self.device)
73 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = slam.H, slam.W, slam.fx, slam.fy, slam.cx, slam.cy
74 |
75 | def optimize_cam_in_batch(self, camera_tensor, gt_color, gt_depth, batch_size, optimizer, tsdf_volume):
76 | """
77 | Do one iteration of camera iteration. Sample pixels, render depth/color, calculate loss and backpropagation.
78 |
79 | Args:
80 | camera_tensor (tensor): camera tensor.
81 | gt_color (tensor): ground truth color image of the current frame.
82 | gt_depth (tensor): ground truth depth image of the current frame.
83 | batch_size (int): batch size, number of sampling rays.
84 | optimizer (torch.optim): camera optimizer.
85 | tsdf_volume (tensor): tsdf volume
86 |
87 | Returns:
88 | loss (float): The value of loss.
89 | """
90 | device = self.device
91 | H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy
92 | optimizer.zero_grad()
93 | c2w = get_camera_from_tensor(camera_tensor)
94 | tsdf_bnds = self.tsdf_bnds.to(device)
95 | Wedge = self.ignore_edge_W
96 | Hedge = self.ignore_edge_H
97 | batch_rays_o, batch_rays_d, batch_gt_depth, batch_gt_color = get_samples(
98 | Hedge, H-Hedge, Wedge, W-Wedge, batch_size, H, W, fx, fy, cx, cy, c2w, gt_depth, gt_color, self.device)
99 |
100 | # should pre-filter those out of bounding box depth value
101 | with torch.no_grad():
102 | det_rays_o = batch_rays_o.clone().detach().unsqueeze(-1) # (N, 3, 1)
103 | det_rays_d = batch_rays_d.clone().detach().unsqueeze(-1) # (N, 3, 1)
104 | t = (self.bound.unsqueeze(0).to(device)-det_rays_o)/det_rays_d
105 | t, _ = torch.min(torch.max(t, dim=2)[0], dim=1)
106 | inside_mask = t >= batch_gt_depth
107 | batch_rays_d = batch_rays_d[inside_mask]
108 | batch_rays_o = batch_rays_o[inside_mask]
109 | batch_gt_depth = batch_gt_depth[inside_mask]
110 | batch_gt_color = batch_gt_color[inside_mask]
111 |
112 | ret = self.renderer.render_batch_ray(
113 | self.c, self.decoders, batch_rays_d, batch_rays_o, self.device, tsdf_volume, tsdf_bnds, stage='color', gt_depth=batch_gt_depth) #color
114 | depth, uncertainty, color, _ = ret
115 |
116 | uncertainty = uncertainty.detach()
117 | if self.handle_dynamic:
118 | tmp = torch.abs(batch_gt_depth-depth)/torch.sqrt(uncertainty+1e-10)
119 | mask = (tmp < 10*tmp.median()) & (batch_gt_depth > 0)
120 | else:
121 | mask = batch_gt_depth > 0
122 |
123 | loss = (torch.abs(batch_gt_depth-depth) /
124 | torch.sqrt(uncertainty+1e-10))[mask].sum()
125 |
126 | if self.use_color_in_tracking:
127 | color_loss = torch.abs(
128 | batch_gt_color - color)[mask].sum()
129 | loss += self.w_color_loss*color_loss
130 |
131 | loss.backward(retain_graph=False)
132 | optimizer.step()
133 | optimizer.zero_grad()
134 | return loss.item()
135 |
136 | def update_para_from_mapping(self):
137 | """
138 | Update the parameters of scene representation from the mapping thread.
139 |
140 | """
141 | if self.mapping_idx[0] != self.prev_mapping_idx:
142 | if self.verbose:
143 | print('Tracking: update the parameters from mapping')
144 | self.decoders = copy.deepcopy(self.shared_decoders).to(self.device)
145 | for key, val in self.shared_c.items():
146 | val = val.clone().to(self.device)
147 | self.c[key] = val
148 | self.prev_mapping_idx = self.mapping_idx[0].clone()
149 |
150 | def run(self):
151 | device = self.device
152 | tsdf_volume = self.tsdf_volume_shared
153 | tsdf_bnds = self.tsdf_bnds.to(device)
154 |
155 | self.c = {}
156 | if self.verbose:
157 | pbar = self.frame_loader
158 | else:
159 | pbar = tqdm(self.frame_loader)
160 |
161 | for idx, gt_color, gt_depth, gt_c2w in pbar:
162 | if not self.verbose:
163 | pbar.set_description(f"Tracking Frame {idx[0]}")
164 |
165 | idx = idx[0]
166 | gt_depth = gt_depth[0]
167 | gt_color = gt_color[0]
168 | gt_c2w = gt_c2w[0]
169 |
170 | if self.sync_method == 'strict':
171 | # strictly mapping and then tracking
172 | # initiate mapping every self.every_frame frames
173 | if idx > 0 and (idx % self.every_frame == 1 or self.every_frame == 1):
174 | while self.mapping_idx[0] != idx-1:
175 | time.sleep(0.1)
176 | pre_c2w = self.estimate_c2w_list[idx-1].to(device)
177 | elif self.sync_method == 'loose':
178 | # mapping idx can be later than tracking idx is within the bound of
179 | # [-self.every_frame-self.every_frame//2, -self.every_frame+self.every_frame//2]
180 | while self.mapping_idx[0] < idx-self.every_frame-self.every_frame//2:
181 | time.sleep(0.1)
182 | elif self.sync_method == 'free':
183 | # pure parallel, if mesh/vis happens may cause inbalance
184 | pass
185 |
186 | self.update_para_from_mapping()
187 |
188 | if self.verbose:
189 | print(Fore.MAGENTA)
190 | print("Tracking Frame ", idx.item())
191 | print(Style.RESET_ALL)
192 |
193 |
194 |
195 | if idx == 0 or self.gt_camera:
196 | c2w = gt_c2w
197 | if not self.no_vis_on_first_frame:
198 | self.visualizer.vis(
199 | idx, 0, gt_depth, gt_color, c2w, self.c, self.decoders, tsdf_volume, tsdf_bnds)
200 |
201 | else:
202 | gt_camera_tensor = get_tensor_from_camera(gt_c2w)
203 | if self.const_speed_assumption and idx-2 >= 0:
204 | pre_c2w = pre_c2w.float()
205 | delta = pre_c2w@self.estimate_c2w_list[idx-2].to(
206 | device).float().inverse()
207 | estimated_new_cam_c2w = delta@pre_c2w
208 | else:
209 | estimated_new_cam_c2w = pre_c2w
210 |
211 | camera_tensor = get_tensor_from_camera(
212 | estimated_new_cam_c2w.detach())
213 | if self.seperate_LR:
214 | camera_tensor = camera_tensor.to(device).detach()
215 | T = camera_tensor[-3:]
216 | quad = camera_tensor[:4]
217 | cam_para_list_quad = [quad]
218 | quad = Variable(quad, requires_grad=True)
219 | T = Variable(T, requires_grad=True)
220 | camera_tensor = torch.cat([quad, T], 0)
221 | cam_para_list_T = [T]
222 | cam_para_list_quad = [quad]
223 | optimizer_camera = torch.optim.Adam([{'params': cam_para_list_T, 'lr': self.cam_lr},
224 | {'params': cam_para_list_quad, 'lr': self.cam_lr*0.2}])
225 | else:
226 | camera_tensor = Variable(
227 | camera_tensor.to(device), requires_grad=True)
228 | cam_para_list = [camera_tensor]
229 | optimizer_camera = torch.optim.Adam(
230 | cam_para_list, lr=self.cam_lr)
231 |
232 | initial_loss_camera_tensor = torch.abs(
233 | gt_camera_tensor.to(device)-camera_tensor).mean().item()
234 | candidate_cam_tensor = None
235 | current_min_loss = 10000000000.
236 |
237 |
238 |
239 | for cam_iter in range(self.num_cam_iters):
240 | if self.seperate_LR:
241 | camera_tensor = torch.cat([quad, T], 0).to(self.device)
242 |
243 | self.visualizer.vis(
244 | idx, cam_iter, gt_depth, gt_color, camera_tensor, self.c, self.decoders, tsdf_volume, tsdf_bnds)
245 |
246 | loss = self.optimize_cam_in_batch(
247 | camera_tensor, gt_color, gt_depth, self.tracking_pixels, optimizer_camera, tsdf_volume)
248 |
249 | if cam_iter == 0:
250 | initial_loss = loss
251 |
252 | loss_camera_tensor = torch.abs(
253 | gt_camera_tensor.to(device)-camera_tensor).mean().item()
254 | if self.verbose:
255 | if cam_iter == self.num_cam_iters-1:
256 | print(
257 | f'Re-rendering loss: {initial_loss:.2f}->{loss:.2f} ' +
258 | f'camera tensor error: {initial_loss_camera_tensor:.4f}->{loss_camera_tensor:.4f}')
259 | if loss < current_min_loss:
260 | current_min_loss = loss
261 | candidate_cam_tensor = camera_tensor.clone().detach()
262 | bottom = torch.from_numpy(np.array([0, 0, 0, 1.]).reshape(
263 | [1, 4])).type(torch.float32).to(self.device)
264 | c2w = get_camera_from_tensor(
265 | candidate_cam_tensor.clone().detach())
266 | c2w = torch.cat([c2w, bottom], dim=0)
267 |
268 |
269 | self.estimate_c2w_list[idx] = c2w.clone().cpu()
270 | self.gt_c2w_list[idx] = gt_c2w.clone().cpu()
271 | pre_c2w = c2w.clone()
272 | self.idx[0] = idx
273 | if self.low_gpu_mem:
274 | torch.cuda.empty_cache()
275 |
276 |
277 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__init__.py
--------------------------------------------------------------------------------
/src/__pycache__/DF_Prior.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/DF_Prior.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/Mapper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Mapper.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/Mapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Mapper.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/NICE_SLAM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/NICE_SLAM.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/NICE_SLAM.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/NICE_SLAM.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/Tracker.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Tracker.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/Tracker.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Tracker.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/common.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/common.cpython-310.pyc
--------------------------------------------------------------------------------
/src/__pycache__/common.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/common.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/src/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/depth2pointcloud.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/depth2pointcloud.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/fusion.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/fusion.cpython-310.pyc
--------------------------------------------------------------------------------
/src/__pycache__/fusion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/fusion.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/fusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/fusion.cpython-38.pyc
--------------------------------------------------------------------------------
/src/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 |
8 | def as_intrinsics_matrix(intrinsics):
9 | """
10 | Get matrix representation of intrinsics.
11 |
12 | """
13 | K = np.eye(3)
14 | K[0, 0] = intrinsics[0]
15 | K[1, 1] = intrinsics[1]
16 | K[0, 2] = intrinsics[2]
17 | K[1, 2] = intrinsics[3]
18 | return K
19 |
20 |
21 | def sample_pdf(bins, weights, N_samples, det=False, device='cuda:0'):
22 | """
23 | Hierarchical sampling in NeRF paper (section 5.2).
24 |
25 | """
26 | # Get pdf - probability density function
27 | weights = weights + 1e-5 # prevent nans
28 | pdf = weights / torch.sum(weights, -1, keepdim=True)
29 | cdf = torch.cumsum(pdf, -1)
30 | # (batch, len(bins))
31 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
32 |
33 | # Take uniform samples
34 | if det:
35 | u = torch.linspace(0., 1., steps=N_samples)
36 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
37 | else:
38 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
39 |
40 | u = u.to(device)
41 | # Invert CDF
42 | u = u.contiguous()
43 | try:
44 | # this should work fine with the provided environment.yaml
45 | inds = torch.searchsorted(cdf, u, right=True)
46 | except:
47 | # for lower version torch that does not have torch.searchsorted,
48 | # you need to manually install from
49 | # https://github.com/aliutkus/torchsearchsorted
50 | from torchsearchsorted import searchsorted
51 | inds = searchsorted(cdf, u, side='right')
52 | below = torch.max(torch.zeros_like(inds-1), inds-1)
53 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
54 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
55 |
56 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
57 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
58 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
59 |
60 | denom = (cdf_g[..., 1]-cdf_g[..., 0])
61 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
62 | t = (u-cdf_g[..., 0])/denom
63 | samples = bins_g[..., 0] + t * (bins_g[..., 1]-bins_g[..., 0])
64 |
65 | return samples
66 |
67 |
68 | def random_select(l, k):
69 | """
70 | Random select k values from 0..l.
71 |
72 | """
73 | return list(np.random.permutation(np.array(range(l)))[:min(l, k)])
74 |
75 |
76 | def get_rays_from_uv(i, j, c2w, H, W, fx, fy, cx, cy, device):
77 | """
78 | Get corresponding rays from input uv.
79 |
80 | """
81 | if isinstance(c2w, np.ndarray):
82 | c2w = torch.from_numpy(c2w).to(device)
83 |
84 | dirs = torch.stack(
85 | [(i-cx)/fx, -(j-cy)/fy, -torch.ones_like(i)], -1).to(device)
86 | dirs = dirs.reshape(-1, 1, 3)
87 | # Rotate ray directions from camera frame to the world frame
88 | # dot product, equals to: [c2w.dot(dir) for dir in dirs]
89 | rays_d = torch.sum(dirs * c2w[:3, :3], -1)
90 | rays_o = c2w[:3, -1].expand(rays_d.shape)
91 | return rays_o, rays_d
92 |
93 |
94 | def select_uv(i, j, n, depth, color, device='cuda:0'):
95 | """
96 | Select n uv from dense uv.
97 |
98 | """
99 | i = i.reshape(-1)
100 | j = j.reshape(-1)
101 | indices = torch.randint(i.shape[0], (n,), device=device)
102 | indices = indices.clamp(0, i.shape[0])
103 | i = i[indices] # (n)
104 | j = j[indices] # (n)
105 | depth = depth.reshape(-1)
106 | color = color.reshape(-1, 3)
107 | depth = depth[indices] # (n)
108 | color = color[indices] # (n,3)
109 | return i, j, depth, color
110 |
111 |
112 | def get_sample_uv(H0, H1, W0, W1, n, depth, color, device='cuda:0'):
113 | """
114 | Sample n uv coordinates from an image region H0..H1, W0..W1
115 |
116 | """
117 | depth = depth[H0:H1, W0:W1]
118 | color = color[H0:H1, W0:W1]
119 | i, j = torch.meshgrid(torch.linspace(
120 | W0, W1-1, W1-W0).to(device), torch.linspace(H0, H1-1, H1-H0).to(device))
121 | i = i.t() # transpose
122 | j = j.t()
123 | i, j, depth, color = select_uv(i, j, n, depth, color, device=device)
124 | return i, j, depth, color
125 |
126 |
127 | def get_samples(H0, H1, W0, W1, n, H, W, fx, fy, cx, cy, c2w, depth, color, device):
128 | """
129 | Get n rays from the image region H0..H1, W0..W1.
130 | c2w is its camera pose and depth/color is the corresponding image tensor.
131 |
132 | """
133 | i, j, sample_depth, sample_color = get_sample_uv(
134 | H0, H1, W0, W1, n, depth, color, device=device)
135 | rays_o, rays_d = get_rays_from_uv(i, j, c2w, H, W, fx, fy, cx, cy, device)
136 | return rays_o, rays_d, sample_depth, sample_color
137 |
138 |
139 | def quad2rotation(quad):
140 | """
141 | Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing.
142 |
143 | Args:
144 | quad (tensor, batch_size*4): quaternion.
145 |
146 | Returns:
147 | rot_mat (tensor, batch_size*3*3): rotation.
148 | """
149 | bs = quad.shape[0]
150 | qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3]
151 | two_s = 2.0 / (quad * quad).sum(-1)
152 | rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device())
153 | rot_mat[:, 0, 0] = 1 - two_s * (qj ** 2 + qk ** 2)
154 | rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr)
155 | rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr)
156 | rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr)
157 | rot_mat[:, 1, 1] = 1 - two_s * (qi ** 2 + qk ** 2)
158 | rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr)
159 | rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr)
160 | rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr)
161 | rot_mat[:, 2, 2] = 1 - two_s * (qi ** 2 + qj ** 2)
162 | return rot_mat
163 |
164 |
165 | def get_camera_from_tensor(inputs):
166 | """
167 | Convert quaternion and translation to transformation matrix.
168 |
169 | """
170 | N = len(inputs.shape)
171 | if N == 1:
172 | inputs = inputs.unsqueeze(0)
173 | quad, T = inputs[:, :4], inputs[:, 4:]
174 | R = quad2rotation(quad)
175 | RT = torch.cat([R, T[:, :, None]], 2)
176 | if N == 1:
177 | RT = RT[0]
178 | return RT
179 |
180 |
181 | def get_tensor_from_camera(RT, Tquad=False):
182 | """
183 | Convert transformation matrix to quaternion and translation.
184 |
185 | """
186 | gpu_id = -1
187 | if type(RT) == torch.Tensor:
188 | if RT.get_device() != -1:
189 | RT = RT.detach().cpu()
190 | gpu_id = RT.get_device()
191 | RT = RT.numpy()
192 | from mathutils import Matrix
193 | R, T = RT[:3, :3], RT[:3, 3]
194 | rot = Matrix(R)
195 | quad = rot.to_quaternion()
196 | if Tquad:
197 | tensor = np.concatenate([T, quad], 0)
198 | else:
199 | tensor = np.concatenate([quad, T], 0)
200 | tensor = torch.from_numpy(tensor).float()
201 | if gpu_id != -1:
202 | tensor = tensor.to(gpu_id)
203 | return tensor
204 |
205 |
206 | def raw2outputs_nerf_color(raw, z_vals, rays_d, occupancy=False, device='cuda:0'):
207 | """
208 | Transforms model's predictions to semantically meaningful values.
209 |
210 | Args:
211 | raw (tensor, N_rays*N_samples*4): prediction from model.
212 | z_vals (tensor, N_rays*N_samples): integration time.
213 | rays_d (tensor, N_rays*3): direction of each ray.
214 | occupancy (bool, optional): occupancy or volume density. Defaults to False.
215 | device (str, optional): device. Defaults to 'cuda:0'.
216 |
217 | Returns:
218 | depth_map (tensor, N_rays): estimated distance to object.
219 | depth_var (tensor, N_rays): depth variance/uncertainty.
220 | rgb_map (tensor, N_rays*3): estimated RGB color of a ray.
221 | weights (tensor, N_rays*N_samples): weights assigned to each sampled color.
222 | """
223 |
224 | def raw2alpha(raw, dists, act_fn=F.relu): return 1. - \
225 | torch.exp(-act_fn(raw)*dists)
226 | dists = z_vals[..., 1:] - z_vals[..., :-1]
227 | dists = dists.float()
228 | dists = torch.cat([dists, torch.Tensor([1e10]).float().to(
229 | device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
230 |
231 | # different ray angle corresponds to different unit length
232 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
233 | rgb = raw[..., :-1]
234 | if occupancy:
235 |
236 | raw[..., 3] = torch.sigmoid(10 * raw[..., -1]) #sigmoid tanh -1,1 # when occ do belong to 0 - 1
237 | alpha = raw[..., -1]
238 |
239 | alpha_theta = 0
240 | else:
241 | # original nerf, volume density
242 | alpha = raw2alpha(raw[..., -1], dists) # (N_rays, N_samples)
243 |
244 | weights = alpha.float() * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)).to(
245 | device).float(), (1.-alpha + 1e-10).float()], -1).float(), -1)[:, :-1]
246 |
247 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # (N_rays, 3)
248 | depth_map = torch.sum(weights * z_vals, -1) # (N_rays)
249 | tmp = (z_vals-depth_map.unsqueeze(-1)) # (N_rays, N_samples)
250 | depth_var = torch.sum(weights*tmp*tmp, dim=1) # (N_rays)
251 | return depth_map, depth_var, rgb_map, weights
252 |
253 |
254 | def get_rays(H, W, fx, fy, cx, cy, c2w, device):
255 | """
256 | Get rays for a whole image.
257 |
258 | """
259 | if isinstance(c2w, np.ndarray):
260 | c2w = torch.from_numpy(c2w)
261 | # pytorch's meshgrid has indexing='ij'
262 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))
263 | i = i.t() # transpose
264 | j = j.t()
265 | dirs = torch.stack(
266 | [(i-cx)/fx, -(j-cy)/fy, -torch.ones_like(i)], -1).to(device)
267 | dirs = dirs.reshape(H, W, 1, 3)
268 | # Rotate ray directions from camera frame to the world frame
269 | # dot product, equals to: [c2w.dot(dir) for dir in dirs]
270 | rays_d = torch.sum(dirs * c2w[:3, :3], -1)
271 | rays_o = c2w[:3, -1].expand(rays_d.shape)
272 | return rays_o, rays_d
273 |
274 |
275 | def normalize_3d_coordinate(p, bound):
276 | """
277 | Normalize coordinate to [-1, 1], corresponds to the bounding box given.
278 |
279 | Args:
280 | p (tensor, N*3): coordinate.
281 | bound (tensor, 3*2): the scene bound.
282 |
283 | Returns:
284 | p (tensor, N*3): normalized coordinate.
285 | """
286 | p = p.reshape(-1, 3)
287 | p[:, 0] = ((p[:, 0]-bound[0, 0])/(bound[0, 1]-bound[0, 0]))*2-1.0
288 | p[:, 1] = ((p[:, 1]-bound[1, 0])/(bound[1, 1]-bound[1, 0]))*2-1.0
289 | p[:, 2] = ((p[:, 2]-bound[2, 0])/(bound[2, 1]-bound[2, 0]))*2-1.0
290 | return p
291 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from src import conv_onet
3 |
4 |
5 | method_dict = {
6 | 'conv_onet': conv_onet
7 | }
8 |
9 |
10 | def load_config(path, default_path=None):
11 | """
12 | Loads config file.
13 |
14 | Args:
15 | path (str): path to config file.
16 | default_path (str, optional): whether to use default path. Defaults to None.
17 |
18 | Returns:
19 | cfg (dict): config dict.
20 |
21 | """
22 | # load configuration from file itself
23 | with open(path, 'r') as f:
24 | cfg_special = yaml.full_load(f)
25 |
26 | # check if we should inherit from a config
27 | inherit_from = cfg_special.get('inherit_from')
28 |
29 | # if yes, load this config first as default
30 | # if no, use the default_path
31 | if inherit_from is not None:
32 | cfg = load_config(inherit_from, default_path)
33 | elif default_path is not None:
34 | with open(default_path, 'r') as f:
35 | cfg = yaml.full_load(f)
36 | else:
37 | cfg = dict()
38 |
39 | # include main configuration
40 | update_recursive(cfg, cfg_special)
41 |
42 | return cfg
43 |
44 |
45 | def update_recursive(dict1, dict2):
46 | """
47 | Update two config dictionaries recursively.
48 |
49 | Args:
50 | dict1 (dict): first dictionary to be updated.
51 | dict2 (dict): second dictionary which entries should be used.
52 | """
53 | for k, v in dict2.items():
54 | if k not in dict1:
55 | dict1[k] = dict()
56 | if isinstance(v, dict):
57 | update_recursive(dict1[k], v)
58 | else:
59 | dict1[k] = v
60 |
61 |
62 | # Models
63 | def get_model(cfg):
64 | """
65 | Returns the model instance.
66 |
67 | Args:
68 | cfg (dict): config dictionary.
69 |
70 |
71 | Returns:
72 | model (nn.module): network model.
73 | """
74 |
75 | method = 'conv_onet'
76 | model = method_dict[method].config.get_model(cfg)
77 |
78 | return model
79 |
--------------------------------------------------------------------------------
/src/conv_onet/__init__.py:
--------------------------------------------------------------------------------
1 | from src.conv_onet import (
2 | config, models
3 | )
4 |
5 | __all__ = [
6 | config, models
7 | ]
8 |
--------------------------------------------------------------------------------
/src/conv_onet/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/conv_onet/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/conv_onet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/conv_onet/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/src/conv_onet/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/src/conv_onet/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/src/conv_onet/config.py:
--------------------------------------------------------------------------------
1 | from src.conv_onet import models
2 |
3 |
4 | def get_model(cfg):
5 | """
6 | Return the network model.
7 |
8 | Args:
9 | cfg (dict): imported yaml config.
10 |
11 | Returns:
12 | decoder (nn.module): the network model.
13 | """
14 |
15 | dim = cfg['data']['dim']
16 | low_grid_len = cfg['grid_len']['low']
17 | high_grid_len = cfg['grid_len']['high']
18 | color_grid_len = cfg['grid_len']['color']
19 | c_dim = cfg['model']['c_dim'] # feature dimensions
20 | pos_embedding_method = cfg['model']['pos_embedding_method']
21 |
22 | decoder = models.decoder_dict['dfprior'](
23 | dim=dim, c_dim=c_dim,
24 | low_grid_len=low_grid_len, high_grid_len=high_grid_len,
25 | color_grid_len=color_grid_len, pos_embedding_method=pos_embedding_method)
26 |
27 | return decoder
28 |
--------------------------------------------------------------------------------
/src/conv_onet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from src.conv_onet.models import decoder
2 |
3 | # Decoder dictionary
4 | decoder_dict = {
5 | 'dfprior': decoder.DF
6 | }
--------------------------------------------------------------------------------
/src/conv_onet/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/conv_onet/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/conv_onet/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/conv_onet/models/__pycache__/decoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/decoder.cpython-310.pyc
--------------------------------------------------------------------------------
/src/conv_onet/models/__pycache__/decoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/decoder.cpython-37.pyc
--------------------------------------------------------------------------------
/src/conv_onet/models/__pycache__/decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/src/conv_onet/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from src.common import normalize_3d_coordinate
5 |
6 |
7 | class GaussianFourierFeatureTransform(torch.nn.Module):
8 | """
9 | Modified based on the implementation of Gaussian Fourier feature mapping.
10 |
11 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
12 | https://arxiv.org/abs/2006.10739
13 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
14 |
15 | """
16 |
17 | def __init__(self, num_input_channels, mapping_size=93, scale=25, learnable=True):
18 | super().__init__()
19 |
20 | if learnable:
21 | self._B = nn.Parameter(torch.randn(
22 | (num_input_channels, mapping_size)) * scale)
23 | else:
24 | self._B = torch.randn((num_input_channels, mapping_size)) * scale
25 |
26 | def forward(self, x):
27 | x = x.squeeze(0)
28 | assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(x.dim())
29 | x = x @ self._B.to(x.device)
30 | return torch.sin(x)
31 |
32 |
33 | class Nerf_positional_embedding(torch.nn.Module):
34 | """
35 | Nerf positional embedding.
36 |
37 | """
38 |
39 | def __init__(self, multires, log_sampling=True):
40 | super().__init__()
41 | self.log_sampling = log_sampling
42 | self.include_input = True
43 | self.periodic_fns = [torch.sin, torch.cos]
44 | self.max_freq_log2 = multires-1
45 | self.num_freqs = multires
46 | self.max_freq = self.max_freq_log2
47 | self.N_freqs = self.num_freqs
48 |
49 | def forward(self, x):
50 | x = x.squeeze(0)
51 | assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(
52 | x.dim())
53 |
54 | if self.log_sampling:
55 | freq_bands = 2.**torch.linspace(0.,
56 | self.max_freq, steps=self.N_freqs)
57 | else:
58 | freq_bands = torch.linspace(
59 | 2.**0., 2.**self.max_freq, steps=self.N_freqs)
60 | output = []
61 | if self.include_input:
62 | output.append(x)
63 | for freq in freq_bands:
64 | for p_fn in self.periodic_fns:
65 | output.append(p_fn(x * freq))
66 | ret = torch.cat(output, dim=1)
67 | return ret
68 |
69 |
70 | class DenseLayer(nn.Linear):
71 | def __init__(self, in_dim: int, out_dim: int, activation: str = "relu", *args, **kwargs) -> None:
72 | self.activation = activation
73 | super().__init__(in_dim, out_dim, *args, **kwargs)
74 |
75 | def reset_parameters(self) -> None:
76 | torch.nn.init.xavier_uniform_(
77 | self.weight, gain=torch.nn.init.calculate_gain(self.activation))
78 | if self.bias is not None:
79 | torch.nn.init.zeros_(self.bias)
80 |
81 |
82 | class Same(nn.Module):
83 | def __init__(self):
84 | super().__init__()
85 |
86 | def forward(self, x):
87 | x = x.squeeze(0)
88 | return x
89 |
90 |
91 | class MLP(nn.Module):
92 | """
93 | Decoder. Point coordinates not only used in sampling the feature grids, but also as MLP input.
94 |
95 | Args:
96 | name (str): name of this decoder.
97 | dim (int): input dimension.
98 | c_dim (int): feature dimension.
99 | hidden_size (int): hidden size of Decoder network.
100 | n_blocks (int): number of layers.
101 | leaky (bool): whether to use leaky ReLUs.
102 | sample_mode (str): sampling feature strategy, bilinear|nearest.
103 | color (bool): whether or not to output color.
104 | skips (list): list of layers to have skip connections.
105 | grid_len (float): voxel length of its corresponding feature grid.
106 | pos_embedding_method (str): positional embedding method.
107 | concat_feature (bool): whether to get feature from low level and concat to the current feature.
108 | """
109 |
110 | def __init__(self, name='', dim=3, c_dim=128,
111 | hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear',
112 | color=False, skips=[2], grid_len=0.16, pos_embedding_method='fourier', concat_feature=False):
113 | super().__init__()
114 | self.name = name
115 | self.color = color
116 | self.no_grad_feature = False
117 | self.c_dim = c_dim
118 | self.grid_len = grid_len
119 | self.concat_feature = concat_feature
120 | self.n_blocks = n_blocks
121 | self.skips = skips
122 |
123 | if c_dim != 0:
124 | self.fc_c = nn.ModuleList([
125 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks)
126 | ])
127 |
128 | if pos_embedding_method == 'fourier':
129 | embedding_size = 93
130 | self.embedder = GaussianFourierFeatureTransform(
131 | dim, mapping_size=embedding_size, scale=25)
132 | elif pos_embedding_method == 'same':
133 | embedding_size = 3
134 | self.embedder = Same()
135 | elif pos_embedding_method == 'nerf':
136 | if 'color' in name:
137 | multires = 10
138 | self.embedder = Nerf_positional_embedding(
139 | multires, log_sampling=True)
140 | else:
141 | multires = 5
142 | self.embedder = Nerf_positional_embedding(
143 | multires, log_sampling=False)
144 | embedding_size = multires*6+3
145 | elif pos_embedding_method == 'fc_relu':
146 | embedding_size = 93
147 | self.embedder = DenseLayer(dim, embedding_size, activation='relu')
148 |
149 | self.pts_linears = nn.ModuleList(
150 | [DenseLayer(embedding_size, hidden_size, activation="relu")] +
151 | [DenseLayer(hidden_size, hidden_size, activation="relu") if i not in self.skips
152 | else DenseLayer(hidden_size + embedding_size, hidden_size, activation="relu") for i in range(n_blocks-1)])
153 |
154 | if self.color:
155 | self.output_linear = DenseLayer(
156 | hidden_size, 4, activation="linear")
157 | else:
158 | self.output_linear = DenseLayer(
159 | hidden_size, 1, activation="linear")
160 |
161 | if not leaky:
162 | self.actvn = F.relu
163 | else:
164 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
165 |
166 | self.sample_mode = sample_mode
167 |
168 | def sample_grid_feature(self, p, c):
169 | p_nor = normalize_3d_coordinate(p.clone(), self.bound)
170 | p_nor = p_nor.unsqueeze(0)
171 | vgrid = p_nor[:, :, None, None].float()
172 | # acutally trilinear interpolation if mode = 'bilinear'
173 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True,
174 | mode=self.sample_mode).squeeze(-1).squeeze(-1)
175 | return c
176 |
177 | def forward(self, p, c_grid=None):
178 | if self.c_dim != 0:
179 | c = self.sample_grid_feature(
180 | p, c_grid['grid_' + self.name]).transpose(1, 2).squeeze(0)
181 |
182 | if self.concat_feature:
183 | # only happen to high decoder, get feature from low level and concat to the current feature
184 | with torch.no_grad():
185 | c_low = self.sample_grid_feature(
186 | p, c_grid['grid_low']).transpose(1, 2).squeeze(0)
187 | c = torch.cat([c, c_low], dim=1)
188 |
189 | p = p.float()
190 |
191 | embedded_pts = self.embedder(p)
192 | h = embedded_pts
193 | for i, l in enumerate(self.pts_linears):
194 | h = self.pts_linears[i](h)
195 | h = F.relu(h)
196 | if self.c_dim != 0:
197 | h = h + self.fc_c[i](c)
198 | if i in self.skips:
199 | h = torch.cat([embedded_pts, h], -1)
200 | out = self.output_linear(h)
201 | if not self.color:
202 | out = out.squeeze(-1)
203 | return out
204 |
205 |
206 | class mlp_tsdf(nn.Module):
207 | """
208 | Attention-based MLP.
209 |
210 | """
211 |
212 | def __init__(self):
213 | super().__init__()
214 |
215 | self.no_grad_feature = False
216 | self.sample_mode = 'bilinear'
217 |
218 | self.pts_linears = nn.ModuleList(
219 | [DenseLayer(2, 64, activation="relu")] +
220 | [DenseLayer(64, 128, activation="relu")] +
221 | [DenseLayer(128, 128, activation="relu")] +
222 | [DenseLayer(128, 64, activation="relu")])
223 |
224 | self.output_linear = DenseLayer(
225 | 64, 2, activation="linear") #linear
226 |
227 | self.softmax = nn.Softmax(dim=1)
228 | self.sigmoid = nn.Sigmoid()
229 |
230 | def sample_grid_tsdf(self, p, tsdf_volume, tsdf_bnds, device='cuda:0'):
231 | p_nor = normalize_3d_coordinate(p.clone(), tsdf_bnds)
232 | p_nor = p_nor.unsqueeze(0)
233 | vgrid = p_nor[:, :, None, None].float()
234 | # acutally trilinear interpolation if mode = 'bilinear'
235 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True,
236 | mode='bilinear').squeeze(-1).squeeze(-1)
237 |
238 | return tsdf_value
239 |
240 | def forward(self, p, occ, tsdf_volume, tsdf_bnds, **kwargs):
241 | tsdf_val = self.sample_grid_tsdf(p, tsdf_volume, tsdf_bnds, device='cuda:0')
242 | tsdf_val = tsdf_val.squeeze(0)
243 |
244 | tsdf_val = 1. - (tsdf_val + 1.) / 2. #0,1
245 | tsdf_val = torch.clamp(tsdf_val, 0.0, 1.0)
246 | occ = occ.reshape(tsdf_val.shape)
247 | inv_tsdf = -0.1 * torch.log((1 / (tsdf_val + 1e-8)) - 1 + 1e-7) #0.1
248 | inv_tsdf = torch.clamp(inv_tsdf, -100.0, 100.0)
249 | input = torch.cat([occ, inv_tsdf], dim=0)
250 | h = input.t()
251 | for i, l in enumerate(self.pts_linears):
252 | h = self.pts_linears[i](h)
253 | h = F.relu(h)
254 | weight = self.output_linear(h)
255 | weight = self.softmax(weight)
256 | out = weight.mul(input.t()).sum(dim=1)
257 |
258 | return out, weight[:, 1]
259 |
260 |
261 |
262 | class DF(nn.Module):
263 | """
264 | Neural Implicit Scalable Encoding.
265 |
266 | Args:
267 | dim (int): input dimension.
268 | c_dim (int): feature dimension.
269 | low_grid_len (float): voxel length in low grid.
270 | high_grid_len (float): voxel length in high grid.
271 | color_grid_len (float): voxel length in color grid.
272 | hidden_size (int): hidden size of decoder network
273 | pos_embedding_method (str): positional embedding method.
274 | """
275 |
276 | def __init__(self, dim=3, c_dim=32,
277 | low_grid_len=0.16, high_grid_len=0.16,
278 | color_grid_len=0.16, hidden_size=32, pos_embedding_method='fourier'):
279 | super().__init__()
280 |
281 |
282 | self.low_decoder = MLP(name='low', dim=dim, c_dim=c_dim, color=False,
283 | skips=[2], n_blocks=5, hidden_size=hidden_size,
284 | grid_len=low_grid_len, pos_embedding_method=pos_embedding_method)
285 | self.high_decoder = MLP(name='high', dim=dim, c_dim=c_dim*2, color=False,
286 | skips=[2], n_blocks=5, hidden_size=hidden_size,
287 | grid_len=high_grid_len, concat_feature=True, pos_embedding_method=pos_embedding_method)
288 | self.color_decoder = MLP(name='color', dim=dim, c_dim=c_dim, color=True,
289 | skips=[2], n_blocks=5, hidden_size=hidden_size,
290 | grid_len=color_grid_len, pos_embedding_method=pos_embedding_method)
291 |
292 | self.mlp = mlp_tsdf()
293 |
294 |
295 | def sample_grid_tsdf(self, p, tsdf_volume, tsdf_bnds, device='cuda:0'):
296 |
297 | p_nor = normalize_3d_coordinate(p.clone(), tsdf_bnds)
298 | p_nor = p_nor.unsqueeze(0)
299 | vgrid = p_nor[:, :, None, None].float()
300 | # acutally trilinear interpolation if mode = 'bilinear'
301 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True,
302 | mode='bilinear').squeeze(-1).squeeze(-1)
303 | return tsdf_value
304 |
305 |
306 |
307 | def forward(self, p, c_grid, tsdf_volume, tsdf_bnds, stage='low', **kwargs):
308 | """
309 | Output occupancy/color in different stage.
310 | """
311 |
312 | device = f'cuda:{p.get_device()}'
313 | if stage == 'low':
314 | low_occ = self.low_decoder(p, c_grid)
315 | low_occ = low_occ.squeeze(0)
316 |
317 | w = torch.ones(low_occ.shape[0]).to(device)
318 | raw = torch.zeros(low_occ.shape[0], 4).to(device).float()
319 | raw[..., -1] = low_occ # new_occ
320 | return raw, w
321 | elif stage == 'high':
322 | high_occ = self.high_decoder(p, c_grid)
323 | raw = torch.zeros(high_occ.shape[0], 4).to(device).float()
324 | low_occ = self.low_decoder(p, c_grid)
325 | low_occ = low_occ.squeeze(0)
326 | f_add_m_occ = high_occ + low_occ
327 |
328 | eval_tsdf = self.sample_grid_tsdf(p, tsdf_volume, tsdf_bnds, device)
329 | eval_tsdf_mask = ((eval_tsdf > -1.0+1e-4) & (eval_tsdf < 1.0-1e-4))
330 | eval_tsdf_mask = eval_tsdf_mask.squeeze()
331 |
332 | w = torch.ones(low_occ.shape[0]).to(device)
333 | low_occ[eval_tsdf_mask], w[eval_tsdf_mask] = self.mlp(p[:, eval_tsdf_mask, :], f_add_m_occ[eval_tsdf_mask], tsdf_volume, tsdf_bnds)
334 | new_occ = low_occ
335 | new_occ = new_occ.squeeze(-1)
336 | raw[..., -1] = new_occ
337 | return raw, w
338 | elif stage == 'color':
339 | high_occ = self.high_decoder(p, c_grid)
340 | raw = self.color_decoder(p, c_grid)
341 | low_occ = self.low_decoder(p, c_grid)
342 | low_occ = low_occ.squeeze(0)
343 | f_add_m_occ = high_occ + low_occ
344 |
345 | eval_tsdf = self.sample_grid_tsdf(p, tsdf_volume, tsdf_bnds, device)
346 | eval_tsdf_mask = ((eval_tsdf > -1.0+1e-4) & (eval_tsdf < 1.0-1e-4))
347 | eval_tsdf_mask = eval_tsdf_mask.squeeze()
348 |
349 | w = torch.ones(low_occ.shape[0]).to(device)
350 | low_occ[eval_tsdf_mask], w[eval_tsdf_mask] = self.mlp(p[:, eval_tsdf_mask, :], f_add_m_occ[eval_tsdf_mask], tsdf_volume, tsdf_bnds)
351 | new_occ = low_occ
352 | raw[..., -1] = new_occ
353 | return raw, w
354 |
--------------------------------------------------------------------------------
/src/fusion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018 Andy Zeng
2 |
3 | import numpy as np
4 |
5 | from numba import njit, prange
6 | from skimage import measure
7 |
8 |
9 | try:
10 | import pycuda.driver as cuda
11 | import pycuda.autoinit
12 | from pycuda.compiler import SourceModule
13 | FUSION_GPU_MODE = 1
14 | except Exception as err:
15 | print('Warning: {}'.format(err))
16 | print('Failed to import PyCUDA. Running fusion in CPU mode.')
17 | FUSION_GPU_MODE = 0
18 |
19 | # FUSION_GPU_MODE = 0
20 |
21 | class TSDFVolume:
22 | """Volumetric TSDF Fusion of RGB-D Images.
23 | """
24 | def __init__(self, vol_bnds, voxel_size, use_gpu=True):
25 | """Constructor.
26 |
27 | Args:
28 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the
29 | xyz bounds (min/max) in meters.
30 | voxel_size (float): The volume discretization in meters.
31 | """
32 | vol_bnds = np.asarray(vol_bnds)
33 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)."
34 |
35 | # Define voxel volume parameters
36 | self._vol_bnds = vol_bnds
37 | self._voxel_size = float(voxel_size)
38 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF 5 *
39 | self._color_const = 256 * 256
40 |
41 | # Adjust volume bounds and ensure C-order contiguous
42 | self._vol_dim = np.ceil((self._vol_bnds[:,1]-self._vol_bnds[:,0])/self._voxel_size).copy(order='C').astype(int)
43 | self._vol_bnds[:,1] = self._vol_bnds[:,0]+self._vol_dim*self._voxel_size
44 | self._vol_origin = self._vol_bnds[:,0].copy(order='C').astype(np.float32)
45 |
46 | print("Voxel volume size: {} x {} x {} - # points: {:,}".format(
47 | self._vol_dim[0], self._vol_dim[1], self._vol_dim[2],
48 | self._vol_dim[0]*self._vol_dim[1]*self._vol_dim[2])
49 | )
50 |
51 | # Initialize pointers to voxel volume in CPU memory
52 | self._tsdf_vol_cpu = -1 * np.ones(self._vol_dim).astype(np.float32) #ones
53 | # for computing the cumulative moving average of observations per voxel
54 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
55 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
56 |
57 | self.gpu_mode = use_gpu and FUSION_GPU_MODE
58 |
59 | # Copy voxel volumes to GPU
60 | if self.gpu_mode:
61 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes)
62 | cuda.memcpy_htod(self._tsdf_vol_gpu,self._tsdf_vol_cpu)
63 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes)
64 | cuda.memcpy_htod(self._weight_vol_gpu,self._weight_vol_cpu)
65 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes)
66 | cuda.memcpy_htod(self._color_vol_gpu,self._color_vol_cpu)
67 |
68 | # Cuda kernel function (C++)
69 | self._cuda_src_mod = SourceModule("""
70 | __global__ void integrate(float * tsdf_vol,
71 | float * weight_vol,
72 | float * color_vol,
73 | float * vol_dim,
74 | float * vol_origin,
75 | float * cam_intr,
76 | float * cam_pose,
77 | float * other_params,
78 | float * color_im,
79 | float * depth_im) {
80 | // Get voxel index
81 | int gpu_loop_idx = (int) other_params[0];
82 | int max_threads_per_block = blockDim.x;
83 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x;
84 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x;
85 | int vol_dim_x = (int) vol_dim[0];
86 | int vol_dim_y = (int) vol_dim[1];
87 | int vol_dim_z = (int) vol_dim[2];
88 | if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z)
89 | return;
90 | // Get voxel grid coordinates (note: be careful when casting)
91 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z)));
92 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z));
93 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z);
94 | // Voxel grid coordinates to world coordinates
95 | float voxel_size = other_params[1];
96 | float pt_x = vol_origin[0]+voxel_x*voxel_size;
97 | float pt_y = vol_origin[1]+voxel_y*voxel_size;
98 | float pt_z = vol_origin[2]+voxel_z*voxel_size;
99 | // World coordinates to camera coordinates
100 | float tmp_pt_x = pt_x-cam_pose[0*4+3];
101 | float tmp_pt_y = pt_y-cam_pose[1*4+3];
102 | float tmp_pt_z = pt_z-cam_pose[2*4+3];
103 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z;
104 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z;
105 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z;
106 | // Camera coordinates to image pixels
107 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]);
108 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]);
109 | // Skip if outside view frustum
110 | int im_h = (int) other_params[2];
111 | int im_w = (int) other_params[3];
112 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0)
113 | return;
114 | // Skip invalid depth
115 | float depth_value = depth_im[pixel_y*im_w+pixel_x];
116 | if (depth_value == 0)
117 | return;
118 | // Integrate TSDF
119 | float trunc_margin = other_params[4];
120 | float depth_diff = depth_value-cam_pt_z;
121 | if (depth_diff < -trunc_margin)
122 | return;
123 | float dist = fmin(1.0f,depth_diff/trunc_margin);
124 | float w_old = weight_vol[voxel_idx];
125 | float obs_weight = other_params[5];
126 | float w_new = w_old + obs_weight;
127 | weight_vol[voxel_idx] = w_new;
128 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new;
129 | // Integrate color
130 | float old_color = color_vol[voxel_idx];
131 | float old_b = floorf(old_color/(256*256));
132 | float old_g = floorf((old_color-old_b*256*256)/256);
133 | float old_r = old_color-old_b*256*256-old_g*256;
134 | float new_color = color_im[pixel_y*im_w+pixel_x];
135 | float new_b = floorf(new_color/(256*256));
136 | float new_g = floorf((new_color-new_b*256*256)/256);
137 | float new_r = new_color-new_b*256*256-new_g*256;
138 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f);
139 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f);
140 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f);
141 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r;
142 | }""")
143 |
144 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate")
145 |
146 | # Determine block/grid size on GPU
147 | gpu_dev = cuda.Device(0)
148 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK
149 | n_blocks = int(np.ceil(float(np.prod(self._vol_dim))/float(self._max_gpu_threads_per_block)))
150 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X,int(np.floor(np.cbrt(n_blocks))))
151 | grid_dim_y = min(gpu_dev.MAX_GRID_DIM_Y,int(np.floor(np.sqrt(n_blocks/grid_dim_x))))
152 | grid_dim_z = min(gpu_dev.MAX_GRID_DIM_Z,int(np.ceil(float(n_blocks)/float(grid_dim_x*grid_dim_y))))
153 | self._max_gpu_grid_dim = np.array([grid_dim_x,grid_dim_y,grid_dim_z]).astype(int)
154 | self._n_gpu_loops = int(np.ceil(float(np.prod(self._vol_dim))/float(np.prod(self._max_gpu_grid_dim)*self._max_gpu_threads_per_block)))
155 |
156 | else:
157 | # Get voxel grid coordinates
158 | xv, yv, zv = np.meshgrid(
159 | range(self._vol_dim[0]),
160 | range(self._vol_dim[1]),
161 | range(self._vol_dim[2]),
162 | indexing='ij'
163 | )
164 | self.vox_coords = np.concatenate([
165 | xv.reshape(1,-1),
166 | yv.reshape(1,-1),
167 | zv.reshape(1,-1)
168 | ], axis=0).astype(int).T
169 |
170 | @staticmethod
171 | @njit(parallel=True)
172 | def vox2world(vol_origin, vox_coords, vox_size):
173 | """Convert voxel grid coordinates to world coordinates.
174 | """
175 | vol_origin = vol_origin.astype(np.float32)
176 | vox_coords = vox_coords.astype(np.float32)
177 | cam_pts = np.empty_like(vox_coords, dtype=np.float32)
178 | for i in prange(vox_coords.shape[0]):
179 | for j in range(3):
180 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j])
181 | return cam_pts
182 |
183 | @staticmethod
184 | @njit(parallel=True)
185 | def cam2pix(cam_pts, intr):
186 | """Convert camera coordinates to pixel coordinates.
187 | """
188 | intr = intr.astype(np.float32)
189 | fx, fy = intr[0, 0], intr[1, 1]
190 | cx, cy = intr[0, 2], intr[1, 2]
191 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
192 | for i in prange(cam_pts.shape[0]):
193 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
194 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
195 | return pix
196 |
197 | @staticmethod
198 | @njit(parallel=True)
199 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight):
200 | """Integrate the TSDF volume.
201 | """
202 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32)
203 | w_new = np.empty_like(w_old, dtype=np.float32)
204 | for i in prange(len(tsdf_vol)):
205 | w_new[i] = w_old[i] + obs_weight
206 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i]
207 | return tsdf_vol_int, w_new
208 |
209 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.):
210 | """Integrate an RGB-D frame into the TSDF volume.
211 |
212 | Args:
213 | color_im (ndarray): An RGB image of shape (H, W, 3).
214 | depth_im (ndarray): A depth image of shape (H, W).
215 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3).
216 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4).
217 | obs_weight (float): The weight to assign for the current observation. A higher
218 | value
219 | """
220 | im_h, im_w = depth_im.shape
221 |
222 | # Fold RGB color image into a single channel image
223 | color_im = color_im.astype(np.float32)
224 | color_im = np.floor(color_im[...,2]*self._color_const + color_im[...,1]*256 + color_im[...,0])
225 |
226 | if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel)
227 | for gpu_loop_idx in range(self._n_gpu_loops):
228 | self._cuda_integrate(self._tsdf_vol_gpu,
229 | self._weight_vol_gpu,
230 | self._color_vol_gpu,
231 | cuda.InOut(self._vol_dim.astype(np.float32)),
232 | cuda.InOut(self._vol_origin.astype(np.float32)),
233 | cuda.InOut(cam_intr.reshape(-1).astype(np.float32)),
234 | cuda.InOut(cam_pose.reshape(-1).astype(np.float32)),
235 | cuda.InOut(np.asarray([
236 | gpu_loop_idx,
237 | self._voxel_size,
238 | im_h,
239 | im_w,
240 | self._trunc_margin,
241 | obs_weight
242 | ], np.float32)),
243 | cuda.InOut(color_im.reshape(-1).astype(np.float32)),
244 | cuda.InOut(depth_im.reshape(-1).astype(np.float32)),
245 | block=(self._max_gpu_threads_per_block,1,1),
246 | grid=(
247 | int(self._max_gpu_grid_dim[0]),
248 | int(self._max_gpu_grid_dim[1]),
249 | int(self._max_gpu_grid_dim[2]),
250 | )
251 | )
252 | else: # CPU mode: integrate voxel volume (vectorized implementation)
253 | # Convert voxel grid coordinates to pixel coordinates
254 | cam_pts = self.vox2world(self._vol_origin, self.vox_coords, self._voxel_size)
255 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose))
256 | pix_z = cam_pts[:, 2]
257 | pix = self.cam2pix(cam_pts, cam_intr)
258 | pix_x, pix_y = pix[:, 0], pix[:, 1]
259 |
260 | # Eliminate pixels outside view frustum
261 | valid_pix = np.logical_and(pix_x >= 0,
262 | np.logical_and(pix_x < im_w,
263 | np.logical_and(pix_y >= 0,
264 | np.logical_and(pix_y < im_h,
265 | pix_z > 0))))
266 | depth_val = np.zeros(pix_x.shape)
267 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]]
268 |
269 | # Integrate TSDF
270 | depth_diff = depth_val - pix_z
271 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin)
272 | dist = np.minimum(1, depth_diff / self._trunc_margin)
273 | valid_vox_x = self.vox_coords[valid_pts, 0]
274 | valid_vox_y = self.vox_coords[valid_pts, 1]
275 | valid_vox_z = self.vox_coords[valid_pts, 2]
276 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
277 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
278 | valid_dist = dist[valid_pts]
279 | tsdf_vol_new, w_new = self.integrate_tsdf(tsdf_vals, valid_dist, w_old, obs_weight)
280 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
281 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new
282 |
283 | # Integrate color
284 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
285 | old_b = np.floor(old_color / self._color_const)
286 | old_g = np.floor((old_color-old_b*self._color_const)/256)
287 | old_r = old_color - old_b*self._color_const - old_g*256
288 | new_color = color_im[pix_y[valid_pts],pix_x[valid_pts]]
289 | new_b = np.floor(new_color / self._color_const)
290 | new_g = np.floor((new_color - new_b*self._color_const) /256)
291 | new_r = new_color - new_b*self._color_const - new_g*256
292 | new_b = np.minimum(255., np.round((w_old*old_b + obs_weight*new_b) / w_new))
293 | new_g = np.minimum(255., np.round((w_old*old_g + obs_weight*new_g) / w_new))
294 | new_r = np.minimum(255., np.round((w_old*old_r + obs_weight*new_r) / w_new))
295 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = new_b*self._color_const + new_g*256 + new_r
296 |
297 | def get_volume(self):
298 | if self.gpu_mode:
299 | cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu)
300 | cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu)
301 | return self._tsdf_vol_cpu, self._color_vol_cpu, self._vol_bnds
302 |
303 | def get_point_cloud(self):
304 | """Extract a point cloud from the voxel volume.
305 | """
306 | tsdf_vol, color_vol = self.get_volume()
307 |
308 | # Marching cubes
309 | verts = measure.marching_cubes(tsdf_vol, level=0)[0]
310 | verts_ind = np.round(verts).astype(int)
311 | verts = verts*self._voxel_size + self._vol_origin
312 |
313 | # Get vertex colors
314 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
315 | colors_b = np.floor(rgb_vals / self._color_const)
316 | colors_g = np.floor((rgb_vals - colors_b*self._color_const) / 256)
317 | colors_r = rgb_vals - colors_b*self._color_const - colors_g*256
318 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
319 | colors = colors.astype(np.uint8)
320 |
321 | pc = np.hstack([verts, colors])
322 | return pc
323 |
324 | def get_mesh(self):
325 | """Compute a mesh from the voxel volume using marching cubes.
326 | """
327 | tsdf_vol, color_vol, bnds = self.get_volume()
328 |
329 | # Marching cubes
330 | verts, faces, norms, vals = measure.marching_cubes(tsdf_vol, level=0)
331 | verts_ind = np.round(verts).astype(int)
332 | verts = verts*self._voxel_size+self._vol_origin # voxel grid coordinates to world coordinates
333 |
334 | # Get vertex colors
335 | rgb_vals = color_vol[verts_ind[:,0], verts_ind[:,1], verts_ind[:,2]]
336 | colors_b = np.floor(rgb_vals/self._color_const)
337 | colors_g = np.floor((rgb_vals-colors_b*self._color_const)/256)
338 | colors_r = rgb_vals-colors_b*self._color_const-colors_g*256
339 | colors = np.floor(np.asarray([colors_r,colors_g,colors_b])).T
340 | colors = colors.astype(np.uint8)
341 | return verts, faces, norms, colors
342 |
343 |
344 | def rigid_transform(xyz, transform):
345 | """Applies a rigid transform to an (N, 3) pointcloud.
346 | """
347 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
348 | xyz_t_h = np.dot(transform, xyz_h.T).T
349 | return xyz_t_h[:, :3]
350 |
351 |
352 | def get_view_frustum(depth_im, cam_intr, cam_pose):
353 | """Get corners of 3D camera view frustum of depth image
354 | """
355 | im_h = depth_im.shape[0]
356 | im_w = depth_im.shape[1]
357 | max_depth = np.max(depth_im)
358 | view_frust_pts = np.array([
359 | (np.array([0,0,0,im_w,im_w])-cam_intr[0,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[0,0],
360 | (np.array([0,0,im_h,0,im_h])-cam_intr[1,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[1,1],
361 | np.array([0,max_depth,max_depth,max_depth,max_depth])
362 | ])
363 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T
364 | return view_frust_pts
365 |
366 |
367 | def meshwrite(filename, verts, faces, norms, colors):
368 | """Save a 3D mesh to a polygon .ply file.
369 | """
370 | # Write header
371 | ply_file = open(filename,'w')
372 | ply_file.write("ply\n")
373 | ply_file.write("format ascii 1.0\n")
374 | ply_file.write("element vertex %d\n"%(verts.shape[0]))
375 | ply_file.write("property float x\n")
376 | ply_file.write("property float y\n")
377 | ply_file.write("property float z\n")
378 | ply_file.write("property float nx\n")
379 | ply_file.write("property float ny\n")
380 | ply_file.write("property float nz\n")
381 | ply_file.write("property uchar red\n")
382 | ply_file.write("property uchar green\n")
383 | ply_file.write("property uchar blue\n")
384 | ply_file.write("element face %d\n"%(faces.shape[0]))
385 | ply_file.write("property list uchar int vertex_index\n")
386 | ply_file.write("end_header\n")
387 |
388 | # Write vertex list
389 | for i in range(verts.shape[0]):
390 | ply_file.write("%f %f %f %f %f %f %d %d %d\n"%(
391 | verts[i,0], verts[i,1], verts[i,2],
392 | norms[i,0], norms[i,1], norms[i,2],
393 | colors[i,0], colors[i,1], colors[i,2],
394 | ))
395 |
396 | # Write face list
397 | for i in range(faces.shape[0]):
398 | ply_file.write("3 %d %d %d\n"%(faces[i,0], faces[i,1], faces[i,2]))
399 |
400 | ply_file.close()
401 |
402 |
403 | def pcwrite(filename, xyzrgb):
404 | """Save a point cloud to a polygon .ply file.
405 | """
406 | xyz = xyzrgb[:, :3]
407 | rgb = xyzrgb[:, 3:].astype(np.uint8)
408 |
409 | # Write header
410 | ply_file = open(filename,'w')
411 | ply_file.write("ply\n")
412 | ply_file.write("format ascii 1.0\n")
413 | ply_file.write("element vertex %d\n"%(xyz.shape[0]))
414 | ply_file.write("property float x\n")
415 | ply_file.write("property float y\n")
416 | ply_file.write("property float z\n")
417 | ply_file.write("property uchar red\n")
418 | ply_file.write("property uchar green\n")
419 | ply_file.write("property uchar blue\n")
420 | ply_file.write("end_header\n")
421 |
422 | # Write vertex list
423 | for i in range(xyz.shape[0]):
424 | ply_file.write("%f %f %f %d %d %d\n"%(
425 | xyz[i, 0], xyz[i, 1], xyz[i, 2],
426 | rgb[i, 0], rgb[i, 1], rgb[i, 2],
427 | ))
428 |
--------------------------------------------------------------------------------
/src/tools/cull_mesh.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | import trimesh
6 | from tqdm import tqdm
7 |
8 |
9 | def load_poses(path):
10 | poses = []
11 | with open(path, "r") as f:
12 | lines = f.readlines()
13 | for line in lines:
14 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
15 | c2w[:3, 1] *= -1
16 | c2w[:3, 2] *= -1
17 | c2w = torch.from_numpy(c2w).float()
18 | poses.append(c2w)
19 | return poses
20 |
21 |
22 | parser = argparse.ArgumentParser(
23 | description='Arguments to cull the mesh.'
24 | )
25 |
26 | parser.add_argument('--input_mesh', type=str,
27 | help='path to the mesh to be culled')
28 | parser.add_argument('--traj', type=str, help='path to the trajectory')
29 | parser.add_argument('--output_mesh', type=str, help='path to the output mesh')
30 | args = parser.parse_args()
31 |
32 | H = 680
33 | W = 1200
34 | fx = 600.0
35 | fy = 600.0
36 | fx = 600.0
37 | cx = 599.5
38 | cy = 339.5
39 | scale = 6553.5
40 |
41 | poses = load_poses(args.traj)
42 | n_imgs = len(poses)
43 | mesh = trimesh.load(args.input_mesh, process=False)
44 | pc = mesh.vertices
45 | faces = mesh.faces
46 |
47 | # delete mesh vertices that are not inside any camera's viewing frustum
48 | whole_mask = np.ones(pc.shape[0]).astype(np.bool)
49 | for i in tqdm(range(0, n_imgs, 1)):
50 | c2w = poses[i]
51 | points = pc.copy()
52 | points = torch.from_numpy(points).cuda()
53 | w2c = np.linalg.inv(c2w)
54 | w2c = torch.from_numpy(w2c).cuda().float()
55 | K = torch.from_numpy(
56 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda()
57 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
58 | homo_points = torch.cat(
59 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float()
60 | cam_cord_homo = w2c@homo_points
61 | cam_cord = cam_cord_homo[:, :3]
62 |
63 | cam_cord[:, 0] *= -1
64 | uv = K.float()@cam_cord.float()
65 | z = uv[:, -1:]+1e-5
66 | uv = uv[:, :2]/z
67 | uv = uv.float().squeeze(-1).cpu().numpy()
68 | edge = 0
69 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W -
70 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge)
71 | whole_mask &= ~mask
72 | pc = mesh.vertices
73 | faces = mesh.faces
74 | face_mask = whole_mask[mesh.faces].all(axis=1)
75 | mesh.update_faces(~face_mask)
76 | mesh.export(args.output_mesh)
77 |
--------------------------------------------------------------------------------
/src/tools/eval_ate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy
4 | import torch
5 | import sys
6 | import numpy as np
7 | sys.path.append('.')
8 | from src import config
9 | from src.common import get_tensor_from_camera
10 |
11 | def associate(first_list, second_list, offset=0.0, max_difference=0.02):
12 | """
13 | Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim
14 | to find the closest match for every input tuple.
15 |
16 | Input:
17 | first_list -- first dictionary of (stamp,data) tuples
18 | second_list -- second dictionary of (stamp,data) tuples
19 | offset -- time offset between both dictionaries (e.g., to model the delay between the sensors)
20 | max_difference -- search radius for candidate generation
21 |
22 | Output:
23 | matches -- list of matched tuples ((stamp1,data1),(stamp2,data2))
24 |
25 | """
26 | first_keys = list(first_list.keys())
27 | second_keys = list(second_list.keys())
28 | potential_matches = [(abs(a - (b + offset)), a, b)
29 | for a in first_keys
30 | for b in second_keys
31 | if abs(a - (b + offset)) < max_difference]
32 | potential_matches.sort()
33 | matches = []
34 | for diff, a, b in potential_matches:
35 | if a in first_keys and b in second_keys:
36 | first_keys.remove(a)
37 | second_keys.remove(b)
38 | matches.append((a, b))
39 |
40 | matches.sort()
41 | return matches
42 |
43 |
44 | def align(model, data):
45 | """Align two trajectories using the method of Horn (closed-form).
46 |
47 | Input:
48 | model -- first trajectory (3xn)
49 | data -- second trajectory (3xn)
50 |
51 | Output:
52 | rot -- rotation matrix (3x3)
53 | trans -- translation vector (3x1)
54 | trans_error -- translational error per point (1xn)
55 |
56 | """
57 | numpy.set_printoptions(precision=3, suppress=True)
58 | model_zerocentered = model - model.mean(1)
59 | data_zerocentered = data - data.mean(1)
60 |
61 | W = numpy.zeros((3, 3))
62 | for column in range(model.shape[1]):
63 | W += numpy.outer(model_zerocentered[:,
64 | column], data_zerocentered[:, column])
65 | U, d, Vh = numpy.linalg.linalg.svd(W.transpose())
66 | S = numpy.matrix(numpy.identity(3))
67 | if(numpy.linalg.det(U) * numpy.linalg.det(Vh) < 0):
68 | S[2, 2] = -1
69 | rot = U*S*Vh
70 | trans = data.mean(1) - rot * model.mean(1)
71 |
72 | model_aligned = rot * model + trans
73 | alignment_error = model_aligned - data
74 |
75 | trans_error = numpy.sqrt(numpy.sum(numpy.multiply(
76 | alignment_error, alignment_error), 0)).A[0]
77 |
78 | return rot, trans, trans_error
79 |
80 |
81 | def plot_traj(ax, stamps, traj, style, color, label):
82 | """
83 | Plot a trajectory using matplotlib.
84 |
85 | Input:
86 | ax -- the plot
87 | stamps -- time stamps (1xn)
88 | traj -- trajectory (3xn)
89 | style -- line style
90 | color -- line color
91 | label -- plot legend
92 |
93 | """
94 | stamps.sort()
95 | interval = numpy.median([s-t for s, t in zip(stamps[1:], stamps[:-1])])
96 | x = []
97 | y = []
98 | last = stamps[0]
99 | for i in range(len(stamps)):
100 | if stamps[i]-last < 2*interval:
101 | x.append(traj[i][0])
102 | y.append(traj[i][1])
103 | elif len(x) > 0:
104 | ax.plot(x, y, style, color=color, label=label)
105 | label = ""
106 | x = []
107 | y = []
108 | last = stamps[i]
109 | if len(x) > 0:
110 | ax.plot(x, y, style, color=color, label=label)
111 |
112 |
113 | def evaluate_ate(first_list, second_list, plot="", _args=""):
114 | # parse command line
115 | parser = argparse.ArgumentParser(
116 | description='This script computes the absolute trajectory error from the ground truth trajectory and the estimated trajectory.')
117 | # parser.add_argument('first_file', help='ground truth trajectory (format: timestamp tx ty tz qx qy qz qw)')
118 | # parser.add_argument('second_file', help='estimated trajectory (format: timestamp tx ty tz qx qy qz qw)')
119 | parser.add_argument(
120 | '--offset', help='time offset added to the timestamps of the second file (default: 0.0)', default=0.0)
121 | parser.add_argument(
122 | '--scale', help='scaling factor for the second trajectory (default: 1.0)', default=1.0)
123 | parser.add_argument(
124 | '--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)', default=0.02)
125 | parser.add_argument(
126 | '--save', help='save aligned second trajectory to disk (format: stamp2 x2 y2 z2)')
127 | parser.add_argument('--save_associations',
128 | help='save associated first and aligned second trajectory to disk (format: stamp1 x1 y1 z1 stamp2 x2 y2 z2)')
129 | parser.add_argument(
130 | '--plot', help='plot the first and the aligned second trajectory to an image (format: png)')
131 | parser.add_argument(
132 | '--verbose', help='print all evaluation data (otherwise, only the RMSE absolute translational error in meters after alignment will be printed)', action='store_true')
133 | args = parser.parse_args(_args)
134 | args.plot = plot
135 | # first_list = associate.read_file_list(args.first_file)
136 | # second_list = associate.read_file_list(args.second_file)
137 |
138 | matches = associate(first_list, second_list, float(
139 | args.offset), float(args.max_difference))
140 | if len(matches) < 2:
141 | raise ValueError(
142 | "Couldn't find matching timestamp pairs between groundtruth and estimated trajectory! \
143 | Did you choose the correct sequence?")
144 |
145 | first_xyz = numpy.matrix(
146 | [[float(value) for value in first_list[a][0:3]] for a, b in matches]).transpose()
147 | second_xyz = numpy.matrix([[float(value)*float(args.scale)
148 | for value in second_list[b][0:3]] for a, b in matches]).transpose()
149 |
150 | rot, trans, trans_error = align(second_xyz, first_xyz)
151 |
152 | second_xyz_aligned = rot * second_xyz + trans
153 |
154 | first_stamps = list(first_list.keys())
155 | first_stamps.sort()
156 | first_xyz_full = numpy.matrix(
157 | [[float(value) for value in first_list[b][0:3]] for b in first_stamps]).transpose()
158 |
159 | second_stamps = list(second_list.keys())
160 | second_stamps.sort()
161 | second_xyz_full = numpy.matrix([[float(value)*float(args.scale)
162 | for value in second_list[b][0:3]] for b in second_stamps]).transpose()
163 | second_xyz_full_aligned = rot * second_xyz_full + trans
164 |
165 | if args.verbose:
166 | print("compared_pose_pairs %d pairs" % (len(trans_error)))
167 |
168 | print("absolute_translational_error.rmse %f m" % numpy.sqrt(
169 | numpy.dot(trans_error, trans_error) / len(trans_error)))
170 | print("absolute_translational_error.mean %f m" %
171 | numpy.mean(trans_error))
172 | print("absolute_translational_error.median %f m" %
173 | numpy.median(trans_error))
174 | print("absolute_translational_error.std %f m" % numpy.std(trans_error))
175 | print("absolute_translational_error.min %f m" % numpy.min(trans_error))
176 | print("absolute_translational_error.max %f m" % numpy.max(trans_error))
177 |
178 | if args.save_associations:
179 | file = open(args.save_associations, "w")
180 | file.write("\n".join(["%f %f %f %f %f %f %f %f" % (a, x1, y1, z1, b, x2, y2, z2) for (
181 | a, b), (x1, y1, z1), (x2, y2, z2) in zip(matches, first_xyz.transpose().A, second_xyz_aligned.transpose().A)]))
182 | file.close()
183 |
184 | if args.save:
185 | file = open(args.save, "w")
186 | file.write("\n".join(["%f " % stamp+" ".join(["%f" % d for d in line])
187 | for stamp, line in zip(second_stamps, second_xyz_full_aligned.transpose().A)]))
188 | file.close()
189 |
190 | if args.plot:
191 | import matplotlib
192 | matplotlib.use('Agg')
193 | import matplotlib.pylab as pylab
194 | import matplotlib.pyplot as plt
195 | from matplotlib.patches import Ellipse
196 | fig = plt.figure()
197 | ax = fig.add_subplot(111)
198 | ATE = numpy.sqrt(
199 | numpy.dot(trans_error, trans_error) / len(trans_error))
200 | ax.set_title(f'len:{len(trans_error)} ATE RMSE:{ATE} {args.plot[:-3]}')
201 | plot_traj(ax, first_stamps, first_xyz_full.transpose().A,
202 | '-', "black", "ground truth")
203 | plot_traj(ax, second_stamps, second_xyz_full_aligned.transpose(
204 | ).A, '-', "blue", "estimated")
205 |
206 | label = "difference"
207 | for (a, b), (x1, y1, z1), (x2, y2, z2) in zip(matches, first_xyz.transpose().A, second_xyz_aligned.transpose().A):
208 | # ax.plot([x1,x2],[y1,y2],'-',color="red",label=label)
209 | label = ""
210 | ax.legend()
211 | ax.set_xlabel('x [m]')
212 | ax.set_ylabel('y [m]')
213 | plt.savefig(args.plot, dpi=90)
214 |
215 | return {
216 | "compared_pose_pairs": (len(trans_error)),
217 | "absolute_translational_error.rmse": numpy.sqrt(numpy.dot(trans_error, trans_error) / len(trans_error)),
218 | "absolute_translational_error.mean": numpy.mean(trans_error),
219 | "absolute_translational_error.median": numpy.median(trans_error),
220 | "absolute_translational_error.std": numpy.std(trans_error),
221 | "absolute_translational_error.min": numpy.min(trans_error),
222 | "absolute_translational_error.max": numpy.max(trans_error),
223 | }
224 |
225 |
226 | def evaluate(poses_gt, poses_est, plot):
227 |
228 | poses_gt = poses_gt.cpu().numpy()
229 | poses_est = poses_est.cpu().numpy()
230 |
231 | N = poses_gt.shape[0]
232 | poses_gt = dict([(i, poses_gt[i]) for i in range(N)])
233 | poses_est = dict([(i, poses_est[i]) for i in range(N)])
234 |
235 | results = evaluate_ate(poses_gt, poses_est, plot)
236 | print(results)
237 |
238 |
239 | def convert_poses(c2w_list, N, scale, gt=True):
240 | poses = []
241 | mask = torch.ones(N+1).bool()
242 | for idx in range(0, N+1):
243 | if gt:
244 | # some frame have `nan` or `inf` in gt pose of ScanNet,
245 | # but our system have estimated camera pose for all frames
246 | # therefore, when calculating the pose error, we need to mask out invalid pose
247 | if torch.isinf(c2w_list[idx]).any():
248 | mask[idx] = 0
249 | continue
250 | if torch.isnan(c2w_list[idx]).any():
251 | mask[idx] = 0
252 | continue
253 | c2w_list[idx][:3, 3] /= scale
254 | poses.append(get_tensor_from_camera(c2w_list[idx], Tquad=True))
255 | poses = torch.stack(poses)
256 | return poses, mask
257 |
258 |
259 | if __name__ == '__main__':
260 | """
261 | This ATE evaluation code is modified upon the evaluation code in lie-torch.
262 | """
263 |
264 | parser = argparse.ArgumentParser(
265 | description='Arguments to eval the tracking ATE.'
266 | )
267 | parser.add_argument('config', type=str, help='Path to config file.')
268 | parser.add_argument('--output', type=str,
269 | help='output folder, this have higher priority, can overwrite the one inconfig file')
270 | nice_parser = parser.add_mutually_exclusive_group(required=False)
271 | nice_parser.add_argument('--nice', dest='nice', action='store_true')
272 | nice_parser.add_argument('--imap', dest='nice', action='store_false')
273 | parser.set_defaults(nice=True)
274 |
275 | args = parser.parse_args()
276 | cfg = config.load_config(
277 | args.config, 'configs/nice_slam.yaml' if args.nice else 'configs/imap.yaml')
278 | scale = cfg['scale']
279 | output = cfg['data']['output'] if args.output is None else args.output
280 | cofusion = ('cofusion' in args.config) or ('CoFusion' in args.config)
281 | ckptsdir = f'{output}/ckpts'
282 | if os.path.exists(ckptsdir):
283 | ckpts = [os.path.join(ckptsdir, f)
284 | for f in sorted(os.listdir(ckptsdir)) if 'tar' in f]
285 | if len(ckpts) > 0:
286 | ckpt_path = ckpts[-1]
287 | print('Get ckpt :', ckpt_path)
288 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
289 | estimate_c2w_list = ckpt['estimate_c2w_list']
290 | gt_c2w_list = ckpt['gt_c2w_list']
291 | N = ckpt['idx']
292 | if cofusion:
293 | poses_gt = np.loadtxt(
294 | 'Datasets/CoFusion/room4/trajectories/gt-cam-0.txt')
295 | poses_gt = torch.from_numpy(poses_gt[:, 1:])
296 | else:
297 | poses_gt, mask = convert_poses(gt_c2w_list, N, scale)
298 | poses_est, _ = convert_poses(estimate_c2w_list, N, scale)
299 | poses_est = poses_est[mask]
300 | evaluate(poses_gt, poses_est,
301 | plot=f'{output}/eval_ate_plot.png')
302 |
--------------------------------------------------------------------------------
/src/tools/eval_recon.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 |
4 | import numpy as np
5 | import open3d as o3d
6 | import torch
7 | import trimesh
8 | from scipy.spatial import cKDTree as KDTree
9 |
10 | def setup_seed(seed):
11 | torch.manual_seed(seed)
12 | torch.cuda.manual_seed_all(seed)
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.backends.cudnn.deterministic = True
16 |
17 |
18 |
19 | def normalize(x):
20 | return x / np.linalg.norm(x)
21 |
22 |
23 | def viewmatrix(z, up, pos):
24 | vec2 = normalize(z)
25 | vec1_avg = up
26 | vec0 = normalize(np.cross(vec1_avg, vec2))
27 | vec1 = normalize(np.cross(vec2, vec0))
28 | m = np.stack([vec0, vec1, vec2, pos], 1)
29 | return m
30 |
31 |
32 | def completion_ratio(gt_points, rec_points, dist_th=0.05):
33 | gen_points_kd_tree = KDTree(rec_points)
34 | distances, _ = gen_points_kd_tree.query(gt_points)
35 | comp_ratio = np.mean((distances < dist_th).astype(np.float))
36 | return comp_ratio
37 |
38 |
39 | def accuracy(gt_points, rec_points):
40 | gt_points_kd_tree = KDTree(gt_points)
41 | distances, _ = gt_points_kd_tree.query(rec_points)
42 | acc = np.mean(distances)
43 | return acc
44 |
45 |
46 | def completion(gt_points, rec_points):
47 | gt_points_kd_tree = KDTree(rec_points)
48 | distances, _ = gt_points_kd_tree.query(gt_points)
49 | comp = np.mean(distances)
50 | return comp
51 |
52 |
53 | def get_align_transformation(rec_meshfile, gt_meshfile):
54 | """
55 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh.
56 | """
57 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
58 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
59 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices)
60 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices)
61 | trans_init = np.eye(4)
62 | threshold = 0.1
63 | reg_p2p = o3d.pipelines.registration.registration_icp(
64 | o3d_rec_pc, o3d_gt_pc, threshold, trans_init,
65 | o3d.pipelines.registration.TransformationEstimationPointToPoint())
66 | transformation = reg_p2p.transformation
67 | return transformation
68 |
69 |
70 | def check_proj(points, W, H, fx, fy, cx, cy, c2w):
71 | """
72 | Check if points can be projected into the camera view.
73 |
74 | """
75 | c2w = c2w.copy()
76 | c2w[:3, 1] *= -1.0
77 | c2w[:3, 2] *= -1.0
78 | points = torch.from_numpy(points).cuda().clone()
79 | w2c = np.linalg.inv(c2w)
80 | w2c = torch.from_numpy(w2c).cuda().float()
81 | K = torch.from_numpy(
82 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda()
83 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
84 | homo_points = torch.cat(
85 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float() # (N, 4)
86 | cam_cord_homo = w2c@homo_points # (N, 4, 1)=(4,4)*(N, 4, 1)
87 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1)
88 | cam_cord[:, 0] *= -1
89 | uv = K.float()@cam_cord.float()
90 | z = uv[:, -1:]+1e-5
91 | uv = uv[:, :2]/z
92 | uv = uv.float().squeeze(-1).cpu().numpy()
93 | edge = 0
94 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W -
95 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge)
96 | return mask.sum() > 0
97 |
98 |
99 | def calc_3d_metric(rec_meshfile, gt_meshfile, align=True):
100 | """
101 | 3D reconstruction metric.
102 |
103 | """
104 | mesh_rec = trimesh.load(rec_meshfile, process=False)
105 | mesh_gt = trimesh.load(gt_meshfile, process=False)
106 |
107 | if align:
108 | transformation = get_align_transformation(rec_meshfile, gt_meshfile)
109 | mesh_rec = mesh_rec.apply_transform(transformation)
110 |
111 | rec_pc = trimesh.sample.sample_surface(mesh_rec, 200000)
112 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0])
113 |
114 | gt_pc = trimesh.sample.sample_surface(mesh_gt, 200000)
115 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0])
116 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices)
117 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices)
118 | completion_ratio_rec = completion_ratio(
119 | gt_pc_tri.vertices, rec_pc_tri.vertices)
120 | accuracy_rec *= 100 # convert to cm
121 | completion_rec *= 100 # convert to cm
122 | completion_ratio_rec *= 100 # convert to %
123 | print('accuracy: ', accuracy_rec)
124 | print('completion: ', completion_rec)
125 | print('completion ratio: ', completion_ratio_rec)
126 |
127 |
128 | def get_cam_position(gt_meshfile):
129 | mesh_gt = trimesh.load(gt_meshfile)
130 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt)
131 | extents[2] *= 0.7
132 | extents[1] *= 0.7
133 | extents[0] *= 0.3
134 | transform = np.linalg.inv(to_origin)
135 | transform[2, 3] += 0.4
136 | return extents, transform
137 |
138 |
139 | def calc_2d_metric(rec_meshfile, gt_meshfile, align=True, n_imgs=1000):
140 | """
141 | 2D reconstruction metric, depth L1 loss.
142 |
143 | """
144 | H = 500
145 | W = 500
146 | focal = 300
147 | fx = focal
148 | fy = focal
149 | cx = H/2.0-0.5
150 | cy = W/2.0-0.5
151 |
152 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
153 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
154 | unseen_gt_pointcloud_file = gt_meshfile.replace('.ply', '_pc_unseen.npy')
155 | pc_unseen = np.load(unseen_gt_pointcloud_file)
156 | if align:
157 | transformation = get_align_transformation(rec_meshfile, gt_meshfile)
158 | rec_mesh = rec_mesh.transform(transformation)
159 |
160 | # get vacant area inside the room
161 | extents, transform = get_cam_position(gt_meshfile)
162 |
163 | vis = o3d.visualization.Visualizer()
164 | vis.create_window(width=W, height=H)
165 | vis.get_render_option().mesh_show_back_face = True
166 | errors = []
167 | for i in range(n_imgs):
168 | while True:
169 | # sample view, and check if unseen region is not inside the camera view
170 | # if inside, then needs to resample
171 | up = [0, 0, -1]
172 | origin = trimesh.sample.volume_rectangular(
173 | extents, 1, transform=transform)
174 | origin = origin.reshape(-1)
175 | tx = round(random.uniform(-10000, +10000), 2)
176 | ty = round(random.uniform(-10000, +10000), 2)
177 | tz = round(random.uniform(-10000, +10000), 2)
178 | target = [tx, ty, tz]
179 | target = np.array(target)-np.array(origin)
180 | c2w = viewmatrix(target, up, origin)
181 | tmp = np.eye(4)
182 | tmp[:3, :] = c2w
183 | c2w = tmp
184 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w)
185 | if (~seen):
186 | break
187 |
188 | param = o3d.camera.PinholeCameraParameters()
189 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array
190 |
191 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic(
192 | W, H, fx, fy, cx, cy)
193 |
194 | ctr = vis.get_view_control()
195 | ctr.set_constant_z_far(20)
196 | ctr.convert_from_pinhole_camera_parameters(param)
197 |
198 | vis.add_geometry(gt_mesh, reset_bounding_box=True,)
199 | ctr.convert_from_pinhole_camera_parameters(param)
200 | vis.poll_events()
201 | vis.update_renderer()
202 | gt_depth = vis.capture_depth_float_buffer(True)
203 | gt_depth = np.asarray(gt_depth)
204 | vis.remove_geometry(gt_mesh, reset_bounding_box=True,)
205 |
206 | vis.add_geometry(rec_mesh, reset_bounding_box=True,)
207 | ctr.convert_from_pinhole_camera_parameters(param)
208 | vis.poll_events()
209 | vis.update_renderer()
210 | ours_depth = vis.capture_depth_float_buffer(True)
211 | ours_depth = np.asarray(ours_depth)
212 | vis.remove_geometry(rec_mesh, reset_bounding_box=True,)
213 |
214 | errors += [np.abs(gt_depth-ours_depth).mean()]
215 |
216 | errors = np.array(errors)
217 | # from m to cm
218 | print('Depth L1: ', errors.mean()*100)
219 |
220 |
221 | if __name__ == '__main__':
222 | #setup_seed(20)
223 |
224 | parser = argparse.ArgumentParser(
225 | description='Arguments to evaluate the reconstruction.'
226 | )
227 | parser.add_argument('--rec_mesh', type=str,
228 | help='reconstructed mesh file path')
229 | parser.add_argument('--gt_mesh', type=str,
230 | help='ground truth mesh file path')
231 | parser.add_argument('-2d', '--metric_2d',
232 | action='store_true', help='enable 2D metric')
233 | parser.add_argument('-3d', '--metric_3d',
234 | action='store_true', help='enable 3D metric')
235 | args = parser.parse_args()
236 | if args.metric_3d:
237 | calc_3d_metric(args.rec_mesh, args.gt_mesh)
238 |
239 | if args.metric_2d:
240 | calc_2d_metric(args.rec_mesh, args.gt_mesh, n_imgs=1000)
241 |
--------------------------------------------------------------------------------
/src/tools/evaluate_scannet.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/zju3dv/manhattan_sdf
2 | import numpy as np
3 | import open3d as o3d
4 | from sklearn.neighbors import KDTree
5 | import trimesh
6 | import torch
7 | import glob
8 | import os
9 | import pyrender
10 | import os
11 | from tqdm import tqdm
12 | from pathlib import Path
13 | import sys
14 | sys.path.append(".")
15 | from src import config
16 | from src.utils.datasets import get_dataset
17 | import argparse
18 |
19 | # os.environ['PYOPENGL_PLATFORM'] = 'egl'
20 |
21 | def nn_correspondance(verts1, verts2):
22 | indices = []
23 | distances = []
24 | if len(verts1) == 0 or len(verts2) == 0:
25 | return indices, distances
26 |
27 | kdtree = KDTree(verts1)
28 | distances, indices = kdtree.query(verts2)
29 | distances = distances.reshape(-1)
30 |
31 | return distances
32 |
33 |
34 | def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):
35 | pcd_trgt = o3d.geometry.PointCloud()
36 | pcd_pred = o3d.geometry.PointCloud()
37 |
38 | pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3])
39 | pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3])
40 |
41 | if down_sample:
42 | pcd_pred = pcd_pred.voxel_down_sample(down_sample)
43 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)
44 |
45 | verts_pred = np.asarray(pcd_pred.points)
46 | verts_trgt = np.asarray(pcd_trgt.points)
47 |
48 | dist1 = nn_correspondance(verts_pred, verts_trgt)
49 | dist2 = nn_correspondance(verts_trgt, verts_pred)
50 |
51 | precision = np.mean((dist2 < threshold).astype('float'))
52 | recal = np.mean((dist1 < threshold).astype('float'))
53 | fscore = 2 * precision * recal / (precision + recal)
54 | metrics = {
55 | 'Acc': np.mean(dist2),
56 | 'Comp': np.mean(dist1),
57 | 'Chamfer': (np.mean(dist1) + np.mean(dist2))/2,
58 | 'Prec': precision,
59 | 'Recal': recal,
60 | 'F-score': fscore,
61 | }
62 | return metrics
63 |
64 |
65 |
66 | def update_cam(cfg):
67 | """
68 | Update the camera intrinsics according to pre-processing config,
69 | such as resize or edge crop.
70 | """
71 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][
72 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']
73 | # resize the input images to crop_size (variable name used in lietorch)
74 | if 'crop_size' in cfg['cam']:
75 | crop_size = cfg['cam']['crop_size']
76 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][
77 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']
78 | sx = crop_size[1] / W
79 | sy = crop_size[0] / H
80 | fx = sx*fx
81 | fy = sy*fy
82 | cx = sx*cx
83 | cy = sy*cy
84 | W = crop_size[1]
85 | H = crop_size[0]
86 |
87 | # croping will change H, W, cx, cy, so need to change here
88 | if cfg['cam']['crop_edge'] > 0:
89 | H -= cfg['cam']['crop_edge']*2
90 | W -= cfg['cam']['crop_edge']*2
91 | cx -= cfg['cam']['crop_edge']
92 | cy -= cfg['cam']['crop_edge']
93 |
94 | return H, W, fx, fy, cx, cy
95 |
96 | # load pose
97 | def get_pose(cfg, args):
98 | scale = cfg['scale']
99 | H, W, fx, fy, cx, cy = update_cam(cfg)
100 | K = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy).intrinsic_matrix # (3, 3)
101 |
102 | frame_reader = get_dataset(cfg, args, scale)
103 |
104 | pose_ls = []
105 | for idx in range(len(frame_reader)):
106 | if idx % 10 != 0: continue
107 | _, gt_color, gt_depth, gt_c2w = frame_reader[idx]
108 |
109 | c2w = gt_c2w.cpu().numpy()
110 |
111 | if np.isfinite(c2w).any():
112 |
113 | c2w[:3, 1] *= -1.0
114 | c2w[:3, 2] *= -1.0
115 | pose_ls.append(c2w)
116 |
117 | return pose_ls, K, H, W
118 |
119 |
120 | class Renderer():
121 | def __init__(self, height=480, width=640):
122 | self.renderer = pyrender.OffscreenRenderer(width, height)
123 | self.scene = pyrender.Scene()
124 |
125 | def __call__(self, height, width, intrinsics, pose, mesh):
126 | self.renderer.viewport_height = height
127 | self.renderer.viewport_width = width
128 | self.scene.clear()
129 | self.scene.add(mesh)
130 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],
131 | fx=intrinsics[0, 0], fy=intrinsics[1, 1])
132 |
133 | self.scene.add(cam, pose=self.fix_pose(pose))
134 | return self.renderer.render(self.scene)
135 |
136 | def fix_pose(self, pose):
137 | # 3D Rotation about the x-axis.
138 | t = np.pi
139 | c = np.cos(t)
140 | s = np.sin(t)
141 | R = np.array([[1, 0, 0],
142 | [0, c, -s],
143 | [0, s, c]]) # [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
144 | axis_transform = np.eye(4)
145 | axis_transform[:3, :3] = R
146 | return pose @ axis_transform
147 |
148 | def mesh_opengl(self, mesh):
149 | return pyrender.Mesh.from_trimesh(mesh)
150 |
151 | def delete(self):
152 | self.renderer.delete()
153 |
154 |
155 | def refuse(mesh, poses, K, H, W, cfg):
156 | renderer = Renderer()
157 | mesh_opengl = renderer.mesh_opengl(mesh)
158 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
159 | voxel_length=0.01,
160 | sdf_trunc=3*0.01,
161 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
162 | )
163 |
164 | idx = 0
165 | H, W, fx, fy, cx, cy = update_cam(cfg)
166 |
167 | for pose in tqdm(poses):
168 |
169 | intrinsic = np.eye(4)
170 | intrinsic[:3, :3] = K
171 |
172 | rgb = np.ones((H, W, 3))
173 | rgb = (rgb * 255).astype(np.uint8)
174 | rgb = o3d.geometry.Image(rgb)
175 | _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl)
176 |
177 | depth_pred = o3d.geometry.Image(depth_pred)
178 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
179 | rgb, depth_pred, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False
180 | )
181 |
182 | intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy)
183 | extrinsic = np.linalg.inv(pose)
184 | volume.integrate(rgbd, intrinsic, extrinsic)
185 |
186 | return volume.extract_triangle_mesh()
187 |
188 | def evaluate_mesh():
189 | """
190 | Evaluate the scannet mesh.
191 |
192 | """
193 | parser = argparse.ArgumentParser(
194 | description='Arguments for running the code.'
195 | )
196 | parser.add_argument('config', type=str, help='Path to config file.')
197 | parser.add_argument('--input_folder', type=str,
198 | help='input folder, this have higher priority, can overwrite the one in config file')
199 | parser.add_argument('--output', type=str,
200 | help='output folder, this have higher priority, can overwrite the one in config file')
201 | parser.add_argument('--space', type=int, default=10, help='the space between frames to integrate into the TSDF volume.')
202 |
203 | args = parser.parse_args()
204 | cfg = config.load_config(args.config, 'configs/df_prior.yaml')
205 |
206 | scene_id = cfg['data']['id']
207 |
208 | input_file = f"output/scannet/scans/scene{scene_id:04d}_00/mesh/final_mesh.ply"
209 | mesh = trimesh.load_mesh(input_file)
210 | mesh.invert() # change noraml of mesh
211 |
212 | poses, K, H, W = get_pose(cfg, args)
213 |
214 | # refuse mesh
215 | mesh = refuse(mesh, poses, K, H, W, cfg)
216 |
217 | # save mesh
218 | out_mesh_path = f"output/scannet/scans/scene{scene_id:04d}_00/mesh/final_mesh_refused.ply"
219 | o3d.io.write_triangle_mesh(out_mesh_path, mesh)
220 |
221 | mesh = trimesh.load(out_mesh_path)
222 | gt_mesh = os.path.join("./Datasets/scannet/GTmesh_lowres", f"{scene_id:04d}_00.obj")
223 | gt_mesh = trimesh.load(gt_mesh)
224 |
225 | metrics = evaluate(mesh, gt_mesh)
226 | print(metrics)
227 |
228 |
229 | if __name__ == "__main__":
230 | evaluate_mesh()
--------------------------------------------------------------------------------
/src/utils/Logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 |
6 | class Logger(object):
7 | """
8 | Save checkpoints to file.
9 |
10 | """
11 |
12 | def __init__(self, cfg, args, slam
13 | ):
14 | self.verbose = slam.verbose
15 | self.ckptsdir = slam.ckptsdir
16 | self.shared_c = slam.shared_c
17 | self.gt_c2w_list = slam.gt_c2w_list
18 | self.shared_decoders = slam.shared_decoders
19 | self.estimate_c2w_list = slam.estimate_c2w_list
20 | self.tsdf_volume = slam.tsdf_volume_shared
21 |
22 | def log(self, idx, keyframe_dict, keyframe_list, selected_keyframes=None):
23 | path = os.path.join(self.ckptsdir, '{:05d}.tar'.format(idx))
24 | torch.save({
25 | 'c': self.shared_c,
26 | 'decoder_state_dict': self.shared_decoders.state_dict(),
27 | 'gt_c2w_list': self.gt_c2w_list,
28 | 'estimate_c2w_list': self.estimate_c2w_list,
29 | 'keyframe_list': keyframe_list,
30 | 'keyframe_dict': keyframe_dict, # to save keyframe_dict into ckpt, uncomment this line
31 | 'selected_keyframes': selected_keyframes,
32 | 'idx': idx,
33 | 'tsdf_volume': self.tsdf_volume,
34 | }, path, _use_new_zipfile_serialization=False)
35 |
36 | if self.verbose:
37 | print('Saved checkpoints at', path)
38 |
--------------------------------------------------------------------------------
/src/utils/Mesher.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import skimage
4 | import torch
5 | import torch.nn.functional as F
6 | import trimesh
7 | from packaging import version
8 | from src.utils.datasets import get_dataset
9 | from src.common import normalize_3d_coordinate
10 | import matplotlib.pyplot as plt
11 |
12 | class Mesher(object):
13 |
14 | def __init__(self, cfg, args, slam, points_batch_size=500000, ray_batch_size=100000):
15 | """
16 | Mesher class, given a scene representation, the mesher extracts the mesh from it.
17 |
18 | Args:
19 | cfg (dict): parsed config dict.
20 | args (class 'argparse.Namespace'): argparse arguments.
21 | slam (class DF_Prior): DF_Prior main class.
22 | points_batch_size (int): maximum points size for query in one batch.
23 | Used to alleviate GPU memeory usage. Defaults to 500000.
24 | ray_batch_size (int): maximum ray size for query in one batch.
25 | Used to alleviate GPU memeory usage. Defaults to 100000.
26 | """
27 | self.points_batch_size = points_batch_size
28 | self.ray_batch_size = ray_batch_size
29 | self.renderer = slam.renderer
30 | self.scale = cfg['scale']
31 | self.occupancy = cfg['occupancy']
32 |
33 | self.resolution = cfg['meshing']['resolution']
34 | self.level_set = cfg['meshing']['level_set']
35 | self.clean_mesh_bound_scale = cfg['meshing']['clean_mesh_bound_scale']
36 | self.remove_small_geometry_threshold = cfg['meshing']['remove_small_geometry_threshold']
37 | self.color_mesh_extraction_method = cfg['meshing']['color_mesh_extraction_method']
38 | self.get_largest_components = cfg['meshing']['get_largest_components']
39 | self.depth_test = cfg['meshing']['depth_test']
40 |
41 | self.bound = slam.bound
42 | self.verbose = slam.verbose
43 |
44 |
45 | self.marching_cubes_bound = torch.from_numpy(
46 | np.array(cfg['mapping']['marching_cubes_bound']) * self.scale)
47 |
48 | self.frame_reader = get_dataset(cfg, args, self.scale, device='cpu')
49 | self.n_img = len(self.frame_reader)
50 |
51 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = slam.H, slam.W, slam.fx, slam.fy, slam.cx, slam.cy
52 |
53 | self.sample_mode = 'bilinear'
54 | self.tsdf_bnds = slam.tsdf_bnds
55 |
56 |
57 |
58 | def point_masks(self, input_points, keyframe_dict, estimate_c2w_list,
59 | idx, device, get_mask_use_all_frames=False):
60 | """
61 | Split the input points into seen, unseen, and forcast,
62 | according to the estimated camera pose and depth image.
63 |
64 | Args:
65 | input_points (tensor): input points.
66 | keyframe_dict (list): list of keyframe info dictionary.
67 | estimate_c2w_list (tensor): estimated camera pose.
68 | idx (int): current frame index.
69 | device (str): device name to compute on.
70 |
71 | Returns:
72 | seen_mask (tensor): the mask for seen area.
73 | forecast_mask (tensor): the mask for forecast area.
74 | unseen_mask (tensor): the mask for unseen area.
75 | """
76 | H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy
77 | if not isinstance(input_points, torch.Tensor):
78 | input_points = torch.from_numpy(input_points)
79 | input_points = input_points.clone().detach()
80 | seen_mask_list = []
81 | forecast_mask_list = []
82 | unseen_mask_list = []
83 | for i, pnts in enumerate(
84 | torch.split(input_points, self.points_batch_size, dim=0)):
85 | points = pnts.to(device).float()
86 | # should divide the points into three parts, seen and forecast and unseen
87 | # seen: union of all the points in the viewing frustum of keyframes
88 | # forecast: union of all the points in the extended edge of the viewing frustum of keyframes
89 | # unseen: all the other points
90 |
91 | seen_mask = torch.zeros((points.shape[0])).bool().to(device)
92 | forecast_mask = torch.zeros((points.shape[0])).bool().to(device)
93 | if get_mask_use_all_frames:
94 | for i in range(0, idx + 1, 1):
95 | c2w = estimate_c2w_list[i].cpu().numpy()
96 | w2c = np.linalg.inv(c2w)
97 | w2c = torch.from_numpy(w2c).to(device).float()
98 | ones = torch.ones_like(
99 | points[:, 0]).reshape(-1, 1).to(device)
100 | homo_points = torch.cat([points, ones], dim=1).reshape(
101 | -1, 4, 1).to(device).float() # (N, 4)
102 | # (N, 4, 1)=(4,4)*(N, 4, 1)
103 | cam_cord_homo = w2c @ homo_points
104 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1)
105 |
106 | K = torch.from_numpy(
107 | np.array([[fx, .0, cx], [.0, fy, cy],
108 | [.0, .0, 1.0]]).reshape(3, 3)).to(device)
109 | cam_cord[:, 0] *= -1
110 | uv = K.float() @ cam_cord.float()
111 | z = uv[:, -1:] + 1e-8
112 | uv = uv[:, :2] / z
113 | uv = uv.float()
114 | edge = 0
115 | cur_mask_seen = (uv[:, 0] < W - edge) & (
116 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge)
117 | cur_mask_seen = cur_mask_seen & (z[:, :, 0] < 0)
118 |
119 | edge = -1000
120 | cur_mask_forecast = (uv[:, 0] < W - edge) & (
121 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge)
122 | cur_mask_forecast = cur_mask_forecast & (z[:, :, 0] < 0)
123 |
124 | # forecast
125 | cur_mask_forecast = cur_mask_forecast.reshape(-1)
126 | # seen
127 | cur_mask_seen = cur_mask_seen.reshape(-1)
128 |
129 | seen_mask |= cur_mask_seen
130 | forecast_mask |= cur_mask_forecast
131 | else:
132 | for keyframe in keyframe_dict:
133 | c2w = keyframe['est_c2w'].cpu().numpy()
134 | w2c = np.linalg.inv(c2w)
135 | w2c = torch.from_numpy(w2c).to(device).float()
136 | ones = torch.ones_like(
137 | points[:, 0]).reshape(-1, 1).to(device)
138 | homo_points = torch.cat([points, ones], dim=1).reshape(
139 | -1, 4, 1).to(device).float()
140 | cam_cord_homo = w2c @ homo_points
141 | cam_cord = cam_cord_homo[:, :3]
142 |
143 | K = torch.from_numpy(
144 | np.array([[fx, .0, cx], [.0, fy, cy],
145 | [.0, .0, 1.0]]).reshape(3, 3)).to(device)
146 | cam_cord[:, 0] *= -1
147 | uv = K.float() @ cam_cord.float()
148 | z = uv[:, -1:] + 1e-8
149 | uv = uv[:, :2] / z
150 | uv = uv.float()
151 | edge = 0
152 | cur_mask_seen = (uv[:, 0] < W - edge) & (
153 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge)
154 | cur_mask_seen = cur_mask_seen & (z[:, :, 0] < 0)
155 |
156 | edge = -1000
157 | cur_mask_forecast = (uv[:, 0] < W - edge) & (
158 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge)
159 | cur_mask_forecast = cur_mask_forecast & (z[:, :, 0] < 0)
160 |
161 | if self.depth_test:
162 | gt_depth = keyframe['depth'].to(
163 | device).reshape(1, 1, H, W)
164 | vgrid = uv.reshape(1, 1, -1, 2)
165 | # normalized to [-1, 1]
166 | vgrid[..., 0] = (vgrid[..., 0] / (W-1) * 2.0 - 1.0)
167 | vgrid[..., 1] = (vgrid[..., 1] / (H-1) * 2.0 - 1.0)
168 | depth_sample = F.grid_sample(
169 | gt_depth, vgrid, padding_mode='zeros', align_corners=True)
170 | depth_sample = depth_sample.reshape(-1)
171 | max_depth = torch.max(depth_sample)
172 | # forecast
173 | cur_mask_forecast = cur_mask_forecast.reshape(-1)
174 | proj_depth_forecast = -cam_cord[cur_mask_forecast,
175 | 2].reshape(-1)
176 | cur_mask_forecast[cur_mask_forecast.clone()] &= proj_depth_forecast < max_depth
177 | # seen
178 | cur_mask_seen = cur_mask_seen.reshape(-1)
179 | proj_depth_seen = - cam_cord[cur_mask_seen, 2].reshape(-1)
180 | cur_mask_seen[cur_mask_seen.clone()] &= \
181 | (proj_depth_seen < depth_sample[cur_mask_seen]+2.4) \
182 | & (depth_sample[cur_mask_seen]-2.4 < proj_depth_seen)
183 | else:
184 | max_depth = torch.max(keyframe['depth'])*1.1
185 |
186 | # forecast
187 | cur_mask_forecast = cur_mask_forecast.reshape(-1)
188 | proj_depth_forecast = -cam_cord[cur_mask_forecast,
189 | 2].reshape(-1)
190 | cur_mask_forecast[
191 | cur_mask_forecast.clone()] &= proj_depth_forecast < max_depth
192 |
193 | # seen
194 | cur_mask_seen = cur_mask_seen.reshape(-1)
195 | proj_depth_seen = - \
196 | cam_cord[cur_mask_seen, 2].reshape(-1)
197 | cur_mask_seen[cur_mask_seen.clone(
198 | )] &= proj_depth_seen < max_depth
199 |
200 | seen_mask |= cur_mask_seen
201 | forecast_mask |= cur_mask_forecast
202 |
203 | forecast_mask &= ~seen_mask
204 | unseen_mask = ~(seen_mask | forecast_mask)
205 |
206 | seen_mask = seen_mask.cpu().numpy()
207 | forecast_mask = forecast_mask.cpu().numpy()
208 | unseen_mask = unseen_mask.cpu().numpy()
209 |
210 | seen_mask_list.append(seen_mask)
211 | forecast_mask_list.append(forecast_mask)
212 | unseen_mask_list.append(unseen_mask)
213 |
214 | seen_mask = np.concatenate(seen_mask_list, axis=0)
215 | forecast_mask = np.concatenate(forecast_mask_list, axis=0)
216 | unseen_mask = np.concatenate(unseen_mask_list, axis=0)
217 | return seen_mask, forecast_mask, unseen_mask
218 |
219 | def get_bound_from_frames(self, keyframe_dict, scale=1):
220 | """
221 | Get the scene bound (convex hull),
222 | using sparse estimated camera poses and corresponding depth images.
223 |
224 | Args:
225 | keyframe_dict (list): list of keyframe info dictionary.
226 | scale (float): scene scale.
227 |
228 | Returns:
229 | return_mesh (trimesh.Trimesh): the convex hull.
230 | """
231 |
232 | H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy
233 |
234 | if version.parse(o3d.__version__) >= version.parse('0.13.0'):
235 | # for new version as provided in environment.yaml
236 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
237 | voxel_length=4.0 * scale / 512.0,
238 | sdf_trunc=0.04 * scale,
239 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8)
240 | else:
241 | # for lower version
242 | volume = o3d.integration.ScalableTSDFVolume(
243 | voxel_length=4.0 * scale / 512.0,
244 | sdf_trunc=0.04 * scale,
245 | color_type=o3d.integration.TSDFVolumeColorType.RGB8)
246 | cam_points = []
247 | for keyframe in keyframe_dict:
248 | c2w = keyframe['est_c2w'].cpu().numpy()
249 | # convert to open3d camera pose
250 | c2w[:3, 1] *= -1.0
251 | c2w[:3, 2] *= -1.0
252 | w2c = np.linalg.inv(c2w)
253 | cam_points.append(c2w[:3, 3])
254 | depth = keyframe['depth'].cpu().numpy()
255 | color = keyframe['color'].cpu().numpy()
256 |
257 | depth = o3d.geometry.Image(depth.astype(np.float32))
258 | color = o3d.geometry.Image(np.array(
259 | (color * 255).astype(np.uint8)))
260 |
261 | intrinsic = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy)
262 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
263 | color,
264 | depth,
265 | depth_scale=1,
266 | depth_trunc=1000,
267 | convert_rgb_to_intensity=False)
268 | volume.integrate(rgbd, intrinsic, w2c)
269 |
270 | cam_points = np.stack(cam_points, axis=0)
271 | mesh = volume.extract_triangle_mesh()
272 | mesh_points = np.array(mesh.vertices)
273 | points = np.concatenate([cam_points, mesh_points], axis=0)
274 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
275 | mesh, _ = o3d_pc.compute_convex_hull()
276 | mesh.compute_vertex_normals()
277 | if version.parse(o3d.__version__) >= version.parse('0.13.0'):
278 | mesh = mesh.scale(self.clean_mesh_bound_scale, mesh.get_center())
279 | else:
280 | mesh = mesh.scale(self.clean_mesh_bound_scale, center=True)
281 | points = np.array(mesh.vertices)
282 | faces = np.array(mesh.triangles)
283 | return_mesh = trimesh.Trimesh(vertices=points, faces=faces)
284 | return return_mesh
285 |
286 | def eval_points(self, p, decoders, tsdf_volume, tsdf_bnds, c=None, stage='color', device='cuda:0'):
287 | """
288 | Evaluates the occupancy and/or color value for the points.
289 |
290 | Args:
291 | p (tensor, N*3): point coordinates.
292 | decoders (nn.module decoders): decoders.
293 | tsdf_volume (tensor): tsdf volume.
294 | tsdf_bnds (tensor): tsdf volume bounds.
295 | c (dicts, optional): feature grids. Defaults to None.
296 | stage (str, optional): query stage, corresponds to different levels. Defaults to 'color'.
297 | device (str, optional): device name to compute on. Defaults to 'cuda:0'.
298 |
299 | Returns:
300 | ret (tensor): occupancy (and color) value of input points.
301 | """
302 |
303 | p_split = torch.split(p, self.points_batch_size)
304 | bound = self.bound
305 | rets = []
306 |
307 | for pi in p_split:
308 | # mask for points out of bound
309 | mask_x = (pi[:, 0] < bound[0][1]) & (pi[:, 0] > bound[0][0])
310 | mask_y = (pi[:, 1] < bound[1][1]) & (pi[:, 1] > bound[1][0])
311 | mask_z = (pi[:, 2] < bound[2][1]) & (pi[:, 2] > bound[2][0])
312 | mask = mask_x & mask_y & mask_z
313 |
314 | pi = pi.unsqueeze(0)
315 | ret, _ = decoders(pi, c_grid=c, tsdf_volume=tsdf_volume, tsdf_bnds=tsdf_bnds, stage=stage)
316 |
317 | ret = ret.squeeze(0)
318 | if len(ret.shape) == 1 and ret.shape[0] == 4:
319 | ret = ret.unsqueeze(0)
320 |
321 | ret[~mask, 3] = 100
322 | rets.append(ret)
323 |
324 | ret = torch.cat(rets, dim=0)
325 |
326 | return ret
327 |
328 | def sample_grid_tsdf(self, p, tsdf_volume, device='cuda:0'):
329 |
330 | p_nor = normalize_3d_coordinate(p.clone(), self.tsdf_bnds)
331 | p_nor = p_nor.unsqueeze(0)
332 | vgrid = p_nor[:, :, None, None].float()
333 | # acutally trilinear interpolation if mode = 'bilinear'
334 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True,
335 | mode='bilinear').squeeze(-1).squeeze(-1)
336 | return tsdf_value
337 |
338 |
339 | def eval_points_tsdf(self, p, tsdf_volume, device='cuda:0'):
340 | """
341 | Evaluates the occupancy and/or color value for the points.
342 |
343 | Args:
344 | p (tensor, N*3): Point coordinates.
345 | tsdf_volume (tensor): tsdf volume.
346 |
347 | Returns:
348 | ret (tensor): tsdf value of input points.
349 | """
350 |
351 | p_split = torch.split(p, self.points_batch_size)
352 | tsdf_vals = []
353 | for pi in p_split:
354 | pi = pi.unsqueeze(0)
355 | tsdf_volume_tensor = tsdf_volume
356 |
357 | tsdf_val = self.sample_grid_tsdf(pi, tsdf_volume_tensor, device)
358 | tsdf_val = tsdf_val.squeeze(0)
359 | tsdf_vals.append(tsdf_val)
360 |
361 | tsdf_values = torch.cat(tsdf_vals, dim=1)
362 | return tsdf_values
363 |
364 |
365 | def get_grid_uniform(self, resolution):
366 | """
367 | Get query point coordinates for marching cubes.
368 |
369 | Args:
370 | resolution (int): marching cubes resolution.
371 |
372 | Returns:
373 | (dict): points coordinates and sampled coordinates for each axis.
374 | """
375 | bound = self.marching_cubes_bound
376 |
377 | padding = 0.05
378 | x = np.linspace(bound[0][0] - padding, bound[0][1] + padding,
379 | resolution)
380 | y = np.linspace(bound[1][0] - padding, bound[1][1] + padding,
381 | resolution)
382 | z = np.linspace(bound[2][0] - padding, bound[2][1] + padding,
383 | resolution)
384 |
385 | xx, yy, zz = np.meshgrid(x, y, z)
386 | grid_points = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
387 | grid_points = torch.tensor(np.vstack(
388 | [xx.ravel(), yy.ravel(), zz.ravel()]).T,
389 | dtype=torch.float)
390 |
391 |
392 |
393 | return {"grid_points": grid_points, "xyz": [x, y, z]}
394 |
395 | def get_mesh(self,
396 | mesh_out_file,
397 | c,
398 | decoders,
399 | keyframe_dict,
400 | estimate_c2w_list,
401 | idx,
402 | tsdf_volume,
403 | device='cuda:0',
404 | color=True,
405 | clean_mesh=True,
406 | get_mask_use_all_frames=False):
407 | """
408 | Extract mesh from scene representation and save mesh to file.
409 |
410 | Args:
411 | mesh_out_file (str): output mesh filename.
412 | c (dicts): feature grids.
413 | decoders (nn.module): decoders.
414 | keyframe_dict (list): list of keyframe info.
415 | estimate_c2w_list (tensor): estimated camera pose.
416 | idx (int): current processed camera ID.
417 | tsdf volume (tensor): tsdf volume.
418 | device (str, optional): device name to compute on. Defaults to 'cuda:0'.
419 | color (bool, optional): whether to extract colored mesh. Defaults to True.
420 | clean_mesh (bool, optional): whether to clean the output mesh
421 | (remove outliers outside the convexhull and small geometry noise).
422 | Defaults to True.
423 | get_mask_use_all_frames (bool, optional):
424 | whether to use all frames or just keyframes when getting the seen/unseen mask. Defaults to False.
425 | """
426 | with torch.no_grad():
427 |
428 | grid = self.get_grid_uniform(self.resolution)
429 | points = grid['grid_points']
430 | points = points.to(device)
431 | eval_tsdf_volume = tsdf_volume
432 |
433 | mesh_bound = self.get_bound_from_frames(
434 | keyframe_dict, self.scale)
435 | z = []
436 | mask = []
437 | for i, pnts in enumerate(torch.split(points, self.points_batch_size, dim=0)):
438 | mask.append(mesh_bound.contains(pnts.cpu().numpy()))
439 | mask = np.concatenate(mask, axis=0)
440 | for i, pnts in enumerate(torch.split(points, self.points_batch_size, dim=0)):
441 | eval_tsdf = self.eval_points_tsdf(pnts, eval_tsdf_volume, device)
442 | eval_tsdf_mask = ((eval_tsdf > -1.0+1e-4) & (eval_tsdf < 1.0-1e-4)).cpu().numpy()
443 | ret = self.eval_points(pnts, decoders, tsdf_volume, self.tsdf_bnds, c, 'high', device)
444 | ret = ret.cpu().numpy()[:, -1]
445 |
446 | eval_tsdf_mask = eval_tsdf_mask.reshape(ret.shape)
447 | z.append(ret)
448 |
449 | z = np.concatenate(z, axis=0)
450 | z[~mask] = 100
451 | z = z.astype(np.float32)
452 |
453 | z_uni_m = z.reshape(
454 | grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
455 | grid['xyz'][2].shape[0]).transpose([1, 0, 2])
456 |
457 | print('begin marching cube...')
458 | combine_occ_tsdf = z_uni_m
459 |
460 | try:
461 | if version.parse(
462 | skimage.__version__) > version.parse('0.15.0'):
463 | # for new version as provided in environment.yaml
464 | verts, faces, normals, values = skimage.measure.marching_cubes(
465 | volume=combine_occ_tsdf,
466 | level=self.level_set,
467 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
468 | grid['xyz'][1][2] - grid['xyz'][1][1],
469 | grid['xyz'][2][2] - grid['xyz'][2][1]))
470 | else:
471 | # for lower version
472 | verts, faces, normals, values = skimage.measure.marching_cubes_lewiner(
473 | volume=combine_occ_tsdf,
474 | level=self.level_set,
475 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
476 | grid['xyz'][1][2] - grid['xyz'][1][1],
477 | grid['xyz'][2][2] - grid['xyz'][2][1]))
478 | except:
479 | print(
480 | 'marching_cubes error. Possibly no surface extracted from the level set.'
481 | )
482 | return
483 |
484 | # convert back to world coordinates
485 | vertices = verts + np.array(
486 | [grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
487 |
488 | if clean_mesh:
489 | points = vertices
490 | mesh = trimesh.Trimesh(vertices=vertices,
491 | faces=faces,
492 | process=False)
493 | seen_mask, _, unseen_mask = self.point_masks(
494 | points, keyframe_dict, estimate_c2w_list, idx, device=device,
495 | get_mask_use_all_frames=get_mask_use_all_frames)
496 | unseen_mask = ~seen_mask
497 | face_mask = unseen_mask[mesh.faces].all(axis=1)
498 | mesh.update_faces(~face_mask)
499 |
500 | # get connected components
501 | components = mesh.split(only_watertight=False)
502 | if self.get_largest_components:
503 | areas = np.array([c.area for c in components], dtype=np.float)
504 | mesh = components[areas.argmax()]
505 | else:
506 | new_components = []
507 | for comp in components:
508 | if comp.area > self.remove_small_geometry_threshold * self.scale * self.scale:
509 | new_components.append(comp)
510 | mesh = trimesh.util.concatenate(new_components)
511 | vertices = mesh.vertices
512 | faces = mesh.faces
513 |
514 | if color:
515 | if self.color_mesh_extraction_method == 'direct_point_query':
516 | # color is extracted by passing the coordinates of mesh vertices through the network
517 | points = torch.from_numpy(vertices)
518 | z = []
519 | for i, pnts in enumerate(
520 | torch.split(points, self.points_batch_size, dim=0)):
521 | ret = self.eval_points(
522 | pnts.to(device).float(), decoders, tsdf_volume, self.tsdf_bnds, c, 'color',
523 | device)
524 | z_color = ret.cpu()[..., :3]
525 | z.append(z_color)
526 | z = torch.cat(z, axis=0)
527 | vertex_colors = z.numpy()
528 |
529 | vertex_colors = np.clip(vertex_colors, 0, 1) * 255
530 | vertex_colors = vertex_colors.astype(np.uint8)
531 |
532 |
533 | else:
534 | vertex_colors = None
535 |
536 | vertices /= self.scale
537 | mesh = trimesh.Trimesh(vertices, faces, vertex_colors=vertex_colors)
538 | mesh.export(mesh_out_file)
539 | if self.verbose:
540 | print('Saved mesh at', mesh_out_file)
541 |
542 | return z_uni_m
543 |
--------------------------------------------------------------------------------
/src/utils/Renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.common import get_rays, raw2outputs_nerf_color, sample_pdf, normalize_3d_coordinate
3 | import torch.nn.functional as F
4 |
5 |
6 | class Renderer(object):
7 | def __init__(self, cfg, args, slam, points_batch_size=500000, ray_batch_size=100000):
8 | self.ray_batch_size = ray_batch_size
9 | self.points_batch_size = points_batch_size
10 |
11 | self.lindisp = cfg['rendering']['lindisp']
12 | self.perturb = cfg['rendering']['perturb']
13 | self.N_samples = cfg['rendering']['N_samples']
14 | self.N_surface = cfg['rendering']['N_surface']
15 | self.N_importance = cfg['rendering']['N_importance']
16 |
17 | self.scale = cfg['scale']
18 | self.occupancy = cfg['occupancy']
19 | self.bound = slam.bound
20 | self.sample_mode = 'bilinear'
21 | self.tsdf_bnds = slam.vol_bnds
22 |
23 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = slam.H, slam.W, slam.fx, slam.fy, slam.cx, slam.cy
24 |
25 | self.resolution = cfg['meshing']['resolution']
26 |
27 | def eval_points(self, p, decoders, tsdf_volume, tsdf_bnds, c=None, stage='color', device='cuda:0'):
28 | """
29 | Evaluates the occupancy and/or color value for the points.
30 |
31 | Args:
32 | p (tensor, N*3): Point coordinates.
33 | decoders (nn.module decoders): Decoders.
34 | tsdf_volume (tensor): tsdf volume.
35 | tsdf_bnds (tensor): tsdf volume bounds.
36 | c (dicts, optional): Feature grids. Defaults to None.
37 | stage (str, optional): Query stage, corresponds to different levels. Defaults to 'color'.
38 | device (str, optional): CUDA device. Defaults to 'cuda:0'.
39 |
40 | Returns:
41 | ret (tensor): occupancy (and color) value of input points.
42 | """
43 |
44 | p_split = torch.split(p, self.points_batch_size)
45 | bound = self.bound
46 | rets = []
47 | weights = []
48 |
49 | for pi in p_split:
50 | # mask for points out of bound
51 | mask_x = (pi[:, 0] < bound[0][1]) & (pi[:, 0] > bound[0][0])
52 | mask_y = (pi[:, 1] < bound[1][1]) & (pi[:, 1] > bound[1][0])
53 | mask_z = (pi[:, 2] < bound[2][1]) & (pi[:, 2] > bound[2][0])
54 | mask = mask_x & mask_y & mask_z
55 |
56 | pi = pi.unsqueeze(0)
57 | ret, w = decoders(pi, c_grid=c, tsdf_volume=tsdf_volume, tsdf_bnds=tsdf_bnds, stage=stage)
58 | ret = ret.squeeze(0)
59 |
60 |
61 | if len(ret.shape) == 1 and ret.shape[0] == 4:
62 | ret = ret.unsqueeze(0)
63 |
64 | ret[~mask, 3] = 100
65 | rets.append(ret)
66 | weights.append(w)
67 |
68 | ret = torch.cat(rets, dim=0)
69 | weight = torch.cat(weights, dim=0)
70 |
71 | return ret, weight
72 |
73 | def sample_grid_tsdf(self, p, tsdf_volume, device='cuda:0'):
74 |
75 | p_nor = normalize_3d_coordinate(p.clone(), self.tsdf_bnds)
76 | p_nor = p_nor.unsqueeze(0)
77 | vgrid = p_nor[:, :, None, None].float()
78 | # acutally trilinear interpolation if mode = 'bilinear'
79 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True,
80 | mode='bilinear').squeeze(-1).squeeze(-1)
81 | return tsdf_value
82 |
83 |
84 | def eval_points_tsdf(self, p, tsdf_volume, device='cuda:0'):
85 | """
86 | Evaluates the occupancy and/or color value for the points.
87 |
88 | Args:
89 | p (tensor, N*3): Point coordinates.
90 |
91 |
92 | Returns:
93 | ret (tensor): tsdf value of input points.
94 | """
95 |
96 | p_split = torch.split(p, self.points_batch_size)
97 | tsdf_vals = []
98 | for pi in p_split:
99 | pi = pi.unsqueeze(0)
100 | tsdf_volume_tensor = tsdf_volume
101 |
102 | tsdf_val = self.sample_grid_tsdf(pi, tsdf_volume_tensor, device)
103 | tsdf_val = tsdf_val.squeeze(0)
104 | tsdf_vals.append(tsdf_val)
105 |
106 | tsdf_values = torch.cat(tsdf_vals, dim=1)
107 | return tsdf_values
108 |
109 |
110 | def render_batch_ray(self, c, decoders, rays_d, rays_o, device, tsdf_volume, tsdf_bnds, stage, gt_depth=None):
111 | """
112 | Render color, depth and uncertainty of a batch of rays.
113 |
114 | Args:
115 | c (dict): feature grids.
116 | decoders (nn.module): decoders.
117 | rays_d (tensor, N*3): rays direction.
118 | rays_o (tensor, N*3): rays origin.
119 | device (str): device name to compute on.
120 | tsdf_volume (tensor): tsdf volume.
121 | tsdf_bnds (tensor): tsdf volume bounds.
122 | stage (str): query stage.
123 | gt_depth (tensor, optional): sensor depth image. Defaults to None.
124 |
125 | Returns:
126 | depth (tensor): rendered depth.
127 | uncertainty (tensor): rendered uncertainty.
128 | color (tensor): rendered color.
129 | weight (tensor): attention weight.
130 | """
131 | eval_tsdf_volume = tsdf_volume
132 |
133 |
134 | N_samples = self.N_samples
135 | N_surface = self.N_surface
136 | N_importance = self.N_importance
137 |
138 | N_rays = rays_o.shape[0]
139 |
140 | if gt_depth is None:
141 | N_surface = 0
142 | near = 0.01
143 | else:
144 | gt_depth = gt_depth.reshape(-1, 1)
145 | gt_depth_samples = gt_depth.repeat(1, N_samples)
146 | near = gt_depth_samples*0.01
147 |
148 | with torch.no_grad():
149 | det_rays_o = rays_o.clone().detach().unsqueeze(-1) # (N, 3, 1)
150 | det_rays_d = rays_d.clone().detach().unsqueeze(-1) # (N, 3, 1)
151 | t = (self.bound.unsqueeze(0).to(device) -
152 | det_rays_o)/det_rays_d # (N, 3, 2)
153 | far_bb, _ = torch.min(torch.max(t, dim=2)[0], dim=1)
154 | far_bb = far_bb.unsqueeze(-1)
155 | far_bb += 0.01
156 |
157 | if gt_depth is not None:
158 | # in case the bound is too large
159 | far = torch.clamp(far_bb, 0, torch.max(gt_depth*1.2))
160 |
161 | else:
162 | far = far_bb
163 | if N_surface > 0:
164 | if False:
165 | # this naive implementation downgrades performance
166 | gt_depth_surface = gt_depth.repeat(1, N_surface)
167 | t_vals_surface = torch.linspace(
168 | 0., 1., steps=N_surface).to(device)
169 | z_vals_surface = 0.95*gt_depth_surface * \
170 | (1.-t_vals_surface) + 1.05 * \
171 | gt_depth_surface * (t_vals_surface)
172 | else:
173 | # since we want to colorize even on regions with no depth sensor readings,
174 | # meaning colorize on interpolated geometry region,
175 | # we sample all pixels (not using depth mask) for color loss.
176 | # Therefore, for pixels with non-zero depth value, we sample near the surface,
177 | # since it is not a good idea to sample 16 points near (half even behind) camera,
178 | # for pixels with zero depth value, we sample uniformly from camera to max_depth.
179 | gt_none_zero_mask = gt_depth > 0
180 | gt_none_zero = gt_depth[gt_none_zero_mask]
181 | gt_none_zero = gt_none_zero.unsqueeze(-1)
182 | gt_depth_surface = gt_none_zero.repeat(1, N_surface)
183 | t_vals_surface = torch.linspace(
184 | 0., 1., steps=N_surface).double().to(device)
185 | # emperical range 0.05*depth
186 | z_vals_surface_depth_none_zero = 0.95*gt_depth_surface * \
187 | (1.-t_vals_surface) + 1.05 * \
188 | gt_depth_surface * (t_vals_surface)
189 | z_vals_surface = torch.zeros(
190 | gt_depth.shape[0], N_surface).to(device).double()
191 | gt_none_zero_mask = gt_none_zero_mask.squeeze(-1)
192 | z_vals_surface[gt_none_zero_mask,
193 | :] = z_vals_surface_depth_none_zero
194 | near_surface = 0.001
195 | far_surface = torch.max(gt_depth)
196 | z_vals_surface_depth_zero = near_surface * \
197 | (1.-t_vals_surface) + far_surface * (t_vals_surface)
198 | z_vals_surface_depth_zero.unsqueeze(
199 | 0).repeat((~gt_none_zero_mask).sum(), 1)
200 | z_vals_surface[~gt_none_zero_mask,
201 | :] = z_vals_surface_depth_zero
202 |
203 | t_vals = torch.linspace(0., 1., steps=N_samples, device=device)
204 |
205 | if not self.lindisp:
206 | z_vals = near * (1.-t_vals) + far * (t_vals)
207 | else:
208 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
209 |
210 | if self.perturb > 0.:
211 | # get intervals between samples
212 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
213 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
214 | lower = torch.cat([z_vals[..., :1], mids], -1)
215 | # stratified samples in those intervals
216 | t_rand = torch.rand(z_vals.shape).to(device)
217 | z_vals = lower + (upper - lower) * t_rand
218 |
219 | if N_surface > 0:
220 | z_vals, _ = torch.sort(
221 | torch.cat([z_vals, z_vals_surface.double()], -1), -1)
222 |
223 | pts = rays_o[..., None, :] + rays_d[..., None, :] * \
224 | z_vals[..., :, None] # [N_rays, N_samples+N_surface, 3]
225 | pointsf = pts.reshape(-1, 3)
226 |
227 | raw, weight = self.eval_points(pointsf, decoders, tsdf_volume, tsdf_bnds, c, stage, device)
228 | raw = raw.reshape(N_rays, N_samples+N_surface, -1)
229 | weight = weight.reshape(N_rays, N_samples+N_surface, -1)
230 |
231 |
232 | depth, uncertainty, color, weights = raw2outputs_nerf_color(
233 | raw, z_vals, rays_d, occupancy=self.occupancy, device=device)
234 |
235 | if N_importance > 0:
236 | z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
237 | z_samples = sample_pdf(
238 | z_vals_mid, weights[..., 1:-1], N_importance, det=(self.perturb == 0.), device=device)
239 | z_samples = z_samples.detach()
240 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
241 |
242 | pts = rays_o[..., None, :] + \
243 | rays_d[..., None, :] * z_vals[..., :, None]
244 | pts = pts.reshape(-1, 3)
245 |
246 | raw, weight = self.eval_points(pointsf, decoders, tsdf_volume, tsdf_bnds, c, stage, device)
247 | raw = raw.reshape(N_rays, N_samples+N_surface, -1)
248 | weight = weight.reshape(N_rays, N_samples+N_surface, -1)
249 |
250 | depth, uncertainty, color, weights = raw2outputs_nerf_color(
251 | raw, z_vals, rays_d, occupancy=self.occupancy, device=device)
252 | return depth, uncertainty, color, weight
253 |
254 |
255 | return depth, uncertainty, color, weight
256 |
257 |
258 | def render_img(self, c, decoders, c2w, device, tsdf_volume, tsdf_bnds, stage, gt_depth=None):
259 | """
260 | Renders out depth, uncertainty, and color images.
261 |
262 | Args:
263 | c (dict): feature grids.
264 | decoders (nn.module): decoders.
265 | c2w (tensor): camera to world matrix of current frame.
266 | device (str): device name to compute on.
267 | tsdf_volume (tensor): tsdf volume.
268 | tsdf_bnds (tensor): tsdf volume bounds.
269 | stage (str): query stage.
270 | gt_depth (tensor, optional): sensor depth image. Defaults to None.
271 |
272 | Returns:
273 | depth (tensor, H*W): rendered depth image.
274 | uncertainty (tensor, H*W): rendered uncertainty image.
275 | color (tensor, H*W*3): rendered color image.
276 | """
277 |
278 | with torch.no_grad():
279 | H = self.H
280 | W = self.W
281 | rays_o, rays_d = get_rays(
282 | H, W, self.fx, self.fy, self.cx, self.cy, c2w, device)
283 | rays_o = rays_o.reshape(-1, 3)
284 | rays_d = rays_d.reshape(-1, 3)
285 |
286 | depth_list = []
287 | uncertainty_list = []
288 | color_list = []
289 |
290 |
291 | ray_batch_size = self.ray_batch_size
292 | gt_depth = gt_depth.reshape(-1)
293 |
294 | for i in range(0, rays_d.shape[0], ray_batch_size):
295 | rays_d_batch = rays_d[i:i+ray_batch_size]
296 | rays_o_batch = rays_o[i:i+ray_batch_size]
297 |
298 | iter = 10
299 |
300 | if gt_depth is None:
301 | ret = self.render_batch_ray(
302 | c, decoders, rays_d_batch, rays_o_batch, device, tsdf_volume, tsdf_bnds, stage, gt_depth=None)
303 | else:
304 | gt_depth_batch = gt_depth[i:i+ray_batch_size]
305 | ret = self.render_batch_ray(
306 | c, decoders, rays_d_batch, rays_o_batch, device, tsdf_volume, tsdf_bnds, stage, gt_depth=gt_depth_batch)
307 |
308 | depth, uncertainty, color, _= ret
309 |
310 |
311 | depth_list.append(depth.double())
312 | uncertainty_list.append(uncertainty.double())
313 | color_list.append(color)
314 |
315 |
316 |
317 |
318 |
319 | depth = torch.cat(depth_list, dim=0)
320 | uncertainty = torch.cat(uncertainty_list, dim=0)
321 | color = torch.cat(color_list, dim=0)
322 |
323 | depth = depth.reshape(H, W)
324 | uncertainty = uncertainty.reshape(H, W)
325 | color = color.reshape(H, W, 3)
326 |
327 | return depth, uncertainty, color
328 |
329 |
330 |
--------------------------------------------------------------------------------
/src/utils/Visualizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from src.common import get_camera_from_tensor
6 | import open3d as o3d
7 |
8 | class Visualizer(object):
9 | """
10 | Visualize intermediate results, render out depth, color and depth uncertainty images.
11 | It can be called per iteration, which is good for debugging (to see how each tracking/mapping iteration performs).
12 |
13 | """
14 |
15 | def __init__(self, freq, inside_freq, vis_dir, renderer, verbose, device='cuda:0'):
16 | self.freq = freq
17 | self.device = device
18 | self.vis_dir = vis_dir
19 | self.verbose = verbose
20 | self.renderer = renderer
21 | self.inside_freq = inside_freq
22 | os.makedirs(f'{vis_dir}', exist_ok=True)
23 |
24 | def vis(self, idx, iter, gt_depth, gt_color, c2w_or_camera_tensor, c,
25 | decoders, tsdf_volume, tsdf_bnds):
26 | """
27 | Visualization of depth, color images and save to file.
28 |
29 | Args:
30 | idx (int): current frame index.
31 | iter (int): the iteration number.
32 | gt_depth (tensor): ground truth depth image of the current frame.
33 | gt_color (tensor): ground truth color image of the current frame.
34 | c2w_or_camera_tensor (tensor): camera pose, represented in
35 | camera to world matrix or quaternion and translation tensor.
36 | c (dicts): feature grids.
37 | decoders (nn.module): decoders.
38 | tsdf_volume (tensor): tsdf volume.
39 | tsdf_bnds (tensor): tsdf volume bounds.
40 | """
41 | with torch.no_grad():
42 | if (idx % self.freq == 0) and (iter % self.inside_freq == 0):
43 | gt_depth_np = gt_depth.cpu().numpy()
44 | gt_color_np = gt_color.cpu().numpy()
45 | if len(c2w_or_camera_tensor.shape) == 1:
46 | bottom = torch.from_numpy(
47 | np.array([0, 0, 0, 1.]).reshape([1, 4])).type(
48 | torch.float32).to(self.device)
49 | c2w = get_camera_from_tensor(
50 | c2w_or_camera_tensor.clone().detach())
51 | c2w = torch.cat([c2w, bottom], dim=0)
52 | else:
53 | c2w = c2w_or_camera_tensor
54 |
55 | depth, _, color = self.renderer.render_img(
56 | c,
57 | decoders,
58 | c2w,
59 | self.device,
60 | tsdf_volume,
61 | tsdf_bnds,
62 | stage='color',
63 | gt_depth=gt_depth)
64 |
65 | # convert to open3d camera pose
66 | c2w = c2w.cpu().numpy()
67 | c2w[:3, 1] *= -1.0
68 | c2w[:3, 2] *= -1.0
69 |
70 |
71 | depth_np = depth.detach().cpu().numpy()
72 | color_np = color.detach().cpu().numpy()
73 | depth = depth_np.astype(np.float32)
74 | color = np.array((color_np * 255).astype(np.uint8))
75 |
76 | depth_residual = np.abs(gt_depth_np - depth_np)
77 | depth_residual[gt_depth_np == 0.0] = 0.0
78 | color_residual = np.abs(gt_color_np - color_np)
79 | color_residual[gt_depth_np == 0.0] = 0.0
80 |
81 |
82 | fig, axs = plt.subplots(2, 3)
83 | fig.tight_layout()
84 | max_depth = np.max(gt_depth_np)
85 | axs[0, 0].imshow(gt_depth_np, cmap="plasma",
86 | vmin=0, vmax=max_depth)
87 | axs[0, 0].set_title('Input Depth')
88 | axs[0, 0].set_xticks([])
89 | axs[0, 0].set_yticks([])
90 | axs[0, 1].imshow(depth_np, cmap="plasma",
91 | vmin=0, vmax=max_depth)
92 | axs[0, 1].set_title('Generated Depth')
93 | axs[0, 1].set_xticks([])
94 | axs[0, 1].set_yticks([])
95 | axs[0, 2].imshow(depth_residual, cmap="plasma",
96 | vmin=0, vmax=max_depth)
97 | axs[0, 2].set_title('Depth Residual')
98 | axs[0, 2].set_xticks([])
99 | axs[0, 2].set_yticks([])
100 | gt_color_np = np.clip(gt_color_np, 0, 1)
101 | color_np = np.clip(color_np, 0, 1)
102 | color_residual = np.clip(color_residual, 0, 1)
103 | axs[1, 0].imshow(gt_color_np, cmap="plasma")
104 | axs[1, 0].set_title('Input RGB')
105 | axs[1, 0].set_xticks([])
106 | axs[1, 0].set_yticks([])
107 | axs[1, 1].imshow(color_np, cmap="plasma")
108 | axs[1, 1].set_title('Generated RGB')
109 | axs[1, 1].set_xticks([])
110 | axs[1, 1].set_yticks([])
111 | axs[1, 2].imshow(color_residual, cmap="plasma")
112 | axs[1, 2].set_title('RGB Residual')
113 | axs[1, 2].set_xticks([])
114 | axs[1, 2].set_yticks([])
115 | plt.subplots_adjust(wspace=0, hspace=0)
116 | plt.savefig(
117 | f'{self.vis_dir}/{idx:05d}_{iter:04d}.jpg', bbox_inches='tight', pad_inches=0.2)
118 | plt.clf()
119 |
120 | if self.verbose:
121 | print(
122 | f'Saved rendering visualization of color/depth image at {self.vis_dir}/{idx:05d}_{iter:04d}.jpg')
123 |
124 |
--------------------------------------------------------------------------------
/src/utils/__pycache__/Logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Logger.cpython-37.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Logger.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Mesher.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Mesher.cpython-37.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Mesher.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Mesher.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Renderer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Renderer.cpython-37.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Renderer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Renderer.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Visualizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Visualizer.cpython-37.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/Visualizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Visualizer.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/datasets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/datasets.cpython-37.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/datasets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/datasets.cpython-38.pyc
--------------------------------------------------------------------------------
/src/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from src.common import as_intrinsics_matrix
9 | from torch.utils.data import Dataset
10 |
11 |
12 | def readEXR_onlydepth(filename):
13 | """
14 | Read depth data from EXR image file.
15 |
16 | Args:
17 | filename (str): File path.
18 |
19 | Returns:
20 | Y (numpy.array): Depth buffer in float32 format.
21 | """
22 | # move the import here since only CoFusion needs these package
23 | # sometimes installation of openexr is hard, you can run all other datasets
24 | # even without openexr
25 | import Imath
26 | import OpenEXR as exr
27 |
28 | exrfile = exr.InputFile(filename)
29 | header = exrfile.header()
30 | dw = header['dataWindow']
31 | isize = (dw.max.y - dw.min.y + 1, dw.max.x - dw.min.x + 1)
32 |
33 | channelData = dict()
34 |
35 | for c in header['channels']:
36 | C = exrfile.channel(c, Imath.PixelType(Imath.PixelType.FLOAT))
37 | C = np.fromstring(C, dtype=np.float32)
38 | C = np.reshape(C, isize)
39 |
40 | channelData[c] = C
41 |
42 | Y = None if 'Y' not in header['channels'] else channelData['Y']
43 |
44 | return Y
45 |
46 |
47 | def get_dataset(cfg, args, scale, device='cuda:0'):
48 | return dataset_dict[cfg['dataset']](cfg, args, scale, device=device)
49 |
50 |
51 | class BaseDataset(Dataset):
52 | def __init__(self, cfg, args, scale, device='cuda:0'
53 | ):
54 | super(BaseDataset, self).__init__()
55 | self.name = cfg['dataset']
56 | self.device = device
57 | self.scale = scale
58 | self.png_depth_scale = cfg['cam']['png_depth_scale']
59 |
60 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = cfg['cam']['H'], cfg['cam'][
61 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']
62 |
63 | self.distortion = np.array(
64 | cfg['cam']['distortion']) if 'distortion' in cfg['cam'] else None
65 | self.crop_size = cfg['cam']['crop_size'] if 'crop_size' in cfg['cam'] else None
66 |
67 | if args.input_folder is None:
68 | self.input_folder = cfg['data']['input_folder']
69 | else:
70 | self.input_folder = args.input_folder
71 |
72 | self.crop_edge = cfg['cam']['crop_edge']
73 |
74 | def __len__(self):
75 | return self.n_img
76 |
77 | def __getitem__(self, index):
78 | color_path = self.color_paths[index]
79 | depth_path = self.depth_paths[index]
80 | color_data = cv2.imread(color_path)
81 | if '.png' in depth_path:
82 | depth_data = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
83 | elif '.exr' in depth_path:
84 | depth_data = readEXR_onlydepth(depth_path)
85 | if self.distortion is not None:
86 | K = as_intrinsics_matrix([self.fx, self.fy, self.cx, self.cy])
87 | # undistortion is only applied on color image, not depth!
88 | color_data = cv2.undistort(color_data, K, self.distortion)
89 |
90 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
91 | color_data = color_data / 255.
92 | depth_data = depth_data.astype(np.float32) / self.png_depth_scale
93 | H, W = depth_data.shape
94 | color_data = cv2.resize(color_data, (W, H))
95 | color_data = torch.from_numpy(color_data)
96 | depth_data = torch.from_numpy(depth_data)*self.scale
97 | if self.crop_size is not None:
98 | # follow the pre-processing step in lietorch, actually is resize
99 | color_data = color_data.permute(2, 0, 1)
100 | color_data = F.interpolate(
101 | color_data[None], self.crop_size, mode='bilinear', align_corners=True)[0]
102 | depth_data = F.interpolate(
103 | depth_data[None, None], self.crop_size, mode='nearest')[0, 0]
104 | color_data = color_data.permute(1, 2, 0).contiguous()
105 |
106 | edge = self.crop_edge
107 | if edge > 0:
108 | # crop image edge, there are invalid value on the edge of the color image
109 | color_data = color_data[edge:-edge, edge:-edge]
110 | depth_data = depth_data[edge:-edge, edge:-edge]
111 | pose = self.poses[index]
112 | pose[:3, 3] *= self.scale
113 | return index, color_data.to(self.device), depth_data.to(self.device), pose.to(self.device)
114 |
115 |
116 | class Replica(BaseDataset):
117 | def __init__(self, cfg, args, scale, device='cuda:0'
118 | ):
119 | super(Replica, self).__init__(cfg, args, scale, device)
120 | self.color_paths = sorted(
121 | glob.glob(f'{self.input_folder}/results/frame*.jpg'))
122 | self.depth_paths = sorted(
123 | glob.glob(f'{self.input_folder}/results/depth*.png'))
124 | self.n_img = len(self.color_paths)
125 | self.load_poses(f'{self.input_folder}/traj.txt')
126 |
127 | def load_poses(self, path):
128 | self.poses = []
129 | with open(path, "r") as f:
130 | lines = f.readlines()
131 | for i in range(self.n_img):
132 | line = lines[i]
133 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
134 | c2w[:3, 1] *= -1
135 | c2w[:3, 2] *= -1
136 | c2w = torch.from_numpy(c2w).float()
137 | self.poses.append(c2w)
138 |
139 |
140 | class Azure(BaseDataset):
141 | def __init__(self, cfg, args, scale, device='cuda:0'
142 | ):
143 | super(Azure, self).__init__(cfg, args, scale, device)
144 | self.color_paths = sorted(
145 | glob.glob(os.path.join(self.input_folder, 'color', '*.jpg')))
146 | self.depth_paths = sorted(
147 | glob.glob(os.path.join(self.input_folder, 'depth', '*.png')))
148 | self.n_img = len(self.color_paths)
149 | self.load_poses(os.path.join(
150 | self.input_folder, 'scene', 'trajectory.log'))
151 |
152 | def load_poses(self, path):
153 | self.poses = []
154 | if os.path.exists(path):
155 | with open(path) as f:
156 | content = f.readlines()
157 |
158 | # Load .log file.
159 | for i in range(0, len(content), 5):
160 | # format %d (src) %d (tgt) %f (fitness)
161 | data = list(map(float, content[i].strip().split(' ')))
162 | ids = (int(data[0]), int(data[1]))
163 | fitness = data[2]
164 |
165 | # format %f x 16
166 | c2w = np.array(
167 | list(map(float, (''.join(
168 | content[i + 1:i + 5])).strip().split()))).reshape((4, 4))
169 |
170 | c2w[:3, 1] *= -1
171 | c2w[:3, 2] *= -1
172 | c2w = torch.from_numpy(c2w).float()
173 | self.poses.append(c2w)
174 | else:
175 | for i in range(self.n_img):
176 | c2w = np.eye(4)
177 | c2w = torch.from_numpy(c2w).float()
178 | self.poses.append(c2w)
179 |
180 |
181 | class ScanNet(BaseDataset):
182 | def __init__(self, cfg, args, scale, device='cuda:0'
183 | ):
184 | super(ScanNet, self).__init__(cfg, args, scale, device)
185 | self.input_folder = os.path.join(self.input_folder, 'frames')
186 | self.color_paths = sorted(glob.glob(os.path.join(
187 | self.input_folder, 'color', '*.jpg')), key=lambda x: int(os.path.basename(x)[:-4]))
188 | self.depth_paths = sorted(glob.glob(os.path.join(
189 | self.input_folder, 'depth', '*.png')), key=lambda x: int(os.path.basename(x)[:-4]))
190 | self.load_poses(os.path.join(self.input_folder, 'pose'))
191 | self.n_img = len(self.color_paths)
192 |
193 | def load_poses(self, path):
194 | self.poses = []
195 | pose_paths = sorted(glob.glob(os.path.join(path, '*.txt')),
196 | key=lambda x: int(os.path.basename(x)[:-4]))
197 | for pose_path in pose_paths:
198 | with open(pose_path, "r") as f:
199 | lines = f.readlines()
200 | ls = []
201 | for line in lines:
202 | l = list(map(float, line.split(' ')))
203 | ls.append(l)
204 | c2w = np.array(ls).reshape(4, 4)
205 | c2w[:3, 1] *= -1
206 | c2w[:3, 2] *= -1
207 | c2w = torch.from_numpy(c2w).float()
208 | self.poses.append(c2w)
209 |
210 |
211 | class CoFusion(BaseDataset):
212 | def __init__(self, cfg, args, scale, device='cuda:0'
213 | ):
214 | super(CoFusion, self).__init__(cfg, args, scale, device)
215 | self.input_folder = os.path.join(self.input_folder)
216 | self.color_paths = sorted(
217 | glob.glob(os.path.join(self.input_folder, 'colour', '*.png')))
218 | self.depth_paths = sorted(glob.glob(os.path.join(
219 | self.input_folder, 'depth_noise', '*.exr')))
220 | self.n_img = len(self.color_paths)
221 | self.load_poses(os.path.join(self.input_folder, 'trajectories'))
222 |
223 | def load_poses(self, path):
224 | # We tried, but cannot align the coordinate frame of cofusion to ours.
225 | # So here we provide identity matrix as proxy.
226 | # But it will not affect the calculation of ATE since camera trajectories can be aligned.
227 | self.poses = []
228 | for i in range(self.n_img):
229 | c2w = np.eye(4)
230 | c2w = torch.from_numpy(c2w).float()
231 | self.poses.append(c2w)
232 |
233 |
234 | class TUM_RGBD(BaseDataset):
235 | def __init__(self, cfg, args, scale, device='cuda:0'
236 | ):
237 | super(TUM_RGBD, self).__init__(cfg, args, scale, device)
238 | self.color_paths, self.depth_paths, self.poses = self.loadtum(
239 | self.input_folder, frame_rate=32)
240 | self.n_img = len(self.color_paths)
241 |
242 | def parse_list(self, filepath, skiprows=0):
243 | """ read list data """
244 | data = np.loadtxt(filepath, delimiter=' ',
245 | dtype=np.unicode_, skiprows=skiprows)
246 | return data
247 |
248 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08):
249 | """ pair images, depths, and poses """
250 | associations = []
251 | for i, t in enumerate(tstamp_image):
252 | if tstamp_pose is None:
253 | j = np.argmin(np.abs(tstamp_depth - t))
254 | if (np.abs(tstamp_depth[j] - t) < max_dt):
255 | associations.append((i, j))
256 |
257 | else:
258 | j = np.argmin(np.abs(tstamp_depth - t))
259 | k = np.argmin(np.abs(tstamp_pose - t))
260 |
261 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \
262 | (np.abs(tstamp_pose[k] - t) < max_dt):
263 | associations.append((i, j, k))
264 |
265 | return associations
266 |
267 | def loadtum(self, datapath, frame_rate=-1):
268 | """ read video data in tum-rgbd format """
269 | if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')):
270 | pose_list = os.path.join(datapath, 'groundtruth.txt')
271 | elif os.path.isfile(os.path.join(datapath, 'pose.txt')):
272 | pose_list = os.path.join(datapath, 'pose.txt')
273 |
274 | image_list = os.path.join(datapath, 'rgb.txt')
275 | depth_list = os.path.join(datapath, 'depth.txt')
276 |
277 | image_data = self.parse_list(image_list)
278 | depth_data = self.parse_list(depth_list)
279 | pose_data = self.parse_list(pose_list, skiprows=1)
280 | pose_vecs = pose_data[:, 1:].astype(np.float64)
281 |
282 | tstamp_image = image_data[:, 0].astype(np.float64)
283 | tstamp_depth = depth_data[:, 0].astype(np.float64)
284 | tstamp_pose = pose_data[:, 0].astype(np.float64)
285 | associations = self.associate_frames(
286 | tstamp_image, tstamp_depth, tstamp_pose)
287 |
288 | indicies = [0]
289 | for i in range(1, len(associations)):
290 | t0 = tstamp_image[associations[indicies[-1]][0]]
291 | t1 = tstamp_image[associations[i][0]]
292 | if t1 - t0 > 1.0 / frame_rate:
293 | indicies += [i]
294 |
295 | images, poses, depths, intrinsics = [], [], [], []
296 | inv_pose = None
297 | for ix in indicies:
298 | (i, j, k) = associations[ix]
299 | images += [os.path.join(datapath, image_data[i, 1])]
300 | depths += [os.path.join(datapath, depth_data[j, 1])]
301 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k])
302 | if inv_pose is None:
303 | inv_pose = np.linalg.inv(c2w)
304 | c2w = np.eye(4)
305 | else:
306 | c2w = inv_pose@c2w
307 | c2w[:3, 1] *= -1
308 | c2w[:3, 2] *= -1
309 | c2w = torch.from_numpy(c2w).float()
310 | poses += [c2w]
311 |
312 | return images, depths, poses
313 |
314 | def pose_matrix_from_quaternion(self, pvec):
315 | """ convert 4x4 pose matrix to (t, q) """
316 | from scipy.spatial.transform import Rotation
317 |
318 | pose = np.eye(4)
319 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix()
320 | pose[:3, 3] = pvec[:3]
321 | return pose
322 |
323 |
324 | dataset_dict = {
325 | "replica": Replica,
326 | "scannet": ScanNet,
327 | "cofusion": CoFusion,
328 | "azure": Azure,
329 | "tumrgbd": TUM_RGBD
330 | }
331 |
--------------------------------------------------------------------------------