├── LICENSE ├── README.md ├── configs ├── Replica │ ├── office0.yaml │ ├── office1.yaml │ ├── office2.yaml │ ├── office3.yaml │ ├── office4.yaml │ ├── replica.yaml │ ├── room0.yaml │ ├── room1.yaml │ └── room2.yaml ├── ScanNet │ ├── scannet.yaml │ ├── scene0000.yaml │ ├── scene0002.yaml │ ├── scene0005.yaml │ ├── scene0012.yaml │ ├── scene0050.yaml │ ├── scene0059.yaml │ ├── scene0084.yaml │ ├── scene0106.yaml │ ├── scene0169.yaml │ ├── scene0181.yaml │ ├── scene0207.yaml │ ├── scene0472.yaml │ ├── scene0580.yaml │ └── scene0616.yaml └── df_prior.yaml ├── environment.yaml ├── get_tsdf.py ├── pretrained └── low_high.pt ├── run.py ├── scripts ├── download_cull_replica_mesh.sh └── download_replica.sh └── src ├── DF_Prior.py ├── Mapper.py ├── Tracker.py ├── __init__.py ├── __pycache__ ├── DF_Prior.cpython-37.pyc ├── Mapper.cpython-37.pyc ├── Mapper.cpython-38.pyc ├── NICE_SLAM.cpython-37.pyc ├── NICE_SLAM.cpython-38.pyc ├── Tracker.cpython-37.pyc ├── Tracker.cpython-38.pyc ├── __init__.cpython-310.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── common.cpython-310.pyc ├── common.cpython-37.pyc ├── common.cpython-38.pyc ├── config.cpython-310.pyc ├── config.cpython-37.pyc ├── config.cpython-38.pyc ├── depth2pointcloud.cpython-37.pyc ├── fusion.cpython-310.pyc ├── fusion.cpython-37.pyc └── fusion.cpython-38.pyc ├── common.py ├── config.py ├── conv_onet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── config.cpython-310.pyc │ ├── config.cpython-37.pyc │ └── config.cpython-38.pyc ├── config.py └── models │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── decoder.cpython-310.pyc │ ├── decoder.cpython-37.pyc │ └── decoder.cpython-38.pyc │ └── decoder.py ├── fusion.py ├── tools ├── cull_mesh.py ├── eval_ate.py ├── eval_recon.py └── evaluate_scannet.py └── utils ├── Logger.py ├── Mesher.py ├── Renderer.py ├── Visualizer.py ├── __pycache__ ├── Logger.cpython-37.pyc ├── Logger.cpython-38.pyc ├── Mesher.cpython-37.pyc ├── Mesher.cpython-38.pyc ├── Renderer.cpython-37.pyc ├── Renderer.cpython-38.pyc ├── Visualizer.cpython-37.pyc ├── Visualizer.cpython-38.pyc ├── datasets.cpython-37.pyc └── datasets.cpython-38.pyc └── datasets.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MachinePerceptionLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

Learning Neural Implicit through Volume Rendering with Attentive Depth Fusion Priors

4 |

5 | Pengchong Hu 6 | · 7 | Zhizhong Han 8 | 9 |

10 |

NeurIPS 2023

11 |

Paper | Project Page

12 |
13 |

14 | 15 | 16 |
17 | Table of Contents 18 |
    19 |
  1. 20 | Installation 21 |
  2. 22 |
  3. 23 | Dataset 24 |
  4. 25 |
  5. 26 | Run 27 |
  6. 28 |
  7. 29 | Evaluation 30 |
  8. 31 |
  9. 32 | Acknowledgement 33 |
  10. 34 |
  11. 35 | Citation 36 |
  12. 37 |
38 |
39 | 40 | 41 | ## Installation 42 | Please install all dependencies by following the instrutions here. You can use [anaconda](https://www.anaconda.com/) to finish the installation easily. 43 | 44 | You can bulid a conda environment called `df-prior`. Note that for linux users, you need to install `libopenexr-dev` first before building the environment. 45 | 46 | ```bash 47 | git clone https://github.com/MachinePerceptionLab/Attentive_DFPrior.git 48 | cd Attentive_DFPrior 49 | 50 | sudo apt-get install libopenexr-dev 51 | 52 | conda env create -f environment.yaml 53 | conda activate df-prior 54 | ``` 55 | 56 | ## Dataset 57 | 58 | ### Replica 59 | 60 | Please download the Replica dataset generated by the authors of iMAP into `./Datasets/Replica` folder. 61 | 62 | ```bash 63 | bash scripts/download_replica.sh # Released by authors of NICE-SLAM 64 | ``` 65 | 66 | ### ScanNet 67 | 68 | Please follow the data downloading procedure on [ScanNet](http://www.scan-net.org/) website, and extract color/depth frames from the `.sens` file using this [code](https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/reader.py). 69 | 70 |
71 | [Directory structure of ScanNet (click to expand)] 72 | 73 | DATAROOT is `./Datasets` by default. If a sequence (`sceneXXXX_XX`) is stored in other places, please change the `input_folder` path in the config file or in the command line. 74 | 75 | ``` 76 | DATAROOT 77 | └── scannet 78 | └── scans 79 | └── scene0000_00 80 | └── frames 81 | ├── color 82 | │ ├── 0.jpg 83 | │ ├── 1.jpg 84 | │ ├── ... 85 | │ └── ... 86 | ├── depth 87 | │ ├── 0.png 88 | │ ├── 1.png 89 | │ ├── ... 90 | │ └── ... 91 | ├── intrinsic 92 | └── pose 93 | ├── 0.txt 94 | ├── 1.txt 95 | ├── ... 96 | └── ... 97 | 98 | ``` 99 |
100 | 101 | ## Run 102 | 103 | To run our code, you first need to generate the TSDF volume and corresponding bounds. We provide the generated TSDF volume and bounds for Replica and ScanNet: replica_tsdf_volume.tar, scannet_tsdf_volume.tar. 104 | 105 | You also can generate the TSDF volume and corresponding bounds by using the following code: 106 | 107 | ```bash 108 | CUDA_VISIVLE_DEVICES=0 python get_tsdf.py configs/Replica/room0.yaml --space 1 # For Replica 109 | CUDA_VISIVLE_DEVICES=0 python get_tsdf.py configs/ScanNet/scene0050_00.yaml --space 10 # For ScanNet 110 | ``` 111 | 112 | You can run DF-Prior by using the following code: 113 | 114 | ```bash 115 | CUDA_VISIVLE_DEVICES=0 python -W ignore run.py configs/Replica/room0.yaml # For Replica 116 | CUDA_VISIVLE_DEVICES=0 python -W ignore run.py configs/ScanNet/scene0050.yaml # For ScanNet 117 | ``` 118 | 119 | The mesh for evaluation is saved as `$OUTPUT_FOLDER/mesh/final_mesh_eval_rec.ply`, where the unseen regions are culled using all frames. 120 | 121 | 122 | 123 | 124 | ## Evaluation 125 | ### Average Trajectory Error 126 | To evaluate the average trajectory error. Run the command below with the corresponding config file: 127 | ```bash 128 | python src/tools/eval_ate.py configs/Replica/room0.yaml 129 | ``` 130 | 131 | ### Reconstruction Error 132 | #### Replica 133 | To evaluate the reconstruction error in Replica, first download the ground truth Replica meshes where unseen region have been culled. 134 | ```bash 135 | bash scripts/download_cull_replica_mesh.sh # Released by authors of NICE-SLAM 136 | ``` 137 | Then run the command below. The 2D metric requires rendering of 1000 depth images, which will take some time (~9 minutes). Use `-2d` to enable 2D metric. Use `-3d` to enable 3D metric. 138 | ```bash 139 | # assign any output_folder and gt mesh you like, here is just an example 140 | OUTPUT_FOLDER=output/Replica/room0 141 | GT_MESH=cull_replica_mesh/room0.ply 142 | python src/tools/eval_recon.py --rec_mesh $OUTPUT_FOLDER/mesh/final_mesh_eval_rec.ply --gt_mesh $GT_MESH -2d -3d 143 | ``` 144 | 145 | We also provide code to cull the mesh given camera poses. Here we take culling of ground truth mesh of Replica room0 as an example. 146 | ```bash 147 | python src/tools/cull_mesh.py --input_mesh Datasets/Replica/room0_mesh.ply --traj Datasets/Replica/room0/traj.txt --output_mesh cull_replica_mesh/room0.ply 148 | ``` 149 | #### ScanNet 150 | To evaluate the reconstruction error in ScanNet, first download the ground truth ScanNet meshes.zip into `./Datasets/scannet` folder. Then run the command below. 151 | ```bash 152 | python src/tools/evaluate_scannet.py configs/ScanNet/scene0050.yaml 153 | ``` 154 | We also provide our reconstructed meshes in Replica and ScanNet for evaluation purposes: meshes.zip. 155 | 156 | ## Acknowledgement 157 | We adapt codes from some awesome repositories, including [NICE-SLAM](https://github.com/cvg/nice-slam), [NeuralRGBD](https://github.com/dazinovic/neural-rgbd-surface-reconstruction), [tsdf-fusion](https://github.com/andyzeng/tsdf-fusion-python), [manhattan-sdf](https://github.com/zju3dv/manhattan_sdf), [MonoSDF](https://github.com/autonomousvision/monosdf). Thanks for making the code available. We also thank [Zihan Zhu](https://zzh2000.github.io/) of [NICE-SLAM](https://github.com/cvg/nice-slam), for prompt responses to our inquiries regarding the details of their methods. 158 | 159 | ## Citation 160 | If you find our code or paper useful, please cite 161 | ```bibtex 162 | @inproceedings{Hu2023LNI-ADFP, 163 | title = {Learning Neural Implicit through Volume Rendering with Attentive Depth Fusion Priors}, 164 | author = {Hu, Pengchong and Han, Zhizhong}, 165 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 166 | year = {2023} 167 | } 168 | ``` 169 | -------------------------------------------------------------------------------- /configs/Replica/office0.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-5.5,5.9],[-6.7,5.4],[-4.7,5.3]] 4 | marching_cubes_bound: [[-5.5,5.9],[-6.7,5.4],[-4.7,5.3]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/office0 8 | output: output/Replica/office0 9 | id: office0 -------------------------------------------------------------------------------- /configs/Replica/office1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-5.3,6.5],[-5.1,6.0],[-4.5,5.2]] 4 | marching_cubes_bound: [[-5.3,6.5],[-5.1,6.0],[-4.5,5.2]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/office1 8 | output: output/Replica/office1 9 | id: office1 -------------------------------------------------------------------------------- /configs/Replica/office2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-5.0,4.6],[-4.4,6.9],[-2.8,3.1]] 4 | marching_cubes_bound: [[-5.0,4.6],[-4.4,6.9],[-2.8,3.1]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/office2 8 | output: output/Replica/office2 9 | id: office2 -------------------------------------------------------------------------------- /configs/Replica/office3.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-6.7,5.1],[-7.5,4.9],[-2.8,3.5]] 4 | marching_cubes_bound: [[-6.7,5.1],[-7.5,4.9],[-2.8,3.5]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/office3 8 | output: output/Replica/office3 9 | id: office3 10 | -------------------------------------------------------------------------------- /configs/Replica/office4.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-3.7,7.8],[-4.8,6.7],[-3.7,4.1]] 4 | marching_cubes_bound: [[-3.7,7.8],[-4.8,6.7],[-3.7,4.1]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/office4 8 | output: output/Replica/office4 9 | id: office4 -------------------------------------------------------------------------------- /configs/Replica/replica.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'replica' 2 | meshing: 3 | eval_rec: True 4 | tracking: 5 | vis_freq: 50 6 | vis_inside_freq: 25 7 | ignore_edge_W: 100 8 | ignore_edge_H: 100 9 | seperate_LR: False 10 | const_speed_assumption: True 11 | lr: 0.001 12 | pixels: 200 13 | iters: 10 14 | gt_camera: False 15 | mapping: 16 | every_frame: 5 17 | vis_freq: 50 18 | vis_inside_freq: 30 19 | mesh_freq: 50 20 | ckpt_freq: 500 21 | keyframe_every: 50 22 | mapping_window_size: 5 23 | pixels: 1000 24 | iters_first: 1500 25 | iters: 60 26 | stage: 27 | low: 28 | mlp_lr: 0.0 29 | decoders_lr: 0.0 30 | low_lr: 0.1 31 | high_lr: 0.0 32 | color_lr: 0.0 33 | high: 34 | mlp_lr: 0.0 35 | decoders_lr: 0.0 36 | low_lr: 0.005 37 | high_lr: 0.005 38 | color_lr: 0.0 39 | color: 40 | mlp_lr: 0.005 41 | decoders_lr: 0.005 42 | low_lr: 0.005 43 | high_lr: 0.005 44 | color_lr: 0.005 45 | cam: 46 | H: 680 47 | W: 1200 48 | fx: 600.0 49 | fy: 600.0 50 | cx: 599.5 51 | cy: 339.5 52 | png_depth_scale: 6553.5 #for depth image in png format 53 | crop_edge: 0 -------------------------------------------------------------------------------- /configs/Replica/room0.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-2.9,8.9],[-3.2,5.5],[-3.5,3.3]] 4 | marching_cubes_bound: [[-2.9,8.9],[-3.2,5.5],[-3.5,3.3]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/room0 8 | output: output/Replica/room0 9 | id: room0 -------------------------------------------------------------------------------- /configs/Replica/room1.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-7.0,2.8],[-4.6,4.3],[-3.0,2.9]] 4 | marching_cubes_bound: [[-7.0,2.8],[-4.6,4.3],[-3.0,2.9]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/room1 8 | output: output/Replica/room1 9 | id: room1 -------------------------------------------------------------------------------- /configs/Replica/room2.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/Replica/replica.yaml 2 | mapping: 3 | bound: [[-4.3,9.5],[-6.7,5.2],[-6.4,4.2]] 4 | marching_cubes_bound: [[-4.3,9.5],[-6.7,5.2],[-6.4,4.2]] 5 | data: 6 | dataset: replica 7 | input_folder: Datasets/Replica/room2 8 | output: output/Replica/room2 9 | id: room2 -------------------------------------------------------------------------------- /configs/ScanNet/scannet.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'scannet' 2 | tracking: 3 | vis_freq: 50 4 | vis_inside_freq: 25 5 | ignore_edge_W: 20 6 | ignore_edge_H: 20 7 | seperate_LR: False 8 | const_speed_assumption: True 9 | lr: 0.0005 10 | pixels: 1000 11 | iters: 50 12 | gt_camera: True #False 13 | mapping: 14 | every_frame: 5 15 | vis_freq: 50 16 | vis_inside_freq: 30 17 | mesh_freq: 50 18 | ckpt_freq: 500 19 | keyframe_every: 5 20 | mapping_window_size: 10 21 | pixels: 5000 22 | iters_first: 1500 23 | iters: 60 24 | cam: 25 | H: 480 26 | W: 640 27 | fx: 577.590698 28 | fy: 578.729797 29 | cx: 318.905426 30 | cy: 242.683609 31 | png_depth_scale: 1000. #for depth image in png format 32 | crop_edge: 10 -------------------------------------------------------------------------------- /configs/ScanNet/scene0000.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-2.0,11.0],[-2.0,11.5],[-2.0,5.5]] 4 | marching_cubes_bound: [[-2.0,11.0],[-2.0,11.5],[-2.0,5.5]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0000_00 8 | output: output/scannet/scans/scene0000_00 9 | id: 00 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0002.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[0.6,5.2],[0.0,5.6],[0.1,3.4]] #[[0.0,5.5],[0.0,6.0],[-0.5,3.5]] 4 | marching_cubes_bound: [[0.6,5.2],[0.0,5.6],[0.1,3.4]] #[[0.0,5.5],[0.0,6.0],[-0.5,3.5]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0002_00 8 | output: output/scannet/scans/scene0002_00 9 | id: 02 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0005.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.1,5.5],[0.4,5.4],[-0.1,2.5]] # [[-0.5,5.5],[0.0,5.5],[-0.5,3.0]] 4 | marching_cubes_bound: [[-0.1,5.5],[0.4,5.4],[-0.1,2.5]] # [[-0.5,5.5],[0.0,5.5],[-0.5,3.0]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0005_00 8 | output: output/scannet/scans/scene0005_00 9 | id: 05 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0012.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.5,5.5],[-0.5,5.5],[-0.5,3.0]] 4 | marching_cubes_bound: [[-0.5,5.5],[-0.5,5.5],[-0.5,3.0]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0012_00 8 | output: output/scannet/scans/scene0012_00 9 | id: 12 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0050.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[0.5,7.0],[0.0,4.5],[-0.5,3.0]] 4 | marching_cubes_bound: [[0.5,7.0],[0.0,4.5],[-0.5,3.0]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0050_00 8 | output: output/scannet/scans/scene0050_00 9 | id: 50 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0059.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.9,7.3],[-1.0,9.6],[-1.0,3.7]] 4 | marching_cubes_bound: [[-0.9,7.3],[-1.0,9.6],[-1.0,3.7]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0059_00 8 | output: output/scannet/scans/scene0059_00 9 | id: 59 -------------------------------------------------------------------------------- /configs/ScanNet/scene0084.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.5,3.0],[-0.5,7.5],[0.0,2.5]] 4 | marching_cubes_bound: [[-0.5,3.0],[-0.5,7.5],[0.0,2.5]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0084_00 8 | output: output/scannet/scans/scene0084_00 9 | id: 84 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0106.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-1.1,9.8],[-1.0,10.0],[-1.0,4.3]] 4 | marching_cubes_bound: [[-1.1,9.8],[-1.0,10.0],[-1.0,4.3]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0106_00 8 | output: output/scannet/scans/scene0106_00 9 | id: 106 -------------------------------------------------------------------------------- /configs/ScanNet/scene0169.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.2,9.8],[-1.0,8.5],[-1.0,3.4]] 4 | marching_cubes_bound: [[-0.2,9.8],[-1.0,8.5],[-1.0,3.4]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0169_00 8 | output: output/scannet/scans/scene0169_00 9 | id: 169 -------------------------------------------------------------------------------- /configs/ScanNet/scene0181.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-1.0,8.9],[-0.9,8.0],[-1.0,3.6]] 4 | marching_cubes_bound: [[-1.0,8.9],[-0.9,8.0],[-1.0,3.6]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0181_00 8 | output: output/scannet/scans/scene0181_00 9 | id: 181 -------------------------------------------------------------------------------- /configs/ScanNet/scene0207.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[0.3,9.9],[-1.0,8.0],[-1.0,3.8]] 4 | marching_cubes_bound: [[0.3,9.9],[-1.0,8.0],[-1.0,3.8]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0207_00 8 | output: output/scannet/scans/scene0207_00 9 | id: 207 -------------------------------------------------------------------------------- /configs/ScanNet/scene0472.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.6,9.5],[-1.5,9.5],[-1.5,3.5]] 4 | marching_cubes_bound: [[-0.6,9.5],[-1.5,9.5],[-1.5,3.5]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0472_00 8 | output: output/scannet/scans/scene0472_00 9 | id: 472 -------------------------------------------------------------------------------- /configs/ScanNet/scene0580.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[0.0,5.5],[-0.5,3.5],[0.0,4.0]] 4 | marching_cubes_bound: [[0.0,5.5],[-0.5,3.5],[0.0,4.0]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0580_00 8 | output: output/scannet/scans/scene0580_00 9 | id: 580 10 | -------------------------------------------------------------------------------- /configs/ScanNet/scene0616.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: configs/ScanNet/scannet.yaml 2 | mapping: 3 | bound: [[-0.5,6.0],[-0.5,5.0],[-0.5,3.0]] 4 | marching_cubes_bound: [[-0.5,6.0],[-0.5,5.0],[-0.5,3.0]] 5 | data: 6 | dataset: scannet 7 | input_folder: Datasets/scannet/scans/scene0616_00 8 | output: output/scannet/scans/scene0616_00 9 | id: 616 10 | -------------------------------------------------------------------------------- /configs/df_prior.yaml: -------------------------------------------------------------------------------- 1 | sync_method: strict 2 | scale: 1 3 | verbose: True 4 | occupancy: True 5 | low_gpu_mem: True 6 | grid_len: 7 | low: 0.32 8 | high: 0.16 9 | color: 0.16 10 | bound_divisible: 0.32 11 | pretrained_decoders: 12 | low_high: pretrained/low_high.pt # one ckpt contain both low and high 13 | meshing: 14 | level_set: 0 15 | resolution: 256 # change to 512 for higher resolution geometry 16 | eval_rec: False 17 | clean_mesh: True 18 | depth_test: False 19 | clean_mesh_bound_scale: 1.02 20 | get_largest_components: False 21 | color_mesh_extraction_method: direct_point_query 22 | remove_small_geometry_threshold: 0.2 23 | tracking: 24 | ignore_edge_W: 20 25 | ignore_edge_H: 20 26 | use_color_in_tracking: True 27 | device: "cuda:0" 28 | handle_dynamic: True 29 | vis_freq: 50 30 | vis_inside_freq: 25 31 | w_color_loss: 0.5 32 | seperate_LR: False 33 | const_speed_assumption: True 34 | no_vis_on_first_frame: True 35 | gt_camera: True #False 36 | lr: 0.001 37 | pixels: 200 38 | iters: 10 39 | mapping: 40 | device: "cuda:0" 41 | color_refine: True 42 | low_iter_ratio: 0.4 43 | high_iter_ratio: 0.6 44 | every_frame: 5 45 | fix_high: True 46 | fix_color: False 47 | no_vis_on_first_frame: True 48 | no_mesh_on_first_frame: True 49 | no_log_on_first_frame: True 50 | vis_freq: 50 51 | vis_inside_freq: 25 #each iteration 52 | mesh_freq: 50 53 | ckpt_freq: 500 54 | keyframe_every: 50 55 | mapping_window_size: 5 56 | w_color_loss: 0.2 57 | frustum_feature_selection: True 58 | keyframe_selection_method: 'overlap' 59 | save_selected_keyframes_info: False 60 | lr_first_factor: 5 61 | lr_factor: 1 62 | pixels: 1000 63 | iters_first: 1500 64 | iters: 60 65 | stage: 66 | low: 67 | mlp_lr: 0.0 68 | decoders_lr: 0.0 69 | low_lr: 0.1 70 | high_lr: 0.0 71 | color_lr: 0.0 72 | high: 73 | mlp_lr: 0.005 74 | decoders_lr: 0.0 75 | low_lr: 0.005 76 | high_lr: 0.005 77 | color_lr: 0.0 78 | color: 79 | mlp_lr: 0.005 80 | decoders_lr: 0.005 81 | low_lr: 0.005 82 | high_lr: 0.005 83 | color_lr: 0.005 84 | cam: 85 | H: 680 86 | W: 1200 87 | fx: 600.0 88 | fy: 600.0 89 | cx: 599.5 90 | cy: 339.5 91 | png_depth_scale: 6553.5 #for depth image in png format 92 | crop_edge: 0 93 | rendering: 94 | N_samples: 32 95 | N_surface: 16 96 | N_importance: 0 97 | lindisp: False 98 | perturb: 0.0 99 | data: 100 | dim: 3 101 | model: 102 | c_dim: 32 103 | pos_embedding_method: fourier 104 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: df-prior 2 | channels: 3 | - pytorch-nightly 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=1_llvm 11 | - blas=1.0=mkl 12 | - brotli=1.0.9=he6710b0_2 13 | - brotlipy=0.7.0=py37h27cfd23_1003 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2023.05.30=h06a4308_0 16 | - certifi=2022.12.7=py37h06a4308_0 17 | - cffi=1.15.0=py37hd667e15_1 18 | - charset-normalizer=2.0.12=pyhd8ed1ab_0 19 | - cryptography=36.0.0=py37h9ce1e76_0 20 | - cuda-cudart=12.1.105=0 21 | - cuda-cupti=12.1.105=0 22 | - cuda-libraries=12.1.0=0 23 | - cuda-nvrtc=12.1.105=0 24 | - cuda-nvtx=12.1.105=0 25 | - cuda-opencl=12.1.105=0 26 | - cuda-runtime=12.1.0=0 27 | - cudatoolkit=11.3.1=h2bc3f7f_2 28 | - cycler=0.11.0=pyhd3eb1b0_0 29 | - dbus=1.13.18=hb2f20db_0 30 | - embree=2.17.7=ha770c72_1 31 | - expat=2.4.4=h295c915_0 32 | - ffmpeg=4.3=hf484d3e_0 33 | - fontconfig=2.13.1=h6c09931_0 34 | - fonttools=4.25.0=pyhd3eb1b0_0 35 | - freetype=2.10.4=h0708190_1 36 | - giflib=5.2.1=h7b6447c_0 37 | - glib=2.69.1=h4ff587b_1 38 | - gmp=6.2.1=h58526e2_0 39 | - gnutls=3.6.13=h85f3911_1 40 | - gst-plugins-base=1.14.0=h8213a91_2 41 | - gstreamer=1.14.0=h28cd5cc_2 42 | - icu=58.2=he6710b0_3 43 | - idna=3.3=pyhd3eb1b0_0 44 | - jpeg=9b=h024ee3a_2 45 | - kiwisolver=1.3.2=py37h295c915_0 46 | - lame=3.100=h7f98852_1001 47 | - lcms2=2.12=h3be6417_0 48 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 49 | - libcublas=12.1.0.26=0 50 | - libcufft=11.0.2.4=0 51 | - libcufile=1.6.1.9=0 52 | - libcurand=10.3.2.106=0 53 | - libcusolver=11.4.4.55=0 54 | - libcusparse=12.0.2.55=0 55 | - libffi=3.3=he6710b0_2 56 | - libgcc-ng=11.2.0=h1d223b6_14 57 | - libiconv=1.16=h516909a_0 58 | - libnpp=12.0.2.50=0 59 | - libnsl=2.0.0=h7f98852_0 60 | - libnvjitlink=12.1.105=0 61 | - libnvjpeg=12.1.0.39=0 62 | - libpng=1.6.37=h21135ba_2 63 | - libstdcxx-ng=11.2.0=he4da1e4_14 64 | - libtiff=4.2.0=h85742a9_0 65 | - libuuid=1.0.3=h7f8727e_2 66 | - libuv=1.43.0=h7f98852_0 67 | - libwebp=1.2.0=h89dd481_0 68 | - libwebp-base=1.2.0=h27cfd23_0 69 | - libxcb=1.14=h7b6447c_0 70 | - libxml2=2.9.12=h03d6c58_0 71 | - libzlib=1.2.11=h166bdaf_1014 72 | - llvm-openmp=13.0.1=he0ac6c6_1 73 | - lz4-c=1.9.3=h295c915_1 74 | - matplotlib=3.4.3=py37h06a4308_0 75 | - matplotlib-base=3.4.3=py37hbbc1b5f_0 76 | - mkl=2021.4.0=h8d4b97c_729 77 | - mkl-service=2.4.0=py37h402132d_0 78 | - mkl_fft=1.3.1=py37h3e078e5_1 79 | - mkl_random=1.2.2=py37h219a48f_0 80 | - munkres=1.1.4=py_0 81 | - ncurses=6.3=h9c3ff4c_0 82 | - nettle=3.6=he412f7d_0 83 | - ninja=1.10.2=h4bd325d_1 84 | - numpy-base=1.21.2=py37h79a1101_0 85 | - olefile=0.46=pyh9f0ad1d_1 86 | - openh264=2.1.1=h780b84a_0 87 | - openssl=1.1.1n=h166bdaf_0 88 | - pcre=8.45=h295c915_0 89 | - pip=22.0.4=pyhd8ed1ab_0 90 | - pycparser=2.21=pyhd3eb1b0_0 91 | - pyembree=0.1.6=py37h0da4684_1 92 | - pyopenssl=22.0.0=pyhd3eb1b0_0 93 | - pyparsing=3.0.4=pyhd3eb1b0_0 94 | - pyqt=5.9.2=py37h05f1152_2 95 | - pysocks=1.7.1=py37_1 96 | - python=3.7.11=h12debd9_0 97 | - python-dateutil=2.8.2=pyhd3eb1b0_0 98 | - python_abi=3.7=2_cp37m 99 | - pytorch-cuda=12.1=ha16c6d3_5 100 | - pytorch-mutex=1.0=cuda 101 | - qt=5.9.7=h5867ecd_1 102 | - readline=8.1=h46c0cb4_0 103 | - requests=2.27.1=pyhd3eb1b0_0 104 | - setuptools=61.2.0=py37h89c1867_3 105 | - sip=4.19.8=py37hf484d3e_0 106 | - six=1.16.0=pyh6c4a22f_0 107 | - sqlite=3.37.1=h4ff8645_0 108 | - tbb=2021.5.0=h4bd325d_0 109 | - tk=8.6.12=h27826a3_0 110 | - torchaudio=0.11.0=py37_cu113 111 | - tornado=6.1=py37h27cfd23_0 112 | - typing_extensions=4.3.0=py37h06a4308_0 113 | - wheel=0.37.1=pyhd8ed1ab_0 114 | - xz=5.2.5=h516909a_1 115 | - zlib=1.2.11=h166bdaf_1014 116 | - zstd=1.4.9=haebb681_0 117 | - pip: 118 | - addict==2.4.0 119 | - ansi2html==1.8.0 120 | - anyio==3.5.0 121 | - appdirs==1.4.4 122 | - argon2-cffi==21.3.0 123 | - argon2-cffi-bindings==21.2.0 124 | - argparse==1.4.0 125 | - arrow==1.2.3 126 | - attrs==21.4.0 127 | - babel==2.9.1 128 | - backcall==0.2.0 129 | - beautifulsoup4==4.10.0 130 | - bleach==4.1.0 131 | - click==8.1.7 132 | - cloudpickle==2.0.0 133 | - colorama==0.4.4 134 | - comm==0.1.4 135 | - configargparse==1.7 136 | - dash==2.14.1 137 | - dash-core-components==2.0.0 138 | - dash-html-components==2.0.0 139 | - dash-table==5.0.0 140 | - dask==2022.2.0 141 | - debugpy==1.6.0 142 | - decorator==5.1.1 143 | - defusedxml==0.7.1 144 | - deprecation==2.1.0 145 | - docopt==0.6.2 146 | - entrypoints==0.4 147 | - fastjsonschema==2.19.0 148 | - filelock==3.6.0 149 | - flask==2.2.5 150 | - freetype-py==2.3.0 151 | - fsspec==2022.2.0 152 | - gdown==4.4.0 153 | - imageio==2.16.1 154 | - imath==0.0.1 155 | - importlib-metadata==4.11.3 156 | - importlib-resources==5.6.0 157 | - inform==1.28 158 | - ipykernel==6.10.0 159 | - ipython==7.32.0 160 | - ipython-genutils==0.2.0 161 | - ipywidgets==8.1.1 162 | - itsdangerous==2.1.2 163 | - jedi==0.18.1 164 | - jinja2==3.1.1 165 | - joblib==1.1.0 166 | - json5==0.9.6 167 | - jsonschema==4.4.0 168 | - jupyter-client==7.2.0 169 | - jupyter-core==4.9.2 170 | - jupyter-packaging==0.12.0 171 | - jupyter-server==1.16.0 172 | - jupyterlab==3.3.2 173 | - jupyterlab-pygments==0.1.2 174 | - jupyterlab-server==2.12.0 175 | - jupyterlab-widgets==3.0.9 176 | - llvmlite==0.39.1 177 | - locket==0.2.1 178 | - mako==1.2.4 179 | - markupsafe==2.1.1 180 | - mathutils==2.81.2 181 | - matplotlib-inline==0.1.3 182 | - mistune==0.8.4 183 | - nbclassic==0.3.7 184 | - nbclient==0.5.13 185 | - nbconvert==6.4.5 186 | - nbformat==5.7.0 187 | - nest-asyncio==1.5.4 188 | - networkx==2.6.3 189 | - notebook==6.4.10 190 | - notebook-shim==0.1.0 191 | - numba==0.56.4 192 | - numpy==1.21.5 193 | - open3d==0.17.0 194 | - opencv-python==4.5.5.64 195 | - openexr==1.3.7 196 | - packaging==21.3 197 | - pandas==1.3.5 198 | - pandocfilters==1.5.0 199 | - parso==0.8.3 200 | - partd==1.2.0 201 | - pexpect==4.8.0 202 | - pickleshare==0.7.5 203 | - pillow==9.5.0 204 | - platformdirs==3.2.0 205 | - plotly==5.18.0 206 | - plyfile==0.9 207 | - prometheus-client==0.13.1 208 | - prompt-toolkit==3.0.28 209 | - psutil==5.9.0 210 | - ptyprocess==0.7.0 211 | - pycuda==2022.1 212 | - pyglet==2.0.6 213 | - pygments==2.11.2 214 | - pyopengl==3.1.0 215 | - pypng==0.20220715.0 216 | - pyquaternion==0.9.9 217 | - pyrender==0.1.45 218 | - pyrsistent==0.18.1 219 | - pytools==2022.1.12 220 | - pytz==2022.1 221 | - pywavelets==1.3.0 222 | - pyyaml==6.0 223 | - pyzmq==22.3.0 224 | - quantiphy==2.19 225 | - retrying==1.3.4 226 | - rtree==0.9.7 227 | - scikit-image==0.19.2 228 | - scikit-learn==1.0.2 229 | - scipy==1.7.3 230 | - seaborn==0.11.2 231 | - send2trash==1.8.0 232 | - sniffio==1.2.0 233 | - soupsieve==2.3.1 234 | - tenacity==8.2.3 235 | - terminado==0.13.3 236 | - testpath==0.6.0 237 | - threadpoolctl==3.1.0 238 | - tifffile==2021.11.2 239 | - tomlkit==0.10.1 240 | - toolz==0.11.2 241 | - torch==1.11.0+cu113 242 | - torchvision==0.12.0+cu113 243 | - tqdm==4.63.1 244 | - traitlets==5.1.1 245 | - trimesh==3.10.7 246 | - tvm==1.0.0 247 | - typing-extensions==4.5.0 248 | - urllib3==1.26.9 249 | - wcwidth==0.2.5 250 | - webencodings==0.5.1 251 | - websocket-client==1.3.2 252 | - werkzeug==2.2.3 253 | - widgetsnbextension==4.0.9 254 | - zipp==3.7.0 255 | -------------------------------------------------------------------------------- /get_tsdf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | from src import config 7 | import src.fusion as fusion 8 | import open3d as o3d 9 | from src.utils.datasets import get_dataset 10 | 11 | 12 | def update_cam(cfg): 13 | """ 14 | Update the camera intrinsics according to pre-processing config, 15 | such as resize or edge crop. 16 | """ 17 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][ 18 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy'] 19 | # resize the input images to crop_size (variable name used in lietorch) 20 | if 'crop_size' in cfg['cam']: 21 | crop_size = cfg['cam']['crop_size'] 22 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][ 23 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy'] 24 | sx = crop_size[1] / W 25 | sy = crop_size[0] / H 26 | fx = sx*fx 27 | fy = sy*fy 28 | cx = sx*cx 29 | cy = sy*cy 30 | W = crop_size[1] 31 | H = crop_size[0] 32 | 33 | 34 | # croping will change H, W, cx, cy, so need to change here 35 | if cfg['cam']['crop_edge'] > 0: 36 | H -= cfg['cam']['crop_edge']*2 37 | W -= cfg['cam']['crop_edge']*2 38 | cx -= cfg['cam']['crop_edge'] 39 | cy -= cfg['cam']['crop_edge'] 40 | 41 | return H, W, fx, fy, cx, cy 42 | 43 | 44 | def init_tsdf_volume(cfg, args, space=10): 45 | """ 46 | Initialize the TSDF volume. 47 | Get the TSDF volume and bounds. 48 | 49 | space: the space between frames to integrate into the TSDF volume. 50 | 51 | """ 52 | # scale the bound if there is a global scaling factor 53 | scale = cfg['scale'] 54 | bound = torch.from_numpy( 55 | np.array(cfg['mapping']['bound'])*scale) 56 | bound_divisible = cfg['grid_len']['bound_divisible'] 57 | # enlarge the bound a bit to allow it divisible by bound_divisible 58 | bound[:, 1] = (((bound[:, 1]-bound[:, 0]) / 59 | bound_divisible).int()+1)*bound_divisible+bound[:, 0] 60 | 61 | # TSDF volume 62 | H, W, fx, fy, cx, cy = update_cam(cfg) 63 | intrinsic = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy).intrinsic_matrix # (3, 3) 64 | 65 | print("Initializing voxel volume...") 66 | vol_bnds = np.array(bound) 67 | tsdf_vol = fusion.TSDFVolume(vol_bnds, voxel_size=4.0/256) 68 | 69 | frame_reader = get_dataset(cfg, args, scale) 70 | 71 | for idx in range(len(frame_reader)): 72 | if idx % space != 0: continue 73 | print(f'frame: {idx}') 74 | _, gt_color, gt_depth, gt_c2w = frame_reader[idx] 75 | 76 | # convert to open3d camera pose 77 | c2w = gt_c2w.cpu().numpy() 78 | 79 | if np.isfinite(c2w).any(): 80 | c2w[:3, 1] *= -1.0 81 | c2w[:3, 2] *= -1.0 82 | 83 | depth = gt_depth.cpu().numpy() # (368, 496, 3) 84 | color = gt_color.cpu().numpy() 85 | depth = depth.astype(np.float32) 86 | color = np.array((color * 255).astype(np.uint8)) 87 | tsdf_vol.integrate(color, depth, intrinsic, c2w, obs_weight=1.) 88 | 89 | print('Getting TSDF volume') 90 | tsdf_volume, _, bounds = tsdf_vol.get_volume() 91 | 92 | print("Getting mesh") 93 | verts, faces, norms, colors = tsdf_vol.get_mesh() 94 | 95 | tsdf_volume = torch.tensor(tsdf_volume) 96 | tsdf_volume = tsdf_volume.reshape(1, 1, tsdf_volume.shape[0], tsdf_volume.shape[1], tsdf_volume.shape[2]) 97 | tsdf_volume = tsdf_volume.permute(0, 1, 4, 3, 2) 98 | 99 | return tsdf_volume, bounds, verts, faces, norms, colors 100 | 101 | def get_tsdf(): 102 | """ 103 | Save the TSDF volume and bounds to a file. 104 | 105 | """ 106 | parser = argparse.ArgumentParser( 107 | description='Arguments for running the code.' 108 | ) 109 | parser.add_argument('config', type=str, help='Path to config file.') 110 | parser.add_argument('--input_folder', type=str, 111 | help='input folder, this have higher priority, can overwrite the one in config file') 112 | parser.add_argument('--output', type=str, 113 | help='output folder, this have higher priority, can overwrite the one in config file') 114 | parser.add_argument('--space', type=int, default=10, help='the space between frames to integrate into the TSDF volume.') 115 | 116 | args = parser.parse_args() 117 | cfg = config.load_config(args.config, 'configs/df_prior.yaml') 118 | 119 | dataset = cfg['data']['dataset'] 120 | scene_id = cfg['data']['id'] 121 | 122 | 123 | path = f'{dataset}_tsdf_volume' 124 | os.makedirs(path, exist_ok=True) 125 | 126 | tsdf_volume, bounds, verts, faces, norms, colors = init_tsdf_volume(cfg, args, space=args.space) 127 | 128 | if dataset == 'scannet': 129 | tsdf_volume_path = os.path.join(path, f'scene{scene_id}_tsdf_volume.pt') 130 | bounds_path = os.path.join(path, f'scene{scene_id}_bounds.pt') 131 | 132 | elif dataset == 'replica': 133 | tsdf_volume_path = os.path.join(path, f'{scene_id}_tsdf_volume.pt') 134 | bounds_path = os.path.join(path, f'{scene_id}_bounds.pt') 135 | 136 | 137 | torch.save(tsdf_volume, tsdf_volume_path) 138 | torch.save(bounds, bounds_path) 139 | 140 | 141 | 142 | if __name__ == '__main__': 143 | get_tsdf() 144 | -------------------------------------------------------------------------------- /pretrained/low_high.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/pretrained/low_high.pt -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from src import config 8 | from src.DF_Prior import DF_Prior 9 | 10 | 11 | def setup_seed(seed): 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | torch.backends.cudnn.deterministic = True 17 | 18 | 19 | def main(): 20 | 21 | parser = argparse.ArgumentParser( 22 | description='Arguments for running the code.' 23 | ) 24 | parser.add_argument('config', type=str, help='Path to config file.') 25 | parser.add_argument('--input_folder', type=str, 26 | help='input folder, this have higher priority, can overwrite the one in config file') 27 | parser.add_argument('--output', type=str, 28 | help='output folder, this have higher priority, can overwrite the one in config file') 29 | args = parser.parse_args() 30 | cfg = config.load_config(args.config, 'configs/df_prior.yaml') 31 | 32 | slam = DF_Prior(cfg, args) 33 | slam.run() 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /scripts/download_cull_replica_mesh.sh: -------------------------------------------------------------------------------- 1 | # you can also download the cull_replica_mesh.zip manually through 2 | # link: https://caiyun.139.com/m/i?1A5CvGNQmSDsI password: Dhhp 3 | wget https://cvg-data.inf.ethz.ch/nice-slam/cull_replica_mesh.zip 4 | unzip cull_replica_mesh.zip -------------------------------------------------------------------------------- /scripts/download_replica.sh: -------------------------------------------------------------------------------- 1 | mkdir -p Datasets 2 | cd Datasets 3 | # you can also download the Replica.zip manually through 4 | # link: https://caiyun.139.com/m/i?1A5Ch5C3abNiL password: v3fY (the zip is split into smaller zips because of the size limitation of caiyun) 5 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip 6 | unzip Replica.zip 7 | -------------------------------------------------------------------------------- /src/DF_Prior.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.multiprocessing 7 | import torch.multiprocessing as mp 8 | 9 | from src import config 10 | from src.Mapper import Mapper 11 | from src.Tracker import Tracker 12 | from src.utils.datasets import get_dataset 13 | from src.utils.Logger import Logger 14 | from src.utils.Mesher import Mesher 15 | from src.utils.Renderer import Renderer 16 | 17 | # import src.fusion as fusion 18 | # import open3d as o3d 19 | 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | 23 | class DF_Prior(): 24 | """ 25 | DF_Prior main class. 26 | Mainly allocate shared resources, and dispatch mapping and tracking process. 27 | """ 28 | 29 | def __init__(self, cfg, args): 30 | 31 | self.cfg = cfg 32 | self.args = args 33 | 34 | self.occupancy = cfg['occupancy'] 35 | self.low_gpu_mem = cfg['low_gpu_mem'] 36 | self.verbose = cfg['verbose'] 37 | self.dataset = cfg['dataset'] 38 | if args.output is None: 39 | self.output = cfg['data']['output'] 40 | else: 41 | self.output = args.output 42 | self.ckptsdir = os.path.join(self.output, 'ckpts') 43 | os.makedirs(self.output, exist_ok=True) 44 | os.makedirs(self.ckptsdir, exist_ok=True) 45 | os.makedirs(f'{self.output}/mesh', exist_ok=True) 46 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = cfg['cam']['H'], cfg['cam'][ 47 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy'] 48 | self.update_cam() 49 | 50 | model = config.get_model(cfg) 51 | self.shared_decoders = model 52 | 53 | self.scale = cfg['scale'] 54 | 55 | self.load_bound(cfg) 56 | self.load_pretrain(cfg) 57 | self.grid_init(cfg) 58 | 59 | # need to use spawn 60 | try: 61 | mp.set_start_method('spawn', force=True) 62 | except RuntimeError: 63 | pass 64 | 65 | self.frame_reader = get_dataset(cfg, args, self.scale) 66 | self.n_img = len(self.frame_reader) 67 | self.estimate_c2w_list = torch.zeros((self.n_img, 4, 4)) 68 | self.estimate_c2w_list.share_memory_() 69 | 70 | dataset = self.cfg['data']['dataset'] 71 | scene_id = self.cfg['data']['id'] 72 | self.scene_id = scene_id 73 | print(scene_id) 74 | # load tsdf grid 75 | if dataset == 'scannet': 76 | self.tsdf_volume_shared = torch.load(f'scannet_tsdf_volume/scene{scene_id}_tsdf_volume.pt') 77 | elif dataset == 'replica': 78 | self.tsdf_volume_shared = torch.load(f'replica_tsdf_volume/{scene_id}_tsdf_volume.pt') 79 | self.tsdf_volume_shared = self.tsdf_volume_shared.to(self.cfg['mapping']['device']) 80 | self.tsdf_volume_shared.share_memory_() 81 | 82 | # load tsdf grid bound 83 | if dataset == 'scannet': 84 | self.tsdf_bnds = torch.load(f'scannet_tsdf_volume/scene{scene_id}_bounds.pt') 85 | elif dataset == 'replica': 86 | self.tsdf_bnds = torch.load(f'replica_tsdf_volume/{scene_id}_bounds.pt') 87 | self.tsdf_bnds = torch.tensor(self.tsdf_bnds).to(self.cfg['mapping']['device']) 88 | self.tsdf_bnds.share_memory_() 89 | 90 | self.vol_bnds = self.tsdf_bnds 91 | self.vol_bnds.share_memory_() 92 | 93 | self.gt_c2w_list = torch.zeros((self.n_img, 4, 4)) 94 | self.gt_c2w_list.share_memory_() 95 | self.idx = torch.zeros((1)).int() 96 | self.idx.share_memory_() 97 | self.mapping_first_frame = torch.zeros((1)).int() 98 | self.mapping_first_frame.share_memory_() 99 | # the id of the newest frame Mapper is processing 100 | self.mapping_idx = torch.zeros((1)).int() 101 | self.mapping_idx.share_memory_() 102 | self.mapping_cnt = torch.zeros((1)).int() # counter for mapping 103 | self.mapping_cnt.share_memory_() 104 | for key, val in self.shared_c.items(): 105 | val = val.to(self.cfg['mapping']['device']) 106 | val.share_memory_() 107 | self.shared_c[key] = val 108 | self.shared_decoders = self.shared_decoders.to( 109 | self.cfg['mapping']['device']) 110 | self.shared_decoders.share_memory() 111 | self.renderer = Renderer(cfg, args, self) 112 | self.mesher = Mesher(cfg, args, self) 113 | self.logger = Logger(cfg, args, self) 114 | self.mapper = Mapper(cfg, args, self) 115 | self.tracker = Tracker(cfg, args, self) 116 | self.print_output_desc() 117 | 118 | 119 | def print_output_desc(self): 120 | print(f"INFO: The output folder is {self.output}") 121 | if 'Demo' in self.output: 122 | print( 123 | f"INFO: The GT, generated and residual depth/color images can be found under " + 124 | f"{self.output}/vis/") 125 | else: 126 | print( 127 | f"INFO: The GT, generated and residual depth/color images can be found under " + 128 | f"{self.output}/tracking_vis/ and {self.output}/mapping_vis/") 129 | print(f"INFO: The mesh can be found under {self.output}/mesh/") 130 | print(f"INFO: The checkpoint can be found under {self.output}/ckpt/") 131 | 132 | 133 | def update_cam(self): 134 | """ 135 | Update the camera intrinsics according to pre-processing config, 136 | such as resize or edge crop. 137 | """ 138 | # resize the input images to crop_size (variable name used in lietorch) 139 | if 'crop_size' in self.cfg['cam']: 140 | crop_size = self.cfg['cam']['crop_size'] 141 | sx = crop_size[1] / self.W 142 | sy = crop_size[0] / self.H 143 | self.fx = sx*self.fx 144 | self.fy = sy*self.fy 145 | self.cx = sx*self.cx 146 | self.cy = sy*self.cy 147 | self.W = crop_size[1] 148 | self.H = crop_size[0] 149 | 150 | # croping will change H, W, cx, cy, so need to change here 151 | if self.cfg['cam']['crop_edge'] > 0: 152 | self.H -= self.cfg['cam']['crop_edge']*2 153 | self.W -= self.cfg['cam']['crop_edge']*2 154 | self.cx -= self.cfg['cam']['crop_edge'] 155 | self.cy -= self.cfg['cam']['crop_edge'] 156 | 157 | 158 | # def init_tsdf_volume(self, cfg, args): 159 | # # scale the bound if there is a global scaling factor 160 | # scale = cfg['scale'] 161 | # bound = torch.from_numpy( 162 | # np.array(cfg['mapping']['bound'])*scale) 163 | # bound_divisible = cfg['grid_len']['bound_divisible'] 164 | # # enlarge the bound a bit to allow it divisible by bound_divisible 165 | # bound[:, 1] = (((bound[:, 1]-bound[:, 0]) / 166 | # bound_divisible).int()+1)*bound_divisible+bound[:, 0] 167 | # intrinsic = o3d.camera.PinholeCameraIntrinsic(self.W, self.H, self.fx, self.fy, self.cx, self.cy).intrinsic_matrix # (3, 3) 168 | 169 | # print("Initializing voxel volume...") 170 | # vol_bnds = np.array(bound) 171 | # tsdf_vol = fusion.TSDFVolume(vol_bnds, voxel_size=4/256) 172 | 173 | 174 | # return tsdf_vol, intrinsic, vol_bnds 175 | 176 | 177 | def load_bound(self, cfg): 178 | """ 179 | Pass the scene bound parameters to different decoders and self. 180 | 181 | Args: 182 | cfg (dict): parsed config dict. 183 | """ 184 | # scale the bound if there is a global scaling factor 185 | self.bound = torch.from_numpy( 186 | np.array(cfg['mapping']['bound'])*self.scale) 187 | bound_divisible = cfg['grid_len']['bound_divisible'] 188 | # enlarge the bound a bit to allow it divisible by bound_divisible 189 | self.bound[:, 1] = (((self.bound[:, 1]-self.bound[:, 0]) / 190 | bound_divisible).int()+1)*bound_divisible+self.bound[:, 0] 191 | self.shared_decoders.bound = self.bound 192 | self.shared_decoders.low_decoder.bound = self.bound 193 | self.shared_decoders.high_decoder.bound = self.bound 194 | self.shared_decoders.color_decoder.bound = self.bound 195 | 196 | 197 | def load_pretrain(self, cfg): 198 | """ 199 | Load parameters of pretrained ConvOnet checkpoints to the decoders. 200 | 201 | Args: 202 | cfg (dict): parsed config dict 203 | """ 204 | 205 | ckpt = torch.load(cfg['pretrained_decoders']['low_high'], 206 | map_location=cfg['mapping']['device']) 207 | low_dict = {} 208 | high_dict = {} 209 | for key, val in ckpt['model'].items(): 210 | if ('decoder' in key) and ('encoder' not in key): 211 | if 'coarse' in key: 212 | key = key[8+7:] 213 | low_dict[key] = val 214 | elif 'fine' in key: 215 | key = key[8+5:] 216 | high_dict[key] = val 217 | self.shared_decoders.low_decoder.load_state_dict(low_dict) 218 | self.shared_decoders.high_decoder.load_state_dict(high_dict) 219 | 220 | 221 | def grid_init(self, cfg): 222 | """ 223 | Initialize the hierarchical feature grids. 224 | 225 | Args: 226 | cfg (dict): parsed config dict. 227 | """ 228 | 229 | low_grid_len = cfg['grid_len']['low'] 230 | self.low_grid_len = low_grid_len 231 | high_grid_len = cfg['grid_len']['high'] 232 | self.high_grid_len = high_grid_len 233 | color_grid_len = cfg['grid_len']['color'] 234 | self.color_grid_len = color_grid_len 235 | 236 | c = {} 237 | c_dim = cfg['model']['c_dim'] 238 | xyz_len = self.bound[:, 1]-self.bound[:, 0] 239 | 240 | 241 | 242 | low_key = 'grid_low' 243 | low_val_shape = list(map(int, (xyz_len/low_grid_len).tolist())) 244 | low_val_shape[0], low_val_shape[2] = low_val_shape[2], low_val_shape[0] 245 | self.low_val_shape = low_val_shape 246 | val_shape = [1, c_dim, *low_val_shape] 247 | low_val = torch.zeros(val_shape).normal_(mean=0, std=0.01) 248 | c[low_key] = low_val 249 | 250 | high_key = 'grid_high' 251 | high_val_shape = list(map(int, (xyz_len/high_grid_len).tolist())) 252 | high_val_shape[0], high_val_shape[2] = high_val_shape[2], high_val_shape[0] 253 | self.high_val_shape = high_val_shape 254 | val_shape = [1, c_dim, *high_val_shape] 255 | high_val = torch.zeros(val_shape).normal_(mean=0, std=0.0001) 256 | c[high_key] = high_val 257 | 258 | color_key = 'grid_color' 259 | color_val_shape = list(map(int, (xyz_len/color_grid_len).tolist())) 260 | color_val_shape[0], color_val_shape[2] = color_val_shape[2], color_val_shape[0] 261 | self.color_val_shape = color_val_shape 262 | val_shape = [1, c_dim, *color_val_shape] 263 | color_val = torch.zeros(val_shape).normal_(mean=0, std=0.01) 264 | c[color_key] = color_val 265 | 266 | self.shared_c = c 267 | 268 | 269 | def tracking(self, rank): 270 | """ 271 | Tracking Thread. 272 | 273 | Args: 274 | rank (int): Thread ID. 275 | """ 276 | 277 | # should wait until the mapping of first frame is finished 278 | while (1): 279 | if self.mapping_first_frame[0] == 1: 280 | break 281 | time.sleep(1) 282 | 283 | self.tracker.run() 284 | 285 | 286 | def mapping(self, rank): 287 | """ 288 | Mapping Thread. (updates low, high, and color level) 289 | 290 | Args: 291 | rank (int): Thread ID. 292 | """ 293 | 294 | self.mapper.run() 295 | 296 | 297 | def run(self): 298 | """ 299 | Dispatch Threads. 300 | """ 301 | 302 | processes = [] 303 | for rank in range(2): 304 | if rank == 0: 305 | p = mp.Process(target=self.tracking, args=(rank, )) 306 | elif rank == 1: 307 | p = mp.Process(target=self.mapping, args=(rank, )) 308 | p.start() 309 | processes.append(p) 310 | for p in processes: 311 | p.join() 312 | 313 | 314 | # This part is required by torch.multiprocessing 315 | if __name__ == '__main__': 316 | pass 317 | -------------------------------------------------------------------------------- /src/Tracker.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from colorama import Fore, Style 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from src.common import (get_camera_from_tensor, get_samples, 13 | get_tensor_from_camera) 14 | from src.utils.datasets import get_dataset 15 | from src.utils.Visualizer import Visualizer 16 | 17 | 18 | 19 | 20 | class Tracker(object): 21 | def __init__(self, cfg, args, slam 22 | ): 23 | self.cfg = cfg 24 | self.args = args 25 | 26 | self.scale = cfg['scale'] 27 | self.occupancy = cfg['occupancy'] 28 | self.sync_method = cfg['sync_method'] 29 | 30 | self.idx = slam.idx 31 | self.bound = slam.bound 32 | self.mesher = slam.mesher 33 | self.output = slam.output 34 | self.verbose = slam.verbose 35 | self.shared_c = slam.shared_c 36 | self.renderer = slam.renderer 37 | self.gt_c2w_list = slam.gt_c2w_list 38 | self.low_gpu_mem = slam.low_gpu_mem 39 | self.mapping_idx = slam.mapping_idx 40 | self.mapping_cnt = slam.mapping_cnt 41 | self.shared_decoders = slam.shared_decoders 42 | self.estimate_c2w_list = slam.estimate_c2w_list 43 | with torch.no_grad(): 44 | self.tsdf_volume_shared = slam.tsdf_volume_shared 45 | self.tsdf_bnds = slam.tsdf_bnds 46 | 47 | 48 | self.cam_lr = cfg['tracking']['lr'] 49 | self.device = cfg['tracking']['device'] 50 | self.num_cam_iters = cfg['tracking']['iters'] 51 | self.gt_camera = cfg['tracking']['gt_camera'] 52 | self.tracking_pixels = cfg['tracking']['pixels'] 53 | self.seperate_LR = cfg['tracking']['seperate_LR'] 54 | self.w_color_loss = cfg['tracking']['w_color_loss'] 55 | self.ignore_edge_W = cfg['tracking']['ignore_edge_W'] 56 | self.ignore_edge_H = cfg['tracking']['ignore_edge_H'] 57 | self.handle_dynamic = cfg['tracking']['handle_dynamic'] 58 | self.use_color_in_tracking = cfg['tracking']['use_color_in_tracking'] 59 | self.const_speed_assumption = cfg['tracking']['const_speed_assumption'] 60 | 61 | self.every_frame = cfg['mapping']['every_frame'] 62 | self.no_vis_on_first_frame = cfg['mapping']['no_vis_on_first_frame'] # ori mapping 63 | 64 | self.prev_mapping_idx = -1 65 | self.frame_reader = get_dataset( 66 | cfg, args, self.scale, device=self.device) 67 | self.n_img = len(self.frame_reader) 68 | self.frame_loader = DataLoader( 69 | self.frame_reader, batch_size=1, shuffle=False, num_workers=1) 70 | self.visualizer = Visualizer(freq=cfg['tracking']['vis_freq'], inside_freq=cfg['tracking']['vis_inside_freq'], 71 | vis_dir=os.path.join(self.output, 'vis' if 'Demo' in self.output else 'tracking_vis'), 72 | renderer=self.renderer, verbose=self.verbose, device=self.device) 73 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = slam.H, slam.W, slam.fx, slam.fy, slam.cx, slam.cy 74 | 75 | def optimize_cam_in_batch(self, camera_tensor, gt_color, gt_depth, batch_size, optimizer, tsdf_volume): 76 | """ 77 | Do one iteration of camera iteration. Sample pixels, render depth/color, calculate loss and backpropagation. 78 | 79 | Args: 80 | camera_tensor (tensor): camera tensor. 81 | gt_color (tensor): ground truth color image of the current frame. 82 | gt_depth (tensor): ground truth depth image of the current frame. 83 | batch_size (int): batch size, number of sampling rays. 84 | optimizer (torch.optim): camera optimizer. 85 | tsdf_volume (tensor): tsdf volume 86 | 87 | Returns: 88 | loss (float): The value of loss. 89 | """ 90 | device = self.device 91 | H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy 92 | optimizer.zero_grad() 93 | c2w = get_camera_from_tensor(camera_tensor) 94 | tsdf_bnds = self.tsdf_bnds.to(device) 95 | Wedge = self.ignore_edge_W 96 | Hedge = self.ignore_edge_H 97 | batch_rays_o, batch_rays_d, batch_gt_depth, batch_gt_color = get_samples( 98 | Hedge, H-Hedge, Wedge, W-Wedge, batch_size, H, W, fx, fy, cx, cy, c2w, gt_depth, gt_color, self.device) 99 | 100 | # should pre-filter those out of bounding box depth value 101 | with torch.no_grad(): 102 | det_rays_o = batch_rays_o.clone().detach().unsqueeze(-1) # (N, 3, 1) 103 | det_rays_d = batch_rays_d.clone().detach().unsqueeze(-1) # (N, 3, 1) 104 | t = (self.bound.unsqueeze(0).to(device)-det_rays_o)/det_rays_d 105 | t, _ = torch.min(torch.max(t, dim=2)[0], dim=1) 106 | inside_mask = t >= batch_gt_depth 107 | batch_rays_d = batch_rays_d[inside_mask] 108 | batch_rays_o = batch_rays_o[inside_mask] 109 | batch_gt_depth = batch_gt_depth[inside_mask] 110 | batch_gt_color = batch_gt_color[inside_mask] 111 | 112 | ret = self.renderer.render_batch_ray( 113 | self.c, self.decoders, batch_rays_d, batch_rays_o, self.device, tsdf_volume, tsdf_bnds, stage='color', gt_depth=batch_gt_depth) #color 114 | depth, uncertainty, color, _ = ret 115 | 116 | uncertainty = uncertainty.detach() 117 | if self.handle_dynamic: 118 | tmp = torch.abs(batch_gt_depth-depth)/torch.sqrt(uncertainty+1e-10) 119 | mask = (tmp < 10*tmp.median()) & (batch_gt_depth > 0) 120 | else: 121 | mask = batch_gt_depth > 0 122 | 123 | loss = (torch.abs(batch_gt_depth-depth) / 124 | torch.sqrt(uncertainty+1e-10))[mask].sum() 125 | 126 | if self.use_color_in_tracking: 127 | color_loss = torch.abs( 128 | batch_gt_color - color)[mask].sum() 129 | loss += self.w_color_loss*color_loss 130 | 131 | loss.backward(retain_graph=False) 132 | optimizer.step() 133 | optimizer.zero_grad() 134 | return loss.item() 135 | 136 | def update_para_from_mapping(self): 137 | """ 138 | Update the parameters of scene representation from the mapping thread. 139 | 140 | """ 141 | if self.mapping_idx[0] != self.prev_mapping_idx: 142 | if self.verbose: 143 | print('Tracking: update the parameters from mapping') 144 | self.decoders = copy.deepcopy(self.shared_decoders).to(self.device) 145 | for key, val in self.shared_c.items(): 146 | val = val.clone().to(self.device) 147 | self.c[key] = val 148 | self.prev_mapping_idx = self.mapping_idx[0].clone() 149 | 150 | def run(self): 151 | device = self.device 152 | tsdf_volume = self.tsdf_volume_shared 153 | tsdf_bnds = self.tsdf_bnds.to(device) 154 | 155 | self.c = {} 156 | if self.verbose: 157 | pbar = self.frame_loader 158 | else: 159 | pbar = tqdm(self.frame_loader) 160 | 161 | for idx, gt_color, gt_depth, gt_c2w in pbar: 162 | if not self.verbose: 163 | pbar.set_description(f"Tracking Frame {idx[0]}") 164 | 165 | idx = idx[0] 166 | gt_depth = gt_depth[0] 167 | gt_color = gt_color[0] 168 | gt_c2w = gt_c2w[0] 169 | 170 | if self.sync_method == 'strict': 171 | # strictly mapping and then tracking 172 | # initiate mapping every self.every_frame frames 173 | if idx > 0 and (idx % self.every_frame == 1 or self.every_frame == 1): 174 | while self.mapping_idx[0] != idx-1: 175 | time.sleep(0.1) 176 | pre_c2w = self.estimate_c2w_list[idx-1].to(device) 177 | elif self.sync_method == 'loose': 178 | # mapping idx can be later than tracking idx is within the bound of 179 | # [-self.every_frame-self.every_frame//2, -self.every_frame+self.every_frame//2] 180 | while self.mapping_idx[0] < idx-self.every_frame-self.every_frame//2: 181 | time.sleep(0.1) 182 | elif self.sync_method == 'free': 183 | # pure parallel, if mesh/vis happens may cause inbalance 184 | pass 185 | 186 | self.update_para_from_mapping() 187 | 188 | if self.verbose: 189 | print(Fore.MAGENTA) 190 | print("Tracking Frame ", idx.item()) 191 | print(Style.RESET_ALL) 192 | 193 | 194 | 195 | if idx == 0 or self.gt_camera: 196 | c2w = gt_c2w 197 | if not self.no_vis_on_first_frame: 198 | self.visualizer.vis( 199 | idx, 0, gt_depth, gt_color, c2w, self.c, self.decoders, tsdf_volume, tsdf_bnds) 200 | 201 | else: 202 | gt_camera_tensor = get_tensor_from_camera(gt_c2w) 203 | if self.const_speed_assumption and idx-2 >= 0: 204 | pre_c2w = pre_c2w.float() 205 | delta = pre_c2w@self.estimate_c2w_list[idx-2].to( 206 | device).float().inverse() 207 | estimated_new_cam_c2w = delta@pre_c2w 208 | else: 209 | estimated_new_cam_c2w = pre_c2w 210 | 211 | camera_tensor = get_tensor_from_camera( 212 | estimated_new_cam_c2w.detach()) 213 | if self.seperate_LR: 214 | camera_tensor = camera_tensor.to(device).detach() 215 | T = camera_tensor[-3:] 216 | quad = camera_tensor[:4] 217 | cam_para_list_quad = [quad] 218 | quad = Variable(quad, requires_grad=True) 219 | T = Variable(T, requires_grad=True) 220 | camera_tensor = torch.cat([quad, T], 0) 221 | cam_para_list_T = [T] 222 | cam_para_list_quad = [quad] 223 | optimizer_camera = torch.optim.Adam([{'params': cam_para_list_T, 'lr': self.cam_lr}, 224 | {'params': cam_para_list_quad, 'lr': self.cam_lr*0.2}]) 225 | else: 226 | camera_tensor = Variable( 227 | camera_tensor.to(device), requires_grad=True) 228 | cam_para_list = [camera_tensor] 229 | optimizer_camera = torch.optim.Adam( 230 | cam_para_list, lr=self.cam_lr) 231 | 232 | initial_loss_camera_tensor = torch.abs( 233 | gt_camera_tensor.to(device)-camera_tensor).mean().item() 234 | candidate_cam_tensor = None 235 | current_min_loss = 10000000000. 236 | 237 | 238 | 239 | for cam_iter in range(self.num_cam_iters): 240 | if self.seperate_LR: 241 | camera_tensor = torch.cat([quad, T], 0).to(self.device) 242 | 243 | self.visualizer.vis( 244 | idx, cam_iter, gt_depth, gt_color, camera_tensor, self.c, self.decoders, tsdf_volume, tsdf_bnds) 245 | 246 | loss = self.optimize_cam_in_batch( 247 | camera_tensor, gt_color, gt_depth, self.tracking_pixels, optimizer_camera, tsdf_volume) 248 | 249 | if cam_iter == 0: 250 | initial_loss = loss 251 | 252 | loss_camera_tensor = torch.abs( 253 | gt_camera_tensor.to(device)-camera_tensor).mean().item() 254 | if self.verbose: 255 | if cam_iter == self.num_cam_iters-1: 256 | print( 257 | f'Re-rendering loss: {initial_loss:.2f}->{loss:.2f} ' + 258 | f'camera tensor error: {initial_loss_camera_tensor:.4f}->{loss_camera_tensor:.4f}') 259 | if loss < current_min_loss: 260 | current_min_loss = loss 261 | candidate_cam_tensor = camera_tensor.clone().detach() 262 | bottom = torch.from_numpy(np.array([0, 0, 0, 1.]).reshape( 263 | [1, 4])).type(torch.float32).to(self.device) 264 | c2w = get_camera_from_tensor( 265 | candidate_cam_tensor.clone().detach()) 266 | c2w = torch.cat([c2w, bottom], dim=0) 267 | 268 | 269 | self.estimate_c2w_list[idx] = c2w.clone().cpu() 270 | self.gt_c2w_list[idx] = gt_c2w.clone().cpu() 271 | pre_c2w = c2w.clone() 272 | self.idx[0] = idx 273 | if self.low_gpu_mem: 274 | torch.cuda.empty_cache() 275 | 276 | 277 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/DF_Prior.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/DF_Prior.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/Mapper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Mapper.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/Mapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Mapper.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/NICE_SLAM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/NICE_SLAM.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/NICE_SLAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/NICE_SLAM.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/Tracker.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Tracker.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/Tracker.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/Tracker.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/depth2pointcloud.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/depth2pointcloud.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/fusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/fusion.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/fusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/fusion.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/fusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/__pycache__/fusion.cpython-38.pyc -------------------------------------------------------------------------------- /src/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | def as_intrinsics_matrix(intrinsics): 9 | """ 10 | Get matrix representation of intrinsics. 11 | 12 | """ 13 | K = np.eye(3) 14 | K[0, 0] = intrinsics[0] 15 | K[1, 1] = intrinsics[1] 16 | K[0, 2] = intrinsics[2] 17 | K[1, 2] = intrinsics[3] 18 | return K 19 | 20 | 21 | def sample_pdf(bins, weights, N_samples, det=False, device='cuda:0'): 22 | """ 23 | Hierarchical sampling in NeRF paper (section 5.2). 24 | 25 | """ 26 | # Get pdf - probability density function 27 | weights = weights + 1e-5 # prevent nans 28 | pdf = weights / torch.sum(weights, -1, keepdim=True) 29 | cdf = torch.cumsum(pdf, -1) 30 | # (batch, len(bins)) 31 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 32 | 33 | # Take uniform samples 34 | if det: 35 | u = torch.linspace(0., 1., steps=N_samples) 36 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 37 | else: 38 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 39 | 40 | u = u.to(device) 41 | # Invert CDF 42 | u = u.contiguous() 43 | try: 44 | # this should work fine with the provided environment.yaml 45 | inds = torch.searchsorted(cdf, u, right=True) 46 | except: 47 | # for lower version torch that does not have torch.searchsorted, 48 | # you need to manually install from 49 | # https://github.com/aliutkus/torchsearchsorted 50 | from torchsearchsorted import searchsorted 51 | inds = searchsorted(cdf, u, side='right') 52 | below = torch.max(torch.zeros_like(inds-1), inds-1) 53 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 54 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 55 | 56 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 57 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 58 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 59 | 60 | denom = (cdf_g[..., 1]-cdf_g[..., 0]) 61 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 62 | t = (u-cdf_g[..., 0])/denom 63 | samples = bins_g[..., 0] + t * (bins_g[..., 1]-bins_g[..., 0]) 64 | 65 | return samples 66 | 67 | 68 | def random_select(l, k): 69 | """ 70 | Random select k values from 0..l. 71 | 72 | """ 73 | return list(np.random.permutation(np.array(range(l)))[:min(l, k)]) 74 | 75 | 76 | def get_rays_from_uv(i, j, c2w, H, W, fx, fy, cx, cy, device): 77 | """ 78 | Get corresponding rays from input uv. 79 | 80 | """ 81 | if isinstance(c2w, np.ndarray): 82 | c2w = torch.from_numpy(c2w).to(device) 83 | 84 | dirs = torch.stack( 85 | [(i-cx)/fx, -(j-cy)/fy, -torch.ones_like(i)], -1).to(device) 86 | dirs = dirs.reshape(-1, 1, 3) 87 | # Rotate ray directions from camera frame to the world frame 88 | # dot product, equals to: [c2w.dot(dir) for dir in dirs] 89 | rays_d = torch.sum(dirs * c2w[:3, :3], -1) 90 | rays_o = c2w[:3, -1].expand(rays_d.shape) 91 | return rays_o, rays_d 92 | 93 | 94 | def select_uv(i, j, n, depth, color, device='cuda:0'): 95 | """ 96 | Select n uv from dense uv. 97 | 98 | """ 99 | i = i.reshape(-1) 100 | j = j.reshape(-1) 101 | indices = torch.randint(i.shape[0], (n,), device=device) 102 | indices = indices.clamp(0, i.shape[0]) 103 | i = i[indices] # (n) 104 | j = j[indices] # (n) 105 | depth = depth.reshape(-1) 106 | color = color.reshape(-1, 3) 107 | depth = depth[indices] # (n) 108 | color = color[indices] # (n,3) 109 | return i, j, depth, color 110 | 111 | 112 | def get_sample_uv(H0, H1, W0, W1, n, depth, color, device='cuda:0'): 113 | """ 114 | Sample n uv coordinates from an image region H0..H1, W0..W1 115 | 116 | """ 117 | depth = depth[H0:H1, W0:W1] 118 | color = color[H0:H1, W0:W1] 119 | i, j = torch.meshgrid(torch.linspace( 120 | W0, W1-1, W1-W0).to(device), torch.linspace(H0, H1-1, H1-H0).to(device)) 121 | i = i.t() # transpose 122 | j = j.t() 123 | i, j, depth, color = select_uv(i, j, n, depth, color, device=device) 124 | return i, j, depth, color 125 | 126 | 127 | def get_samples(H0, H1, W0, W1, n, H, W, fx, fy, cx, cy, c2w, depth, color, device): 128 | """ 129 | Get n rays from the image region H0..H1, W0..W1. 130 | c2w is its camera pose and depth/color is the corresponding image tensor. 131 | 132 | """ 133 | i, j, sample_depth, sample_color = get_sample_uv( 134 | H0, H1, W0, W1, n, depth, color, device=device) 135 | rays_o, rays_d = get_rays_from_uv(i, j, c2w, H, W, fx, fy, cx, cy, device) 136 | return rays_o, rays_d, sample_depth, sample_color 137 | 138 | 139 | def quad2rotation(quad): 140 | """ 141 | Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing. 142 | 143 | Args: 144 | quad (tensor, batch_size*4): quaternion. 145 | 146 | Returns: 147 | rot_mat (tensor, batch_size*3*3): rotation. 148 | """ 149 | bs = quad.shape[0] 150 | qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3] 151 | two_s = 2.0 / (quad * quad).sum(-1) 152 | rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device()) 153 | rot_mat[:, 0, 0] = 1 - two_s * (qj ** 2 + qk ** 2) 154 | rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr) 155 | rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr) 156 | rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr) 157 | rot_mat[:, 1, 1] = 1 - two_s * (qi ** 2 + qk ** 2) 158 | rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr) 159 | rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr) 160 | rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr) 161 | rot_mat[:, 2, 2] = 1 - two_s * (qi ** 2 + qj ** 2) 162 | return rot_mat 163 | 164 | 165 | def get_camera_from_tensor(inputs): 166 | """ 167 | Convert quaternion and translation to transformation matrix. 168 | 169 | """ 170 | N = len(inputs.shape) 171 | if N == 1: 172 | inputs = inputs.unsqueeze(0) 173 | quad, T = inputs[:, :4], inputs[:, 4:] 174 | R = quad2rotation(quad) 175 | RT = torch.cat([R, T[:, :, None]], 2) 176 | if N == 1: 177 | RT = RT[0] 178 | return RT 179 | 180 | 181 | def get_tensor_from_camera(RT, Tquad=False): 182 | """ 183 | Convert transformation matrix to quaternion and translation. 184 | 185 | """ 186 | gpu_id = -1 187 | if type(RT) == torch.Tensor: 188 | if RT.get_device() != -1: 189 | RT = RT.detach().cpu() 190 | gpu_id = RT.get_device() 191 | RT = RT.numpy() 192 | from mathutils import Matrix 193 | R, T = RT[:3, :3], RT[:3, 3] 194 | rot = Matrix(R) 195 | quad = rot.to_quaternion() 196 | if Tquad: 197 | tensor = np.concatenate([T, quad], 0) 198 | else: 199 | tensor = np.concatenate([quad, T], 0) 200 | tensor = torch.from_numpy(tensor).float() 201 | if gpu_id != -1: 202 | tensor = tensor.to(gpu_id) 203 | return tensor 204 | 205 | 206 | def raw2outputs_nerf_color(raw, z_vals, rays_d, occupancy=False, device='cuda:0'): 207 | """ 208 | Transforms model's predictions to semantically meaningful values. 209 | 210 | Args: 211 | raw (tensor, N_rays*N_samples*4): prediction from model. 212 | z_vals (tensor, N_rays*N_samples): integration time. 213 | rays_d (tensor, N_rays*3): direction of each ray. 214 | occupancy (bool, optional): occupancy or volume density. Defaults to False. 215 | device (str, optional): device. Defaults to 'cuda:0'. 216 | 217 | Returns: 218 | depth_map (tensor, N_rays): estimated distance to object. 219 | depth_var (tensor, N_rays): depth variance/uncertainty. 220 | rgb_map (tensor, N_rays*3): estimated RGB color of a ray. 221 | weights (tensor, N_rays*N_samples): weights assigned to each sampled color. 222 | """ 223 | 224 | def raw2alpha(raw, dists, act_fn=F.relu): return 1. - \ 225 | torch.exp(-act_fn(raw)*dists) 226 | dists = z_vals[..., 1:] - z_vals[..., :-1] 227 | dists = dists.float() 228 | dists = torch.cat([dists, torch.Tensor([1e10]).float().to( 229 | device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 230 | 231 | # different ray angle corresponds to different unit length 232 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 233 | rgb = raw[..., :-1] 234 | if occupancy: 235 | 236 | raw[..., 3] = torch.sigmoid(10 * raw[..., -1]) #sigmoid tanh -1,1 # when occ do belong to 0 - 1 237 | alpha = raw[..., -1] 238 | 239 | alpha_theta = 0 240 | else: 241 | # original nerf, volume density 242 | alpha = raw2alpha(raw[..., -1], dists) # (N_rays, N_samples) 243 | 244 | weights = alpha.float() * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)).to( 245 | device).float(), (1.-alpha + 1e-10).float()], -1).float(), -1)[:, :-1] 246 | 247 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # (N_rays, 3) 248 | depth_map = torch.sum(weights * z_vals, -1) # (N_rays) 249 | tmp = (z_vals-depth_map.unsqueeze(-1)) # (N_rays, N_samples) 250 | depth_var = torch.sum(weights*tmp*tmp, dim=1) # (N_rays) 251 | return depth_map, depth_var, rgb_map, weights 252 | 253 | 254 | def get_rays(H, W, fx, fy, cx, cy, c2w, device): 255 | """ 256 | Get rays for a whole image. 257 | 258 | """ 259 | if isinstance(c2w, np.ndarray): 260 | c2w = torch.from_numpy(c2w) 261 | # pytorch's meshgrid has indexing='ij' 262 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) 263 | i = i.t() # transpose 264 | j = j.t() 265 | dirs = torch.stack( 266 | [(i-cx)/fx, -(j-cy)/fy, -torch.ones_like(i)], -1).to(device) 267 | dirs = dirs.reshape(H, W, 1, 3) 268 | # Rotate ray directions from camera frame to the world frame 269 | # dot product, equals to: [c2w.dot(dir) for dir in dirs] 270 | rays_d = torch.sum(dirs * c2w[:3, :3], -1) 271 | rays_o = c2w[:3, -1].expand(rays_d.shape) 272 | return rays_o, rays_d 273 | 274 | 275 | def normalize_3d_coordinate(p, bound): 276 | """ 277 | Normalize coordinate to [-1, 1], corresponds to the bounding box given. 278 | 279 | Args: 280 | p (tensor, N*3): coordinate. 281 | bound (tensor, 3*2): the scene bound. 282 | 283 | Returns: 284 | p (tensor, N*3): normalized coordinate. 285 | """ 286 | p = p.reshape(-1, 3) 287 | p[:, 0] = ((p[:, 0]-bound[0, 0])/(bound[0, 1]-bound[0, 0]))*2-1.0 288 | p[:, 1] = ((p[:, 1]-bound[1, 0])/(bound[1, 1]-bound[1, 0]))*2-1.0 289 | p[:, 2] = ((p[:, 2]-bound[2, 0])/(bound[2, 1]-bound[2, 0]))*2-1.0 290 | return p 291 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from src import conv_onet 3 | 4 | 5 | method_dict = { 6 | 'conv_onet': conv_onet 7 | } 8 | 9 | 10 | def load_config(path, default_path=None): 11 | """ 12 | Loads config file. 13 | 14 | Args: 15 | path (str): path to config file. 16 | default_path (str, optional): whether to use default path. Defaults to None. 17 | 18 | Returns: 19 | cfg (dict): config dict. 20 | 21 | """ 22 | # load configuration from file itself 23 | with open(path, 'r') as f: 24 | cfg_special = yaml.full_load(f) 25 | 26 | # check if we should inherit from a config 27 | inherit_from = cfg_special.get('inherit_from') 28 | 29 | # if yes, load this config first as default 30 | # if no, use the default_path 31 | if inherit_from is not None: 32 | cfg = load_config(inherit_from, default_path) 33 | elif default_path is not None: 34 | with open(default_path, 'r') as f: 35 | cfg = yaml.full_load(f) 36 | else: 37 | cfg = dict() 38 | 39 | # include main configuration 40 | update_recursive(cfg, cfg_special) 41 | 42 | return cfg 43 | 44 | 45 | def update_recursive(dict1, dict2): 46 | """ 47 | Update two config dictionaries recursively. 48 | 49 | Args: 50 | dict1 (dict): first dictionary to be updated. 51 | dict2 (dict): second dictionary which entries should be used. 52 | """ 53 | for k, v in dict2.items(): 54 | if k not in dict1: 55 | dict1[k] = dict() 56 | if isinstance(v, dict): 57 | update_recursive(dict1[k], v) 58 | else: 59 | dict1[k] = v 60 | 61 | 62 | # Models 63 | def get_model(cfg): 64 | """ 65 | Returns the model instance. 66 | 67 | Args: 68 | cfg (dict): config dictionary. 69 | 70 | 71 | Returns: 72 | model (nn.module): network model. 73 | """ 74 | 75 | method = 'conv_onet' 76 | model = method_dict[method].config.get_model(cfg) 77 | 78 | return model 79 | -------------------------------------------------------------------------------- /src/conv_onet/__init__.py: -------------------------------------------------------------------------------- 1 | from src.conv_onet import ( 2 | config, models 3 | ) 4 | 5 | __all__ = [ 6 | config, models 7 | ] 8 | -------------------------------------------------------------------------------- /src/conv_onet/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/conv_onet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/conv_onet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/conv_onet/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /src/conv_onet/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /src/conv_onet/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /src/conv_onet/config.py: -------------------------------------------------------------------------------- 1 | from src.conv_onet import models 2 | 3 | 4 | def get_model(cfg): 5 | """ 6 | Return the network model. 7 | 8 | Args: 9 | cfg (dict): imported yaml config. 10 | 11 | Returns: 12 | decoder (nn.module): the network model. 13 | """ 14 | 15 | dim = cfg['data']['dim'] 16 | low_grid_len = cfg['grid_len']['low'] 17 | high_grid_len = cfg['grid_len']['high'] 18 | color_grid_len = cfg['grid_len']['color'] 19 | c_dim = cfg['model']['c_dim'] # feature dimensions 20 | pos_embedding_method = cfg['model']['pos_embedding_method'] 21 | 22 | decoder = models.decoder_dict['dfprior']( 23 | dim=dim, c_dim=c_dim, 24 | low_grid_len=low_grid_len, high_grid_len=high_grid_len, 25 | color_grid_len=color_grid_len, pos_embedding_method=pos_embedding_method) 26 | 27 | return decoder 28 | -------------------------------------------------------------------------------- /src/conv_onet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from src.conv_onet.models import decoder 2 | 3 | # Decoder dictionary 4 | decoder_dict = { 5 | 'dfprior': decoder.DF 6 | } -------------------------------------------------------------------------------- /src/conv_onet/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/conv_onet/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/conv_onet/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/conv_onet/models/__pycache__/decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/decoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/conv_onet/models/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /src/conv_onet/models/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/conv_onet/models/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /src/conv_onet/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from src.common import normalize_3d_coordinate 5 | 6 | 7 | class GaussianFourierFeatureTransform(torch.nn.Module): 8 | """ 9 | Modified based on the implementation of Gaussian Fourier feature mapping. 10 | 11 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": 12 | https://arxiv.org/abs/2006.10739 13 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 14 | 15 | """ 16 | 17 | def __init__(self, num_input_channels, mapping_size=93, scale=25, learnable=True): 18 | super().__init__() 19 | 20 | if learnable: 21 | self._B = nn.Parameter(torch.randn( 22 | (num_input_channels, mapping_size)) * scale) 23 | else: 24 | self._B = torch.randn((num_input_channels, mapping_size)) * scale 25 | 26 | def forward(self, x): 27 | x = x.squeeze(0) 28 | assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(x.dim()) 29 | x = x @ self._B.to(x.device) 30 | return torch.sin(x) 31 | 32 | 33 | class Nerf_positional_embedding(torch.nn.Module): 34 | """ 35 | Nerf positional embedding. 36 | 37 | """ 38 | 39 | def __init__(self, multires, log_sampling=True): 40 | super().__init__() 41 | self.log_sampling = log_sampling 42 | self.include_input = True 43 | self.periodic_fns = [torch.sin, torch.cos] 44 | self.max_freq_log2 = multires-1 45 | self.num_freqs = multires 46 | self.max_freq = self.max_freq_log2 47 | self.N_freqs = self.num_freqs 48 | 49 | def forward(self, x): 50 | x = x.squeeze(0) 51 | assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format( 52 | x.dim()) 53 | 54 | if self.log_sampling: 55 | freq_bands = 2.**torch.linspace(0., 56 | self.max_freq, steps=self.N_freqs) 57 | else: 58 | freq_bands = torch.linspace( 59 | 2.**0., 2.**self.max_freq, steps=self.N_freqs) 60 | output = [] 61 | if self.include_input: 62 | output.append(x) 63 | for freq in freq_bands: 64 | for p_fn in self.periodic_fns: 65 | output.append(p_fn(x * freq)) 66 | ret = torch.cat(output, dim=1) 67 | return ret 68 | 69 | 70 | class DenseLayer(nn.Linear): 71 | def __init__(self, in_dim: int, out_dim: int, activation: str = "relu", *args, **kwargs) -> None: 72 | self.activation = activation 73 | super().__init__(in_dim, out_dim, *args, **kwargs) 74 | 75 | def reset_parameters(self) -> None: 76 | torch.nn.init.xavier_uniform_( 77 | self.weight, gain=torch.nn.init.calculate_gain(self.activation)) 78 | if self.bias is not None: 79 | torch.nn.init.zeros_(self.bias) 80 | 81 | 82 | class Same(nn.Module): 83 | def __init__(self): 84 | super().__init__() 85 | 86 | def forward(self, x): 87 | x = x.squeeze(0) 88 | return x 89 | 90 | 91 | class MLP(nn.Module): 92 | """ 93 | Decoder. Point coordinates not only used in sampling the feature grids, but also as MLP input. 94 | 95 | Args: 96 | name (str): name of this decoder. 97 | dim (int): input dimension. 98 | c_dim (int): feature dimension. 99 | hidden_size (int): hidden size of Decoder network. 100 | n_blocks (int): number of layers. 101 | leaky (bool): whether to use leaky ReLUs. 102 | sample_mode (str): sampling feature strategy, bilinear|nearest. 103 | color (bool): whether or not to output color. 104 | skips (list): list of layers to have skip connections. 105 | grid_len (float): voxel length of its corresponding feature grid. 106 | pos_embedding_method (str): positional embedding method. 107 | concat_feature (bool): whether to get feature from low level and concat to the current feature. 108 | """ 109 | 110 | def __init__(self, name='', dim=3, c_dim=128, 111 | hidden_size=256, n_blocks=5, leaky=False, sample_mode='bilinear', 112 | color=False, skips=[2], grid_len=0.16, pos_embedding_method='fourier', concat_feature=False): 113 | super().__init__() 114 | self.name = name 115 | self.color = color 116 | self.no_grad_feature = False 117 | self.c_dim = c_dim 118 | self.grid_len = grid_len 119 | self.concat_feature = concat_feature 120 | self.n_blocks = n_blocks 121 | self.skips = skips 122 | 123 | if c_dim != 0: 124 | self.fc_c = nn.ModuleList([ 125 | nn.Linear(c_dim, hidden_size) for i in range(n_blocks) 126 | ]) 127 | 128 | if pos_embedding_method == 'fourier': 129 | embedding_size = 93 130 | self.embedder = GaussianFourierFeatureTransform( 131 | dim, mapping_size=embedding_size, scale=25) 132 | elif pos_embedding_method == 'same': 133 | embedding_size = 3 134 | self.embedder = Same() 135 | elif pos_embedding_method == 'nerf': 136 | if 'color' in name: 137 | multires = 10 138 | self.embedder = Nerf_positional_embedding( 139 | multires, log_sampling=True) 140 | else: 141 | multires = 5 142 | self.embedder = Nerf_positional_embedding( 143 | multires, log_sampling=False) 144 | embedding_size = multires*6+3 145 | elif pos_embedding_method == 'fc_relu': 146 | embedding_size = 93 147 | self.embedder = DenseLayer(dim, embedding_size, activation='relu') 148 | 149 | self.pts_linears = nn.ModuleList( 150 | [DenseLayer(embedding_size, hidden_size, activation="relu")] + 151 | [DenseLayer(hidden_size, hidden_size, activation="relu") if i not in self.skips 152 | else DenseLayer(hidden_size + embedding_size, hidden_size, activation="relu") for i in range(n_blocks-1)]) 153 | 154 | if self.color: 155 | self.output_linear = DenseLayer( 156 | hidden_size, 4, activation="linear") 157 | else: 158 | self.output_linear = DenseLayer( 159 | hidden_size, 1, activation="linear") 160 | 161 | if not leaky: 162 | self.actvn = F.relu 163 | else: 164 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 165 | 166 | self.sample_mode = sample_mode 167 | 168 | def sample_grid_feature(self, p, c): 169 | p_nor = normalize_3d_coordinate(p.clone(), self.bound) 170 | p_nor = p_nor.unsqueeze(0) 171 | vgrid = p_nor[:, :, None, None].float() 172 | # acutally trilinear interpolation if mode = 'bilinear' 173 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, 174 | mode=self.sample_mode).squeeze(-1).squeeze(-1) 175 | return c 176 | 177 | def forward(self, p, c_grid=None): 178 | if self.c_dim != 0: 179 | c = self.sample_grid_feature( 180 | p, c_grid['grid_' + self.name]).transpose(1, 2).squeeze(0) 181 | 182 | if self.concat_feature: 183 | # only happen to high decoder, get feature from low level and concat to the current feature 184 | with torch.no_grad(): 185 | c_low = self.sample_grid_feature( 186 | p, c_grid['grid_low']).transpose(1, 2).squeeze(0) 187 | c = torch.cat([c, c_low], dim=1) 188 | 189 | p = p.float() 190 | 191 | embedded_pts = self.embedder(p) 192 | h = embedded_pts 193 | for i, l in enumerate(self.pts_linears): 194 | h = self.pts_linears[i](h) 195 | h = F.relu(h) 196 | if self.c_dim != 0: 197 | h = h + self.fc_c[i](c) 198 | if i in self.skips: 199 | h = torch.cat([embedded_pts, h], -1) 200 | out = self.output_linear(h) 201 | if not self.color: 202 | out = out.squeeze(-1) 203 | return out 204 | 205 | 206 | class mlp_tsdf(nn.Module): 207 | """ 208 | Attention-based MLP. 209 | 210 | """ 211 | 212 | def __init__(self): 213 | super().__init__() 214 | 215 | self.no_grad_feature = False 216 | self.sample_mode = 'bilinear' 217 | 218 | self.pts_linears = nn.ModuleList( 219 | [DenseLayer(2, 64, activation="relu")] + 220 | [DenseLayer(64, 128, activation="relu")] + 221 | [DenseLayer(128, 128, activation="relu")] + 222 | [DenseLayer(128, 64, activation="relu")]) 223 | 224 | self.output_linear = DenseLayer( 225 | 64, 2, activation="linear") #linear 226 | 227 | self.softmax = nn.Softmax(dim=1) 228 | self.sigmoid = nn.Sigmoid() 229 | 230 | def sample_grid_tsdf(self, p, tsdf_volume, tsdf_bnds, device='cuda:0'): 231 | p_nor = normalize_3d_coordinate(p.clone(), tsdf_bnds) 232 | p_nor = p_nor.unsqueeze(0) 233 | vgrid = p_nor[:, :, None, None].float() 234 | # acutally trilinear interpolation if mode = 'bilinear' 235 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True, 236 | mode='bilinear').squeeze(-1).squeeze(-1) 237 | 238 | return tsdf_value 239 | 240 | def forward(self, p, occ, tsdf_volume, tsdf_bnds, **kwargs): 241 | tsdf_val = self.sample_grid_tsdf(p, tsdf_volume, tsdf_bnds, device='cuda:0') 242 | tsdf_val = tsdf_val.squeeze(0) 243 | 244 | tsdf_val = 1. - (tsdf_val + 1.) / 2. #0,1 245 | tsdf_val = torch.clamp(tsdf_val, 0.0, 1.0) 246 | occ = occ.reshape(tsdf_val.shape) 247 | inv_tsdf = -0.1 * torch.log((1 / (tsdf_val + 1e-8)) - 1 + 1e-7) #0.1 248 | inv_tsdf = torch.clamp(inv_tsdf, -100.0, 100.0) 249 | input = torch.cat([occ, inv_tsdf], dim=0) 250 | h = input.t() 251 | for i, l in enumerate(self.pts_linears): 252 | h = self.pts_linears[i](h) 253 | h = F.relu(h) 254 | weight = self.output_linear(h) 255 | weight = self.softmax(weight) 256 | out = weight.mul(input.t()).sum(dim=1) 257 | 258 | return out, weight[:, 1] 259 | 260 | 261 | 262 | class DF(nn.Module): 263 | """ 264 | Neural Implicit Scalable Encoding. 265 | 266 | Args: 267 | dim (int): input dimension. 268 | c_dim (int): feature dimension. 269 | low_grid_len (float): voxel length in low grid. 270 | high_grid_len (float): voxel length in high grid. 271 | color_grid_len (float): voxel length in color grid. 272 | hidden_size (int): hidden size of decoder network 273 | pos_embedding_method (str): positional embedding method. 274 | """ 275 | 276 | def __init__(self, dim=3, c_dim=32, 277 | low_grid_len=0.16, high_grid_len=0.16, 278 | color_grid_len=0.16, hidden_size=32, pos_embedding_method='fourier'): 279 | super().__init__() 280 | 281 | 282 | self.low_decoder = MLP(name='low', dim=dim, c_dim=c_dim, color=False, 283 | skips=[2], n_blocks=5, hidden_size=hidden_size, 284 | grid_len=low_grid_len, pos_embedding_method=pos_embedding_method) 285 | self.high_decoder = MLP(name='high', dim=dim, c_dim=c_dim*2, color=False, 286 | skips=[2], n_blocks=5, hidden_size=hidden_size, 287 | grid_len=high_grid_len, concat_feature=True, pos_embedding_method=pos_embedding_method) 288 | self.color_decoder = MLP(name='color', dim=dim, c_dim=c_dim, color=True, 289 | skips=[2], n_blocks=5, hidden_size=hidden_size, 290 | grid_len=color_grid_len, pos_embedding_method=pos_embedding_method) 291 | 292 | self.mlp = mlp_tsdf() 293 | 294 | 295 | def sample_grid_tsdf(self, p, tsdf_volume, tsdf_bnds, device='cuda:0'): 296 | 297 | p_nor = normalize_3d_coordinate(p.clone(), tsdf_bnds) 298 | p_nor = p_nor.unsqueeze(0) 299 | vgrid = p_nor[:, :, None, None].float() 300 | # acutally trilinear interpolation if mode = 'bilinear' 301 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True, 302 | mode='bilinear').squeeze(-1).squeeze(-1) 303 | return tsdf_value 304 | 305 | 306 | 307 | def forward(self, p, c_grid, tsdf_volume, tsdf_bnds, stage='low', **kwargs): 308 | """ 309 | Output occupancy/color in different stage. 310 | """ 311 | 312 | device = f'cuda:{p.get_device()}' 313 | if stage == 'low': 314 | low_occ = self.low_decoder(p, c_grid) 315 | low_occ = low_occ.squeeze(0) 316 | 317 | w = torch.ones(low_occ.shape[0]).to(device) 318 | raw = torch.zeros(low_occ.shape[0], 4).to(device).float() 319 | raw[..., -1] = low_occ # new_occ 320 | return raw, w 321 | elif stage == 'high': 322 | high_occ = self.high_decoder(p, c_grid) 323 | raw = torch.zeros(high_occ.shape[0], 4).to(device).float() 324 | low_occ = self.low_decoder(p, c_grid) 325 | low_occ = low_occ.squeeze(0) 326 | f_add_m_occ = high_occ + low_occ 327 | 328 | eval_tsdf = self.sample_grid_tsdf(p, tsdf_volume, tsdf_bnds, device) 329 | eval_tsdf_mask = ((eval_tsdf > -1.0+1e-4) & (eval_tsdf < 1.0-1e-4)) 330 | eval_tsdf_mask = eval_tsdf_mask.squeeze() 331 | 332 | w = torch.ones(low_occ.shape[0]).to(device) 333 | low_occ[eval_tsdf_mask], w[eval_tsdf_mask] = self.mlp(p[:, eval_tsdf_mask, :], f_add_m_occ[eval_tsdf_mask], tsdf_volume, tsdf_bnds) 334 | new_occ = low_occ 335 | new_occ = new_occ.squeeze(-1) 336 | raw[..., -1] = new_occ 337 | return raw, w 338 | elif stage == 'color': 339 | high_occ = self.high_decoder(p, c_grid) 340 | raw = self.color_decoder(p, c_grid) 341 | low_occ = self.low_decoder(p, c_grid) 342 | low_occ = low_occ.squeeze(0) 343 | f_add_m_occ = high_occ + low_occ 344 | 345 | eval_tsdf = self.sample_grid_tsdf(p, tsdf_volume, tsdf_bnds, device) 346 | eval_tsdf_mask = ((eval_tsdf > -1.0+1e-4) & (eval_tsdf < 1.0-1e-4)) 347 | eval_tsdf_mask = eval_tsdf_mask.squeeze() 348 | 349 | w = torch.ones(low_occ.shape[0]).to(device) 350 | low_occ[eval_tsdf_mask], w[eval_tsdf_mask] = self.mlp(p[:, eval_tsdf_mask, :], f_add_m_occ[eval_tsdf_mask], tsdf_volume, tsdf_bnds) 351 | new_occ = low_occ 352 | raw[..., -1] = new_occ 353 | return raw, w 354 | -------------------------------------------------------------------------------- /src/fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Andy Zeng 2 | 3 | import numpy as np 4 | 5 | from numba import njit, prange 6 | from skimage import measure 7 | 8 | 9 | try: 10 | import pycuda.driver as cuda 11 | import pycuda.autoinit 12 | from pycuda.compiler import SourceModule 13 | FUSION_GPU_MODE = 1 14 | except Exception as err: 15 | print('Warning: {}'.format(err)) 16 | print('Failed to import PyCUDA. Running fusion in CPU mode.') 17 | FUSION_GPU_MODE = 0 18 | 19 | # FUSION_GPU_MODE = 0 20 | 21 | class TSDFVolume: 22 | """Volumetric TSDF Fusion of RGB-D Images. 23 | """ 24 | def __init__(self, vol_bnds, voxel_size, use_gpu=True): 25 | """Constructor. 26 | 27 | Args: 28 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the 29 | xyz bounds (min/max) in meters. 30 | voxel_size (float): The volume discretization in meters. 31 | """ 32 | vol_bnds = np.asarray(vol_bnds) 33 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)." 34 | 35 | # Define voxel volume parameters 36 | self._vol_bnds = vol_bnds 37 | self._voxel_size = float(voxel_size) 38 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF 5 * 39 | self._color_const = 256 * 256 40 | 41 | # Adjust volume bounds and ensure C-order contiguous 42 | self._vol_dim = np.ceil((self._vol_bnds[:,1]-self._vol_bnds[:,0])/self._voxel_size).copy(order='C').astype(int) 43 | self._vol_bnds[:,1] = self._vol_bnds[:,0]+self._vol_dim*self._voxel_size 44 | self._vol_origin = self._vol_bnds[:,0].copy(order='C').astype(np.float32) 45 | 46 | print("Voxel volume size: {} x {} x {} - # points: {:,}".format( 47 | self._vol_dim[0], self._vol_dim[1], self._vol_dim[2], 48 | self._vol_dim[0]*self._vol_dim[1]*self._vol_dim[2]) 49 | ) 50 | 51 | # Initialize pointers to voxel volume in CPU memory 52 | self._tsdf_vol_cpu = -1 * np.ones(self._vol_dim).astype(np.float32) #ones 53 | # for computing the cumulative moving average of observations per voxel 54 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 55 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 56 | 57 | self.gpu_mode = use_gpu and FUSION_GPU_MODE 58 | 59 | # Copy voxel volumes to GPU 60 | if self.gpu_mode: 61 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes) 62 | cuda.memcpy_htod(self._tsdf_vol_gpu,self._tsdf_vol_cpu) 63 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes) 64 | cuda.memcpy_htod(self._weight_vol_gpu,self._weight_vol_cpu) 65 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes) 66 | cuda.memcpy_htod(self._color_vol_gpu,self._color_vol_cpu) 67 | 68 | # Cuda kernel function (C++) 69 | self._cuda_src_mod = SourceModule(""" 70 | __global__ void integrate(float * tsdf_vol, 71 | float * weight_vol, 72 | float * color_vol, 73 | float * vol_dim, 74 | float * vol_origin, 75 | float * cam_intr, 76 | float * cam_pose, 77 | float * other_params, 78 | float * color_im, 79 | float * depth_im) { 80 | // Get voxel index 81 | int gpu_loop_idx = (int) other_params[0]; 82 | int max_threads_per_block = blockDim.x; 83 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x; 84 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x; 85 | int vol_dim_x = (int) vol_dim[0]; 86 | int vol_dim_y = (int) vol_dim[1]; 87 | int vol_dim_z = (int) vol_dim[2]; 88 | if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z) 89 | return; 90 | // Get voxel grid coordinates (note: be careful when casting) 91 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z))); 92 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z)); 93 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z); 94 | // Voxel grid coordinates to world coordinates 95 | float voxel_size = other_params[1]; 96 | float pt_x = vol_origin[0]+voxel_x*voxel_size; 97 | float pt_y = vol_origin[1]+voxel_y*voxel_size; 98 | float pt_z = vol_origin[2]+voxel_z*voxel_size; 99 | // World coordinates to camera coordinates 100 | float tmp_pt_x = pt_x-cam_pose[0*4+3]; 101 | float tmp_pt_y = pt_y-cam_pose[1*4+3]; 102 | float tmp_pt_z = pt_z-cam_pose[2*4+3]; 103 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z; 104 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z; 105 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z; 106 | // Camera coordinates to image pixels 107 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]); 108 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]); 109 | // Skip if outside view frustum 110 | int im_h = (int) other_params[2]; 111 | int im_w = (int) other_params[3]; 112 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0) 113 | return; 114 | // Skip invalid depth 115 | float depth_value = depth_im[pixel_y*im_w+pixel_x]; 116 | if (depth_value == 0) 117 | return; 118 | // Integrate TSDF 119 | float trunc_margin = other_params[4]; 120 | float depth_diff = depth_value-cam_pt_z; 121 | if (depth_diff < -trunc_margin) 122 | return; 123 | float dist = fmin(1.0f,depth_diff/trunc_margin); 124 | float w_old = weight_vol[voxel_idx]; 125 | float obs_weight = other_params[5]; 126 | float w_new = w_old + obs_weight; 127 | weight_vol[voxel_idx] = w_new; 128 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new; 129 | // Integrate color 130 | float old_color = color_vol[voxel_idx]; 131 | float old_b = floorf(old_color/(256*256)); 132 | float old_g = floorf((old_color-old_b*256*256)/256); 133 | float old_r = old_color-old_b*256*256-old_g*256; 134 | float new_color = color_im[pixel_y*im_w+pixel_x]; 135 | float new_b = floorf(new_color/(256*256)); 136 | float new_g = floorf((new_color-new_b*256*256)/256); 137 | float new_r = new_color-new_b*256*256-new_g*256; 138 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f); 139 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f); 140 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f); 141 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r; 142 | }""") 143 | 144 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate") 145 | 146 | # Determine block/grid size on GPU 147 | gpu_dev = cuda.Device(0) 148 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK 149 | n_blocks = int(np.ceil(float(np.prod(self._vol_dim))/float(self._max_gpu_threads_per_block))) 150 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X,int(np.floor(np.cbrt(n_blocks)))) 151 | grid_dim_y = min(gpu_dev.MAX_GRID_DIM_Y,int(np.floor(np.sqrt(n_blocks/grid_dim_x)))) 152 | grid_dim_z = min(gpu_dev.MAX_GRID_DIM_Z,int(np.ceil(float(n_blocks)/float(grid_dim_x*grid_dim_y)))) 153 | self._max_gpu_grid_dim = np.array([grid_dim_x,grid_dim_y,grid_dim_z]).astype(int) 154 | self._n_gpu_loops = int(np.ceil(float(np.prod(self._vol_dim))/float(np.prod(self._max_gpu_grid_dim)*self._max_gpu_threads_per_block))) 155 | 156 | else: 157 | # Get voxel grid coordinates 158 | xv, yv, zv = np.meshgrid( 159 | range(self._vol_dim[0]), 160 | range(self._vol_dim[1]), 161 | range(self._vol_dim[2]), 162 | indexing='ij' 163 | ) 164 | self.vox_coords = np.concatenate([ 165 | xv.reshape(1,-1), 166 | yv.reshape(1,-1), 167 | zv.reshape(1,-1) 168 | ], axis=0).astype(int).T 169 | 170 | @staticmethod 171 | @njit(parallel=True) 172 | def vox2world(vol_origin, vox_coords, vox_size): 173 | """Convert voxel grid coordinates to world coordinates. 174 | """ 175 | vol_origin = vol_origin.astype(np.float32) 176 | vox_coords = vox_coords.astype(np.float32) 177 | cam_pts = np.empty_like(vox_coords, dtype=np.float32) 178 | for i in prange(vox_coords.shape[0]): 179 | for j in range(3): 180 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j]) 181 | return cam_pts 182 | 183 | @staticmethod 184 | @njit(parallel=True) 185 | def cam2pix(cam_pts, intr): 186 | """Convert camera coordinates to pixel coordinates. 187 | """ 188 | intr = intr.astype(np.float32) 189 | fx, fy = intr[0, 0], intr[1, 1] 190 | cx, cy = intr[0, 2], intr[1, 2] 191 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64) 192 | for i in prange(cam_pts.shape[0]): 193 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx)) 194 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy)) 195 | return pix 196 | 197 | @staticmethod 198 | @njit(parallel=True) 199 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight): 200 | """Integrate the TSDF volume. 201 | """ 202 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32) 203 | w_new = np.empty_like(w_old, dtype=np.float32) 204 | for i in prange(len(tsdf_vol)): 205 | w_new[i] = w_old[i] + obs_weight 206 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i] 207 | return tsdf_vol_int, w_new 208 | 209 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.): 210 | """Integrate an RGB-D frame into the TSDF volume. 211 | 212 | Args: 213 | color_im (ndarray): An RGB image of shape (H, W, 3). 214 | depth_im (ndarray): A depth image of shape (H, W). 215 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3). 216 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4). 217 | obs_weight (float): The weight to assign for the current observation. A higher 218 | value 219 | """ 220 | im_h, im_w = depth_im.shape 221 | 222 | # Fold RGB color image into a single channel image 223 | color_im = color_im.astype(np.float32) 224 | color_im = np.floor(color_im[...,2]*self._color_const + color_im[...,1]*256 + color_im[...,0]) 225 | 226 | if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel) 227 | for gpu_loop_idx in range(self._n_gpu_loops): 228 | self._cuda_integrate(self._tsdf_vol_gpu, 229 | self._weight_vol_gpu, 230 | self._color_vol_gpu, 231 | cuda.InOut(self._vol_dim.astype(np.float32)), 232 | cuda.InOut(self._vol_origin.astype(np.float32)), 233 | cuda.InOut(cam_intr.reshape(-1).astype(np.float32)), 234 | cuda.InOut(cam_pose.reshape(-1).astype(np.float32)), 235 | cuda.InOut(np.asarray([ 236 | gpu_loop_idx, 237 | self._voxel_size, 238 | im_h, 239 | im_w, 240 | self._trunc_margin, 241 | obs_weight 242 | ], np.float32)), 243 | cuda.InOut(color_im.reshape(-1).astype(np.float32)), 244 | cuda.InOut(depth_im.reshape(-1).astype(np.float32)), 245 | block=(self._max_gpu_threads_per_block,1,1), 246 | grid=( 247 | int(self._max_gpu_grid_dim[0]), 248 | int(self._max_gpu_grid_dim[1]), 249 | int(self._max_gpu_grid_dim[2]), 250 | ) 251 | ) 252 | else: # CPU mode: integrate voxel volume (vectorized implementation) 253 | # Convert voxel grid coordinates to pixel coordinates 254 | cam_pts = self.vox2world(self._vol_origin, self.vox_coords, self._voxel_size) 255 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose)) 256 | pix_z = cam_pts[:, 2] 257 | pix = self.cam2pix(cam_pts, cam_intr) 258 | pix_x, pix_y = pix[:, 0], pix[:, 1] 259 | 260 | # Eliminate pixels outside view frustum 261 | valid_pix = np.logical_and(pix_x >= 0, 262 | np.logical_and(pix_x < im_w, 263 | np.logical_and(pix_y >= 0, 264 | np.logical_and(pix_y < im_h, 265 | pix_z > 0)))) 266 | depth_val = np.zeros(pix_x.shape) 267 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]] 268 | 269 | # Integrate TSDF 270 | depth_diff = depth_val - pix_z 271 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin) 272 | dist = np.minimum(1, depth_diff / self._trunc_margin) 273 | valid_vox_x = self.vox_coords[valid_pts, 0] 274 | valid_vox_y = self.vox_coords[valid_pts, 1] 275 | valid_vox_z = self.vox_coords[valid_pts, 2] 276 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 277 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 278 | valid_dist = dist[valid_pts] 279 | tsdf_vol_new, w_new = self.integrate_tsdf(tsdf_vals, valid_dist, w_old, obs_weight) 280 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new 281 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new 282 | 283 | # Integrate color 284 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 285 | old_b = np.floor(old_color / self._color_const) 286 | old_g = np.floor((old_color-old_b*self._color_const)/256) 287 | old_r = old_color - old_b*self._color_const - old_g*256 288 | new_color = color_im[pix_y[valid_pts],pix_x[valid_pts]] 289 | new_b = np.floor(new_color / self._color_const) 290 | new_g = np.floor((new_color - new_b*self._color_const) /256) 291 | new_r = new_color - new_b*self._color_const - new_g*256 292 | new_b = np.minimum(255., np.round((w_old*old_b + obs_weight*new_b) / w_new)) 293 | new_g = np.minimum(255., np.round((w_old*old_g + obs_weight*new_g) / w_new)) 294 | new_r = np.minimum(255., np.round((w_old*old_r + obs_weight*new_r) / w_new)) 295 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = new_b*self._color_const + new_g*256 + new_r 296 | 297 | def get_volume(self): 298 | if self.gpu_mode: 299 | cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu) 300 | cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu) 301 | return self._tsdf_vol_cpu, self._color_vol_cpu, self._vol_bnds 302 | 303 | def get_point_cloud(self): 304 | """Extract a point cloud from the voxel volume. 305 | """ 306 | tsdf_vol, color_vol = self.get_volume() 307 | 308 | # Marching cubes 309 | verts = measure.marching_cubes(tsdf_vol, level=0)[0] 310 | verts_ind = np.round(verts).astype(int) 311 | verts = verts*self._voxel_size + self._vol_origin 312 | 313 | # Get vertex colors 314 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 315 | colors_b = np.floor(rgb_vals / self._color_const) 316 | colors_g = np.floor((rgb_vals - colors_b*self._color_const) / 256) 317 | colors_r = rgb_vals - colors_b*self._color_const - colors_g*256 318 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 319 | colors = colors.astype(np.uint8) 320 | 321 | pc = np.hstack([verts, colors]) 322 | return pc 323 | 324 | def get_mesh(self): 325 | """Compute a mesh from the voxel volume using marching cubes. 326 | """ 327 | tsdf_vol, color_vol, bnds = self.get_volume() 328 | 329 | # Marching cubes 330 | verts, faces, norms, vals = measure.marching_cubes(tsdf_vol, level=0) 331 | verts_ind = np.round(verts).astype(int) 332 | verts = verts*self._voxel_size+self._vol_origin # voxel grid coordinates to world coordinates 333 | 334 | # Get vertex colors 335 | rgb_vals = color_vol[verts_ind[:,0], verts_ind[:,1], verts_ind[:,2]] 336 | colors_b = np.floor(rgb_vals/self._color_const) 337 | colors_g = np.floor((rgb_vals-colors_b*self._color_const)/256) 338 | colors_r = rgb_vals-colors_b*self._color_const-colors_g*256 339 | colors = np.floor(np.asarray([colors_r,colors_g,colors_b])).T 340 | colors = colors.astype(np.uint8) 341 | return verts, faces, norms, colors 342 | 343 | 344 | def rigid_transform(xyz, transform): 345 | """Applies a rigid transform to an (N, 3) pointcloud. 346 | """ 347 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)]) 348 | xyz_t_h = np.dot(transform, xyz_h.T).T 349 | return xyz_t_h[:, :3] 350 | 351 | 352 | def get_view_frustum(depth_im, cam_intr, cam_pose): 353 | """Get corners of 3D camera view frustum of depth image 354 | """ 355 | im_h = depth_im.shape[0] 356 | im_w = depth_im.shape[1] 357 | max_depth = np.max(depth_im) 358 | view_frust_pts = np.array([ 359 | (np.array([0,0,0,im_w,im_w])-cam_intr[0,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[0,0], 360 | (np.array([0,0,im_h,0,im_h])-cam_intr[1,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[1,1], 361 | np.array([0,max_depth,max_depth,max_depth,max_depth]) 362 | ]) 363 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T 364 | return view_frust_pts 365 | 366 | 367 | def meshwrite(filename, verts, faces, norms, colors): 368 | """Save a 3D mesh to a polygon .ply file. 369 | """ 370 | # Write header 371 | ply_file = open(filename,'w') 372 | ply_file.write("ply\n") 373 | ply_file.write("format ascii 1.0\n") 374 | ply_file.write("element vertex %d\n"%(verts.shape[0])) 375 | ply_file.write("property float x\n") 376 | ply_file.write("property float y\n") 377 | ply_file.write("property float z\n") 378 | ply_file.write("property float nx\n") 379 | ply_file.write("property float ny\n") 380 | ply_file.write("property float nz\n") 381 | ply_file.write("property uchar red\n") 382 | ply_file.write("property uchar green\n") 383 | ply_file.write("property uchar blue\n") 384 | ply_file.write("element face %d\n"%(faces.shape[0])) 385 | ply_file.write("property list uchar int vertex_index\n") 386 | ply_file.write("end_header\n") 387 | 388 | # Write vertex list 389 | for i in range(verts.shape[0]): 390 | ply_file.write("%f %f %f %f %f %f %d %d %d\n"%( 391 | verts[i,0], verts[i,1], verts[i,2], 392 | norms[i,0], norms[i,1], norms[i,2], 393 | colors[i,0], colors[i,1], colors[i,2], 394 | )) 395 | 396 | # Write face list 397 | for i in range(faces.shape[0]): 398 | ply_file.write("3 %d %d %d\n"%(faces[i,0], faces[i,1], faces[i,2])) 399 | 400 | ply_file.close() 401 | 402 | 403 | def pcwrite(filename, xyzrgb): 404 | """Save a point cloud to a polygon .ply file. 405 | """ 406 | xyz = xyzrgb[:, :3] 407 | rgb = xyzrgb[:, 3:].astype(np.uint8) 408 | 409 | # Write header 410 | ply_file = open(filename,'w') 411 | ply_file.write("ply\n") 412 | ply_file.write("format ascii 1.0\n") 413 | ply_file.write("element vertex %d\n"%(xyz.shape[0])) 414 | ply_file.write("property float x\n") 415 | ply_file.write("property float y\n") 416 | ply_file.write("property float z\n") 417 | ply_file.write("property uchar red\n") 418 | ply_file.write("property uchar green\n") 419 | ply_file.write("property uchar blue\n") 420 | ply_file.write("end_header\n") 421 | 422 | # Write vertex list 423 | for i in range(xyz.shape[0]): 424 | ply_file.write("%f %f %f %d %d %d\n"%( 425 | xyz[i, 0], xyz[i, 1], xyz[i, 2], 426 | rgb[i, 0], rgb[i, 1], rgb[i, 2], 427 | )) 428 | -------------------------------------------------------------------------------- /src/tools/cull_mesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | import trimesh 6 | from tqdm import tqdm 7 | 8 | 9 | def load_poses(path): 10 | poses = [] 11 | with open(path, "r") as f: 12 | lines = f.readlines() 13 | for line in lines: 14 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 15 | c2w[:3, 1] *= -1 16 | c2w[:3, 2] *= -1 17 | c2w = torch.from_numpy(c2w).float() 18 | poses.append(c2w) 19 | return poses 20 | 21 | 22 | parser = argparse.ArgumentParser( 23 | description='Arguments to cull the mesh.' 24 | ) 25 | 26 | parser.add_argument('--input_mesh', type=str, 27 | help='path to the mesh to be culled') 28 | parser.add_argument('--traj', type=str, help='path to the trajectory') 29 | parser.add_argument('--output_mesh', type=str, help='path to the output mesh') 30 | args = parser.parse_args() 31 | 32 | H = 680 33 | W = 1200 34 | fx = 600.0 35 | fy = 600.0 36 | fx = 600.0 37 | cx = 599.5 38 | cy = 339.5 39 | scale = 6553.5 40 | 41 | poses = load_poses(args.traj) 42 | n_imgs = len(poses) 43 | mesh = trimesh.load(args.input_mesh, process=False) 44 | pc = mesh.vertices 45 | faces = mesh.faces 46 | 47 | # delete mesh vertices that are not inside any camera's viewing frustum 48 | whole_mask = np.ones(pc.shape[0]).astype(np.bool) 49 | for i in tqdm(range(0, n_imgs, 1)): 50 | c2w = poses[i] 51 | points = pc.copy() 52 | points = torch.from_numpy(points).cuda() 53 | w2c = np.linalg.inv(c2w) 54 | w2c = torch.from_numpy(w2c).cuda().float() 55 | K = torch.from_numpy( 56 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda() 57 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda() 58 | homo_points = torch.cat( 59 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float() 60 | cam_cord_homo = w2c@homo_points 61 | cam_cord = cam_cord_homo[:, :3] 62 | 63 | cam_cord[:, 0] *= -1 64 | uv = K.float()@cam_cord.float() 65 | z = uv[:, -1:]+1e-5 66 | uv = uv[:, :2]/z 67 | uv = uv.float().squeeze(-1).cpu().numpy() 68 | edge = 0 69 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W - 70 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge) 71 | whole_mask &= ~mask 72 | pc = mesh.vertices 73 | faces = mesh.faces 74 | face_mask = whole_mask[mesh.faces].all(axis=1) 75 | mesh.update_faces(~face_mask) 76 | mesh.export(args.output_mesh) 77 | -------------------------------------------------------------------------------- /src/tools/eval_ate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy 4 | import torch 5 | import sys 6 | import numpy as np 7 | sys.path.append('.') 8 | from src import config 9 | from src.common import get_tensor_from_camera 10 | 11 | def associate(first_list, second_list, offset=0.0, max_difference=0.02): 12 | """ 13 | Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim 14 | to find the closest match for every input tuple. 15 | 16 | Input: 17 | first_list -- first dictionary of (stamp,data) tuples 18 | second_list -- second dictionary of (stamp,data) tuples 19 | offset -- time offset between both dictionaries (e.g., to model the delay between the sensors) 20 | max_difference -- search radius for candidate generation 21 | 22 | Output: 23 | matches -- list of matched tuples ((stamp1,data1),(stamp2,data2)) 24 | 25 | """ 26 | first_keys = list(first_list.keys()) 27 | second_keys = list(second_list.keys()) 28 | potential_matches = [(abs(a - (b + offset)), a, b) 29 | for a in first_keys 30 | for b in second_keys 31 | if abs(a - (b + offset)) < max_difference] 32 | potential_matches.sort() 33 | matches = [] 34 | for diff, a, b in potential_matches: 35 | if a in first_keys and b in second_keys: 36 | first_keys.remove(a) 37 | second_keys.remove(b) 38 | matches.append((a, b)) 39 | 40 | matches.sort() 41 | return matches 42 | 43 | 44 | def align(model, data): 45 | """Align two trajectories using the method of Horn (closed-form). 46 | 47 | Input: 48 | model -- first trajectory (3xn) 49 | data -- second trajectory (3xn) 50 | 51 | Output: 52 | rot -- rotation matrix (3x3) 53 | trans -- translation vector (3x1) 54 | trans_error -- translational error per point (1xn) 55 | 56 | """ 57 | numpy.set_printoptions(precision=3, suppress=True) 58 | model_zerocentered = model - model.mean(1) 59 | data_zerocentered = data - data.mean(1) 60 | 61 | W = numpy.zeros((3, 3)) 62 | for column in range(model.shape[1]): 63 | W += numpy.outer(model_zerocentered[:, 64 | column], data_zerocentered[:, column]) 65 | U, d, Vh = numpy.linalg.linalg.svd(W.transpose()) 66 | S = numpy.matrix(numpy.identity(3)) 67 | if(numpy.linalg.det(U) * numpy.linalg.det(Vh) < 0): 68 | S[2, 2] = -1 69 | rot = U*S*Vh 70 | trans = data.mean(1) - rot * model.mean(1) 71 | 72 | model_aligned = rot * model + trans 73 | alignment_error = model_aligned - data 74 | 75 | trans_error = numpy.sqrt(numpy.sum(numpy.multiply( 76 | alignment_error, alignment_error), 0)).A[0] 77 | 78 | return rot, trans, trans_error 79 | 80 | 81 | def plot_traj(ax, stamps, traj, style, color, label): 82 | """ 83 | Plot a trajectory using matplotlib. 84 | 85 | Input: 86 | ax -- the plot 87 | stamps -- time stamps (1xn) 88 | traj -- trajectory (3xn) 89 | style -- line style 90 | color -- line color 91 | label -- plot legend 92 | 93 | """ 94 | stamps.sort() 95 | interval = numpy.median([s-t for s, t in zip(stamps[1:], stamps[:-1])]) 96 | x = [] 97 | y = [] 98 | last = stamps[0] 99 | for i in range(len(stamps)): 100 | if stamps[i]-last < 2*interval: 101 | x.append(traj[i][0]) 102 | y.append(traj[i][1]) 103 | elif len(x) > 0: 104 | ax.plot(x, y, style, color=color, label=label) 105 | label = "" 106 | x = [] 107 | y = [] 108 | last = stamps[i] 109 | if len(x) > 0: 110 | ax.plot(x, y, style, color=color, label=label) 111 | 112 | 113 | def evaluate_ate(first_list, second_list, plot="", _args=""): 114 | # parse command line 115 | parser = argparse.ArgumentParser( 116 | description='This script computes the absolute trajectory error from the ground truth trajectory and the estimated trajectory.') 117 | # parser.add_argument('first_file', help='ground truth trajectory (format: timestamp tx ty tz qx qy qz qw)') 118 | # parser.add_argument('second_file', help='estimated trajectory (format: timestamp tx ty tz qx qy qz qw)') 119 | parser.add_argument( 120 | '--offset', help='time offset added to the timestamps of the second file (default: 0.0)', default=0.0) 121 | parser.add_argument( 122 | '--scale', help='scaling factor for the second trajectory (default: 1.0)', default=1.0) 123 | parser.add_argument( 124 | '--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)', default=0.02) 125 | parser.add_argument( 126 | '--save', help='save aligned second trajectory to disk (format: stamp2 x2 y2 z2)') 127 | parser.add_argument('--save_associations', 128 | help='save associated first and aligned second trajectory to disk (format: stamp1 x1 y1 z1 stamp2 x2 y2 z2)') 129 | parser.add_argument( 130 | '--plot', help='plot the first and the aligned second trajectory to an image (format: png)') 131 | parser.add_argument( 132 | '--verbose', help='print all evaluation data (otherwise, only the RMSE absolute translational error in meters after alignment will be printed)', action='store_true') 133 | args = parser.parse_args(_args) 134 | args.plot = plot 135 | # first_list = associate.read_file_list(args.first_file) 136 | # second_list = associate.read_file_list(args.second_file) 137 | 138 | matches = associate(first_list, second_list, float( 139 | args.offset), float(args.max_difference)) 140 | if len(matches) < 2: 141 | raise ValueError( 142 | "Couldn't find matching timestamp pairs between groundtruth and estimated trajectory! \ 143 | Did you choose the correct sequence?") 144 | 145 | first_xyz = numpy.matrix( 146 | [[float(value) for value in first_list[a][0:3]] for a, b in matches]).transpose() 147 | second_xyz = numpy.matrix([[float(value)*float(args.scale) 148 | for value in second_list[b][0:3]] for a, b in matches]).transpose() 149 | 150 | rot, trans, trans_error = align(second_xyz, first_xyz) 151 | 152 | second_xyz_aligned = rot * second_xyz + trans 153 | 154 | first_stamps = list(first_list.keys()) 155 | first_stamps.sort() 156 | first_xyz_full = numpy.matrix( 157 | [[float(value) for value in first_list[b][0:3]] for b in first_stamps]).transpose() 158 | 159 | second_stamps = list(second_list.keys()) 160 | second_stamps.sort() 161 | second_xyz_full = numpy.matrix([[float(value)*float(args.scale) 162 | for value in second_list[b][0:3]] for b in second_stamps]).transpose() 163 | second_xyz_full_aligned = rot * second_xyz_full + trans 164 | 165 | if args.verbose: 166 | print("compared_pose_pairs %d pairs" % (len(trans_error))) 167 | 168 | print("absolute_translational_error.rmse %f m" % numpy.sqrt( 169 | numpy.dot(trans_error, trans_error) / len(trans_error))) 170 | print("absolute_translational_error.mean %f m" % 171 | numpy.mean(trans_error)) 172 | print("absolute_translational_error.median %f m" % 173 | numpy.median(trans_error)) 174 | print("absolute_translational_error.std %f m" % numpy.std(trans_error)) 175 | print("absolute_translational_error.min %f m" % numpy.min(trans_error)) 176 | print("absolute_translational_error.max %f m" % numpy.max(trans_error)) 177 | 178 | if args.save_associations: 179 | file = open(args.save_associations, "w") 180 | file.write("\n".join(["%f %f %f %f %f %f %f %f" % (a, x1, y1, z1, b, x2, y2, z2) for ( 181 | a, b), (x1, y1, z1), (x2, y2, z2) in zip(matches, first_xyz.transpose().A, second_xyz_aligned.transpose().A)])) 182 | file.close() 183 | 184 | if args.save: 185 | file = open(args.save, "w") 186 | file.write("\n".join(["%f " % stamp+" ".join(["%f" % d for d in line]) 187 | for stamp, line in zip(second_stamps, second_xyz_full_aligned.transpose().A)])) 188 | file.close() 189 | 190 | if args.plot: 191 | import matplotlib 192 | matplotlib.use('Agg') 193 | import matplotlib.pylab as pylab 194 | import matplotlib.pyplot as plt 195 | from matplotlib.patches import Ellipse 196 | fig = plt.figure() 197 | ax = fig.add_subplot(111) 198 | ATE = numpy.sqrt( 199 | numpy.dot(trans_error, trans_error) / len(trans_error)) 200 | ax.set_title(f'len:{len(trans_error)} ATE RMSE:{ATE} {args.plot[:-3]}') 201 | plot_traj(ax, first_stamps, first_xyz_full.transpose().A, 202 | '-', "black", "ground truth") 203 | plot_traj(ax, second_stamps, second_xyz_full_aligned.transpose( 204 | ).A, '-', "blue", "estimated") 205 | 206 | label = "difference" 207 | for (a, b), (x1, y1, z1), (x2, y2, z2) in zip(matches, first_xyz.transpose().A, second_xyz_aligned.transpose().A): 208 | # ax.plot([x1,x2],[y1,y2],'-',color="red",label=label) 209 | label = "" 210 | ax.legend() 211 | ax.set_xlabel('x [m]') 212 | ax.set_ylabel('y [m]') 213 | plt.savefig(args.plot, dpi=90) 214 | 215 | return { 216 | "compared_pose_pairs": (len(trans_error)), 217 | "absolute_translational_error.rmse": numpy.sqrt(numpy.dot(trans_error, trans_error) / len(trans_error)), 218 | "absolute_translational_error.mean": numpy.mean(trans_error), 219 | "absolute_translational_error.median": numpy.median(trans_error), 220 | "absolute_translational_error.std": numpy.std(trans_error), 221 | "absolute_translational_error.min": numpy.min(trans_error), 222 | "absolute_translational_error.max": numpy.max(trans_error), 223 | } 224 | 225 | 226 | def evaluate(poses_gt, poses_est, plot): 227 | 228 | poses_gt = poses_gt.cpu().numpy() 229 | poses_est = poses_est.cpu().numpy() 230 | 231 | N = poses_gt.shape[0] 232 | poses_gt = dict([(i, poses_gt[i]) for i in range(N)]) 233 | poses_est = dict([(i, poses_est[i]) for i in range(N)]) 234 | 235 | results = evaluate_ate(poses_gt, poses_est, plot) 236 | print(results) 237 | 238 | 239 | def convert_poses(c2w_list, N, scale, gt=True): 240 | poses = [] 241 | mask = torch.ones(N+1).bool() 242 | for idx in range(0, N+1): 243 | if gt: 244 | # some frame have `nan` or `inf` in gt pose of ScanNet, 245 | # but our system have estimated camera pose for all frames 246 | # therefore, when calculating the pose error, we need to mask out invalid pose 247 | if torch.isinf(c2w_list[idx]).any(): 248 | mask[idx] = 0 249 | continue 250 | if torch.isnan(c2w_list[idx]).any(): 251 | mask[idx] = 0 252 | continue 253 | c2w_list[idx][:3, 3] /= scale 254 | poses.append(get_tensor_from_camera(c2w_list[idx], Tquad=True)) 255 | poses = torch.stack(poses) 256 | return poses, mask 257 | 258 | 259 | if __name__ == '__main__': 260 | """ 261 | This ATE evaluation code is modified upon the evaluation code in lie-torch. 262 | """ 263 | 264 | parser = argparse.ArgumentParser( 265 | description='Arguments to eval the tracking ATE.' 266 | ) 267 | parser.add_argument('config', type=str, help='Path to config file.') 268 | parser.add_argument('--output', type=str, 269 | help='output folder, this have higher priority, can overwrite the one inconfig file') 270 | nice_parser = parser.add_mutually_exclusive_group(required=False) 271 | nice_parser.add_argument('--nice', dest='nice', action='store_true') 272 | nice_parser.add_argument('--imap', dest='nice', action='store_false') 273 | parser.set_defaults(nice=True) 274 | 275 | args = parser.parse_args() 276 | cfg = config.load_config( 277 | args.config, 'configs/nice_slam.yaml' if args.nice else 'configs/imap.yaml') 278 | scale = cfg['scale'] 279 | output = cfg['data']['output'] if args.output is None else args.output 280 | cofusion = ('cofusion' in args.config) or ('CoFusion' in args.config) 281 | ckptsdir = f'{output}/ckpts' 282 | if os.path.exists(ckptsdir): 283 | ckpts = [os.path.join(ckptsdir, f) 284 | for f in sorted(os.listdir(ckptsdir)) if 'tar' in f] 285 | if len(ckpts) > 0: 286 | ckpt_path = ckpts[-1] 287 | print('Get ckpt :', ckpt_path) 288 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) 289 | estimate_c2w_list = ckpt['estimate_c2w_list'] 290 | gt_c2w_list = ckpt['gt_c2w_list'] 291 | N = ckpt['idx'] 292 | if cofusion: 293 | poses_gt = np.loadtxt( 294 | 'Datasets/CoFusion/room4/trajectories/gt-cam-0.txt') 295 | poses_gt = torch.from_numpy(poses_gt[:, 1:]) 296 | else: 297 | poses_gt, mask = convert_poses(gt_c2w_list, N, scale) 298 | poses_est, _ = convert_poses(estimate_c2w_list, N, scale) 299 | poses_est = poses_est[mask] 300 | evaluate(poses_gt, poses_est, 301 | plot=f'{output}/eval_ate_plot.png') 302 | -------------------------------------------------------------------------------- /src/tools/eval_recon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | import torch 7 | import trimesh 8 | from scipy.spatial import cKDTree as KDTree 9 | 10 | def setup_seed(seed): 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | 17 | 18 | 19 | def normalize(x): 20 | return x / np.linalg.norm(x) 21 | 22 | 23 | def viewmatrix(z, up, pos): 24 | vec2 = normalize(z) 25 | vec1_avg = up 26 | vec0 = normalize(np.cross(vec1_avg, vec2)) 27 | vec1 = normalize(np.cross(vec2, vec0)) 28 | m = np.stack([vec0, vec1, vec2, pos], 1) 29 | return m 30 | 31 | 32 | def completion_ratio(gt_points, rec_points, dist_th=0.05): 33 | gen_points_kd_tree = KDTree(rec_points) 34 | distances, _ = gen_points_kd_tree.query(gt_points) 35 | comp_ratio = np.mean((distances < dist_th).astype(np.float)) 36 | return comp_ratio 37 | 38 | 39 | def accuracy(gt_points, rec_points): 40 | gt_points_kd_tree = KDTree(gt_points) 41 | distances, _ = gt_points_kd_tree.query(rec_points) 42 | acc = np.mean(distances) 43 | return acc 44 | 45 | 46 | def completion(gt_points, rec_points): 47 | gt_points_kd_tree = KDTree(rec_points) 48 | distances, _ = gt_points_kd_tree.query(gt_points) 49 | comp = np.mean(distances) 50 | return comp 51 | 52 | 53 | def get_align_transformation(rec_meshfile, gt_meshfile): 54 | """ 55 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh. 56 | """ 57 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 58 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 59 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices) 60 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices) 61 | trans_init = np.eye(4) 62 | threshold = 0.1 63 | reg_p2p = o3d.pipelines.registration.registration_icp( 64 | o3d_rec_pc, o3d_gt_pc, threshold, trans_init, 65 | o3d.pipelines.registration.TransformationEstimationPointToPoint()) 66 | transformation = reg_p2p.transformation 67 | return transformation 68 | 69 | 70 | def check_proj(points, W, H, fx, fy, cx, cy, c2w): 71 | """ 72 | Check if points can be projected into the camera view. 73 | 74 | """ 75 | c2w = c2w.copy() 76 | c2w[:3, 1] *= -1.0 77 | c2w[:3, 2] *= -1.0 78 | points = torch.from_numpy(points).cuda().clone() 79 | w2c = np.linalg.inv(c2w) 80 | w2c = torch.from_numpy(w2c).cuda().float() 81 | K = torch.from_numpy( 82 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda() 83 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda() 84 | homo_points = torch.cat( 85 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float() # (N, 4) 86 | cam_cord_homo = w2c@homo_points # (N, 4, 1)=(4,4)*(N, 4, 1) 87 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1) 88 | cam_cord[:, 0] *= -1 89 | uv = K.float()@cam_cord.float() 90 | z = uv[:, -1:]+1e-5 91 | uv = uv[:, :2]/z 92 | uv = uv.float().squeeze(-1).cpu().numpy() 93 | edge = 0 94 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W - 95 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge) 96 | return mask.sum() > 0 97 | 98 | 99 | def calc_3d_metric(rec_meshfile, gt_meshfile, align=True): 100 | """ 101 | 3D reconstruction metric. 102 | 103 | """ 104 | mesh_rec = trimesh.load(rec_meshfile, process=False) 105 | mesh_gt = trimesh.load(gt_meshfile, process=False) 106 | 107 | if align: 108 | transformation = get_align_transformation(rec_meshfile, gt_meshfile) 109 | mesh_rec = mesh_rec.apply_transform(transformation) 110 | 111 | rec_pc = trimesh.sample.sample_surface(mesh_rec, 200000) 112 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0]) 113 | 114 | gt_pc = trimesh.sample.sample_surface(mesh_gt, 200000) 115 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) 116 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices) 117 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices) 118 | completion_ratio_rec = completion_ratio( 119 | gt_pc_tri.vertices, rec_pc_tri.vertices) 120 | accuracy_rec *= 100 # convert to cm 121 | completion_rec *= 100 # convert to cm 122 | completion_ratio_rec *= 100 # convert to % 123 | print('accuracy: ', accuracy_rec) 124 | print('completion: ', completion_rec) 125 | print('completion ratio: ', completion_ratio_rec) 126 | 127 | 128 | def get_cam_position(gt_meshfile): 129 | mesh_gt = trimesh.load(gt_meshfile) 130 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt) 131 | extents[2] *= 0.7 132 | extents[1] *= 0.7 133 | extents[0] *= 0.3 134 | transform = np.linalg.inv(to_origin) 135 | transform[2, 3] += 0.4 136 | return extents, transform 137 | 138 | 139 | def calc_2d_metric(rec_meshfile, gt_meshfile, align=True, n_imgs=1000): 140 | """ 141 | 2D reconstruction metric, depth L1 loss. 142 | 143 | """ 144 | H = 500 145 | W = 500 146 | focal = 300 147 | fx = focal 148 | fy = focal 149 | cx = H/2.0-0.5 150 | cy = W/2.0-0.5 151 | 152 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 153 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 154 | unseen_gt_pointcloud_file = gt_meshfile.replace('.ply', '_pc_unseen.npy') 155 | pc_unseen = np.load(unseen_gt_pointcloud_file) 156 | if align: 157 | transformation = get_align_transformation(rec_meshfile, gt_meshfile) 158 | rec_mesh = rec_mesh.transform(transformation) 159 | 160 | # get vacant area inside the room 161 | extents, transform = get_cam_position(gt_meshfile) 162 | 163 | vis = o3d.visualization.Visualizer() 164 | vis.create_window(width=W, height=H) 165 | vis.get_render_option().mesh_show_back_face = True 166 | errors = [] 167 | for i in range(n_imgs): 168 | while True: 169 | # sample view, and check if unseen region is not inside the camera view 170 | # if inside, then needs to resample 171 | up = [0, 0, -1] 172 | origin = trimesh.sample.volume_rectangular( 173 | extents, 1, transform=transform) 174 | origin = origin.reshape(-1) 175 | tx = round(random.uniform(-10000, +10000), 2) 176 | ty = round(random.uniform(-10000, +10000), 2) 177 | tz = round(random.uniform(-10000, +10000), 2) 178 | target = [tx, ty, tz] 179 | target = np.array(target)-np.array(origin) 180 | c2w = viewmatrix(target, up, origin) 181 | tmp = np.eye(4) 182 | tmp[:3, :] = c2w 183 | c2w = tmp 184 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w) 185 | if (~seen): 186 | break 187 | 188 | param = o3d.camera.PinholeCameraParameters() 189 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array 190 | 191 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic( 192 | W, H, fx, fy, cx, cy) 193 | 194 | ctr = vis.get_view_control() 195 | ctr.set_constant_z_far(20) 196 | ctr.convert_from_pinhole_camera_parameters(param) 197 | 198 | vis.add_geometry(gt_mesh, reset_bounding_box=True,) 199 | ctr.convert_from_pinhole_camera_parameters(param) 200 | vis.poll_events() 201 | vis.update_renderer() 202 | gt_depth = vis.capture_depth_float_buffer(True) 203 | gt_depth = np.asarray(gt_depth) 204 | vis.remove_geometry(gt_mesh, reset_bounding_box=True,) 205 | 206 | vis.add_geometry(rec_mesh, reset_bounding_box=True,) 207 | ctr.convert_from_pinhole_camera_parameters(param) 208 | vis.poll_events() 209 | vis.update_renderer() 210 | ours_depth = vis.capture_depth_float_buffer(True) 211 | ours_depth = np.asarray(ours_depth) 212 | vis.remove_geometry(rec_mesh, reset_bounding_box=True,) 213 | 214 | errors += [np.abs(gt_depth-ours_depth).mean()] 215 | 216 | errors = np.array(errors) 217 | # from m to cm 218 | print('Depth L1: ', errors.mean()*100) 219 | 220 | 221 | if __name__ == '__main__': 222 | #setup_seed(20) 223 | 224 | parser = argparse.ArgumentParser( 225 | description='Arguments to evaluate the reconstruction.' 226 | ) 227 | parser.add_argument('--rec_mesh', type=str, 228 | help='reconstructed mesh file path') 229 | parser.add_argument('--gt_mesh', type=str, 230 | help='ground truth mesh file path') 231 | parser.add_argument('-2d', '--metric_2d', 232 | action='store_true', help='enable 2D metric') 233 | parser.add_argument('-3d', '--metric_3d', 234 | action='store_true', help='enable 3D metric') 235 | args = parser.parse_args() 236 | if args.metric_3d: 237 | calc_3d_metric(args.rec_mesh, args.gt_mesh) 238 | 239 | if args.metric_2d: 240 | calc_2d_metric(args.rec_mesh, args.gt_mesh, n_imgs=1000) 241 | -------------------------------------------------------------------------------- /src/tools/evaluate_scannet.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/zju3dv/manhattan_sdf 2 | import numpy as np 3 | import open3d as o3d 4 | from sklearn.neighbors import KDTree 5 | import trimesh 6 | import torch 7 | import glob 8 | import os 9 | import pyrender 10 | import os 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | import sys 14 | sys.path.append(".") 15 | from src import config 16 | from src.utils.datasets import get_dataset 17 | import argparse 18 | 19 | # os.environ['PYOPENGL_PLATFORM'] = 'egl' 20 | 21 | def nn_correspondance(verts1, verts2): 22 | indices = [] 23 | distances = [] 24 | if len(verts1) == 0 or len(verts2) == 0: 25 | return indices, distances 26 | 27 | kdtree = KDTree(verts1) 28 | distances, indices = kdtree.query(verts2) 29 | distances = distances.reshape(-1) 30 | 31 | return distances 32 | 33 | 34 | def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02): 35 | pcd_trgt = o3d.geometry.PointCloud() 36 | pcd_pred = o3d.geometry.PointCloud() 37 | 38 | pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3]) 39 | pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3]) 40 | 41 | if down_sample: 42 | pcd_pred = pcd_pred.voxel_down_sample(down_sample) 43 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample) 44 | 45 | verts_pred = np.asarray(pcd_pred.points) 46 | verts_trgt = np.asarray(pcd_trgt.points) 47 | 48 | dist1 = nn_correspondance(verts_pred, verts_trgt) 49 | dist2 = nn_correspondance(verts_trgt, verts_pred) 50 | 51 | precision = np.mean((dist2 < threshold).astype('float')) 52 | recal = np.mean((dist1 < threshold).astype('float')) 53 | fscore = 2 * precision * recal / (precision + recal) 54 | metrics = { 55 | 'Acc': np.mean(dist2), 56 | 'Comp': np.mean(dist1), 57 | 'Chamfer': (np.mean(dist1) + np.mean(dist2))/2, 58 | 'Prec': precision, 59 | 'Recal': recal, 60 | 'F-score': fscore, 61 | } 62 | return metrics 63 | 64 | 65 | 66 | def update_cam(cfg): 67 | """ 68 | Update the camera intrinsics according to pre-processing config, 69 | such as resize or edge crop. 70 | """ 71 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][ 72 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy'] 73 | # resize the input images to crop_size (variable name used in lietorch) 74 | if 'crop_size' in cfg['cam']: 75 | crop_size = cfg['cam']['crop_size'] 76 | H, W, fx, fy, cx, cy = cfg['cam']['H'], cfg['cam'][ 77 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy'] 78 | sx = crop_size[1] / W 79 | sy = crop_size[0] / H 80 | fx = sx*fx 81 | fy = sy*fy 82 | cx = sx*cx 83 | cy = sy*cy 84 | W = crop_size[1] 85 | H = crop_size[0] 86 | 87 | # croping will change H, W, cx, cy, so need to change here 88 | if cfg['cam']['crop_edge'] > 0: 89 | H -= cfg['cam']['crop_edge']*2 90 | W -= cfg['cam']['crop_edge']*2 91 | cx -= cfg['cam']['crop_edge'] 92 | cy -= cfg['cam']['crop_edge'] 93 | 94 | return H, W, fx, fy, cx, cy 95 | 96 | # load pose 97 | def get_pose(cfg, args): 98 | scale = cfg['scale'] 99 | H, W, fx, fy, cx, cy = update_cam(cfg) 100 | K = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy).intrinsic_matrix # (3, 3) 101 | 102 | frame_reader = get_dataset(cfg, args, scale) 103 | 104 | pose_ls = [] 105 | for idx in range(len(frame_reader)): 106 | if idx % 10 != 0: continue 107 | _, gt_color, gt_depth, gt_c2w = frame_reader[idx] 108 | 109 | c2w = gt_c2w.cpu().numpy() 110 | 111 | if np.isfinite(c2w).any(): 112 | 113 | c2w[:3, 1] *= -1.0 114 | c2w[:3, 2] *= -1.0 115 | pose_ls.append(c2w) 116 | 117 | return pose_ls, K, H, W 118 | 119 | 120 | class Renderer(): 121 | def __init__(self, height=480, width=640): 122 | self.renderer = pyrender.OffscreenRenderer(width, height) 123 | self.scene = pyrender.Scene() 124 | 125 | def __call__(self, height, width, intrinsics, pose, mesh): 126 | self.renderer.viewport_height = height 127 | self.renderer.viewport_width = width 128 | self.scene.clear() 129 | self.scene.add(mesh) 130 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2], 131 | fx=intrinsics[0, 0], fy=intrinsics[1, 1]) 132 | 133 | self.scene.add(cam, pose=self.fix_pose(pose)) 134 | return self.renderer.render(self.scene) 135 | 136 | def fix_pose(self, pose): 137 | # 3D Rotation about the x-axis. 138 | t = np.pi 139 | c = np.cos(t) 140 | s = np.sin(t) 141 | R = np.array([[1, 0, 0], 142 | [0, c, -s], 143 | [0, s, c]]) # [[1, 0, 0], [0, -1, 0], [0, 0, -1]] 144 | axis_transform = np.eye(4) 145 | axis_transform[:3, :3] = R 146 | return pose @ axis_transform 147 | 148 | def mesh_opengl(self, mesh): 149 | return pyrender.Mesh.from_trimesh(mesh) 150 | 151 | def delete(self): 152 | self.renderer.delete() 153 | 154 | 155 | def refuse(mesh, poses, K, H, W, cfg): 156 | renderer = Renderer() 157 | mesh_opengl = renderer.mesh_opengl(mesh) 158 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 159 | voxel_length=0.01, 160 | sdf_trunc=3*0.01, 161 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 162 | ) 163 | 164 | idx = 0 165 | H, W, fx, fy, cx, cy = update_cam(cfg) 166 | 167 | for pose in tqdm(poses): 168 | 169 | intrinsic = np.eye(4) 170 | intrinsic[:3, :3] = K 171 | 172 | rgb = np.ones((H, W, 3)) 173 | rgb = (rgb * 255).astype(np.uint8) 174 | rgb = o3d.geometry.Image(rgb) 175 | _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl) 176 | 177 | depth_pred = o3d.geometry.Image(depth_pred) 178 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 179 | rgb, depth_pred, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False 180 | ) 181 | 182 | intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy) 183 | extrinsic = np.linalg.inv(pose) 184 | volume.integrate(rgbd, intrinsic, extrinsic) 185 | 186 | return volume.extract_triangle_mesh() 187 | 188 | def evaluate_mesh(): 189 | """ 190 | Evaluate the scannet mesh. 191 | 192 | """ 193 | parser = argparse.ArgumentParser( 194 | description='Arguments for running the code.' 195 | ) 196 | parser.add_argument('config', type=str, help='Path to config file.') 197 | parser.add_argument('--input_folder', type=str, 198 | help='input folder, this have higher priority, can overwrite the one in config file') 199 | parser.add_argument('--output', type=str, 200 | help='output folder, this have higher priority, can overwrite the one in config file') 201 | parser.add_argument('--space', type=int, default=10, help='the space between frames to integrate into the TSDF volume.') 202 | 203 | args = parser.parse_args() 204 | cfg = config.load_config(args.config, 'configs/df_prior.yaml') 205 | 206 | scene_id = cfg['data']['id'] 207 | 208 | input_file = f"output/scannet/scans/scene{scene_id:04d}_00/mesh/final_mesh.ply" 209 | mesh = trimesh.load_mesh(input_file) 210 | mesh.invert() # change noraml of mesh 211 | 212 | poses, K, H, W = get_pose(cfg, args) 213 | 214 | # refuse mesh 215 | mesh = refuse(mesh, poses, K, H, W, cfg) 216 | 217 | # save mesh 218 | out_mesh_path = f"output/scannet/scans/scene{scene_id:04d}_00/mesh/final_mesh_refused.ply" 219 | o3d.io.write_triangle_mesh(out_mesh_path, mesh) 220 | 221 | mesh = trimesh.load(out_mesh_path) 222 | gt_mesh = os.path.join("./Datasets/scannet/GTmesh_lowres", f"{scene_id:04d}_00.obj") 223 | gt_mesh = trimesh.load(gt_mesh) 224 | 225 | metrics = evaluate(mesh, gt_mesh) 226 | print(metrics) 227 | 228 | 229 | if __name__ == "__main__": 230 | evaluate_mesh() -------------------------------------------------------------------------------- /src/utils/Logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | class Logger(object): 7 | """ 8 | Save checkpoints to file. 9 | 10 | """ 11 | 12 | def __init__(self, cfg, args, slam 13 | ): 14 | self.verbose = slam.verbose 15 | self.ckptsdir = slam.ckptsdir 16 | self.shared_c = slam.shared_c 17 | self.gt_c2w_list = slam.gt_c2w_list 18 | self.shared_decoders = slam.shared_decoders 19 | self.estimate_c2w_list = slam.estimate_c2w_list 20 | self.tsdf_volume = slam.tsdf_volume_shared 21 | 22 | def log(self, idx, keyframe_dict, keyframe_list, selected_keyframes=None): 23 | path = os.path.join(self.ckptsdir, '{:05d}.tar'.format(idx)) 24 | torch.save({ 25 | 'c': self.shared_c, 26 | 'decoder_state_dict': self.shared_decoders.state_dict(), 27 | 'gt_c2w_list': self.gt_c2w_list, 28 | 'estimate_c2w_list': self.estimate_c2w_list, 29 | 'keyframe_list': keyframe_list, 30 | 'keyframe_dict': keyframe_dict, # to save keyframe_dict into ckpt, uncomment this line 31 | 'selected_keyframes': selected_keyframes, 32 | 'idx': idx, 33 | 'tsdf_volume': self.tsdf_volume, 34 | }, path, _use_new_zipfile_serialization=False) 35 | 36 | if self.verbose: 37 | print('Saved checkpoints at', path) 38 | -------------------------------------------------------------------------------- /src/utils/Mesher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import skimage 4 | import torch 5 | import torch.nn.functional as F 6 | import trimesh 7 | from packaging import version 8 | from src.utils.datasets import get_dataset 9 | from src.common import normalize_3d_coordinate 10 | import matplotlib.pyplot as plt 11 | 12 | class Mesher(object): 13 | 14 | def __init__(self, cfg, args, slam, points_batch_size=500000, ray_batch_size=100000): 15 | """ 16 | Mesher class, given a scene representation, the mesher extracts the mesh from it. 17 | 18 | Args: 19 | cfg (dict): parsed config dict. 20 | args (class 'argparse.Namespace'): argparse arguments. 21 | slam (class DF_Prior): DF_Prior main class. 22 | points_batch_size (int): maximum points size for query in one batch. 23 | Used to alleviate GPU memeory usage. Defaults to 500000. 24 | ray_batch_size (int): maximum ray size for query in one batch. 25 | Used to alleviate GPU memeory usage. Defaults to 100000. 26 | """ 27 | self.points_batch_size = points_batch_size 28 | self.ray_batch_size = ray_batch_size 29 | self.renderer = slam.renderer 30 | self.scale = cfg['scale'] 31 | self.occupancy = cfg['occupancy'] 32 | 33 | self.resolution = cfg['meshing']['resolution'] 34 | self.level_set = cfg['meshing']['level_set'] 35 | self.clean_mesh_bound_scale = cfg['meshing']['clean_mesh_bound_scale'] 36 | self.remove_small_geometry_threshold = cfg['meshing']['remove_small_geometry_threshold'] 37 | self.color_mesh_extraction_method = cfg['meshing']['color_mesh_extraction_method'] 38 | self.get_largest_components = cfg['meshing']['get_largest_components'] 39 | self.depth_test = cfg['meshing']['depth_test'] 40 | 41 | self.bound = slam.bound 42 | self.verbose = slam.verbose 43 | 44 | 45 | self.marching_cubes_bound = torch.from_numpy( 46 | np.array(cfg['mapping']['marching_cubes_bound']) * self.scale) 47 | 48 | self.frame_reader = get_dataset(cfg, args, self.scale, device='cpu') 49 | self.n_img = len(self.frame_reader) 50 | 51 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = slam.H, slam.W, slam.fx, slam.fy, slam.cx, slam.cy 52 | 53 | self.sample_mode = 'bilinear' 54 | self.tsdf_bnds = slam.tsdf_bnds 55 | 56 | 57 | 58 | def point_masks(self, input_points, keyframe_dict, estimate_c2w_list, 59 | idx, device, get_mask_use_all_frames=False): 60 | """ 61 | Split the input points into seen, unseen, and forcast, 62 | according to the estimated camera pose and depth image. 63 | 64 | Args: 65 | input_points (tensor): input points. 66 | keyframe_dict (list): list of keyframe info dictionary. 67 | estimate_c2w_list (tensor): estimated camera pose. 68 | idx (int): current frame index. 69 | device (str): device name to compute on. 70 | 71 | Returns: 72 | seen_mask (tensor): the mask for seen area. 73 | forecast_mask (tensor): the mask for forecast area. 74 | unseen_mask (tensor): the mask for unseen area. 75 | """ 76 | H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy 77 | if not isinstance(input_points, torch.Tensor): 78 | input_points = torch.from_numpy(input_points) 79 | input_points = input_points.clone().detach() 80 | seen_mask_list = [] 81 | forecast_mask_list = [] 82 | unseen_mask_list = [] 83 | for i, pnts in enumerate( 84 | torch.split(input_points, self.points_batch_size, dim=0)): 85 | points = pnts.to(device).float() 86 | # should divide the points into three parts, seen and forecast and unseen 87 | # seen: union of all the points in the viewing frustum of keyframes 88 | # forecast: union of all the points in the extended edge of the viewing frustum of keyframes 89 | # unseen: all the other points 90 | 91 | seen_mask = torch.zeros((points.shape[0])).bool().to(device) 92 | forecast_mask = torch.zeros((points.shape[0])).bool().to(device) 93 | if get_mask_use_all_frames: 94 | for i in range(0, idx + 1, 1): 95 | c2w = estimate_c2w_list[i].cpu().numpy() 96 | w2c = np.linalg.inv(c2w) 97 | w2c = torch.from_numpy(w2c).to(device).float() 98 | ones = torch.ones_like( 99 | points[:, 0]).reshape(-1, 1).to(device) 100 | homo_points = torch.cat([points, ones], dim=1).reshape( 101 | -1, 4, 1).to(device).float() # (N, 4) 102 | # (N, 4, 1)=(4,4)*(N, 4, 1) 103 | cam_cord_homo = w2c @ homo_points 104 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1) 105 | 106 | K = torch.from_numpy( 107 | np.array([[fx, .0, cx], [.0, fy, cy], 108 | [.0, .0, 1.0]]).reshape(3, 3)).to(device) 109 | cam_cord[:, 0] *= -1 110 | uv = K.float() @ cam_cord.float() 111 | z = uv[:, -1:] + 1e-8 112 | uv = uv[:, :2] / z 113 | uv = uv.float() 114 | edge = 0 115 | cur_mask_seen = (uv[:, 0] < W - edge) & ( 116 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge) 117 | cur_mask_seen = cur_mask_seen & (z[:, :, 0] < 0) 118 | 119 | edge = -1000 120 | cur_mask_forecast = (uv[:, 0] < W - edge) & ( 121 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge) 122 | cur_mask_forecast = cur_mask_forecast & (z[:, :, 0] < 0) 123 | 124 | # forecast 125 | cur_mask_forecast = cur_mask_forecast.reshape(-1) 126 | # seen 127 | cur_mask_seen = cur_mask_seen.reshape(-1) 128 | 129 | seen_mask |= cur_mask_seen 130 | forecast_mask |= cur_mask_forecast 131 | else: 132 | for keyframe in keyframe_dict: 133 | c2w = keyframe['est_c2w'].cpu().numpy() 134 | w2c = np.linalg.inv(c2w) 135 | w2c = torch.from_numpy(w2c).to(device).float() 136 | ones = torch.ones_like( 137 | points[:, 0]).reshape(-1, 1).to(device) 138 | homo_points = torch.cat([points, ones], dim=1).reshape( 139 | -1, 4, 1).to(device).float() 140 | cam_cord_homo = w2c @ homo_points 141 | cam_cord = cam_cord_homo[:, :3] 142 | 143 | K = torch.from_numpy( 144 | np.array([[fx, .0, cx], [.0, fy, cy], 145 | [.0, .0, 1.0]]).reshape(3, 3)).to(device) 146 | cam_cord[:, 0] *= -1 147 | uv = K.float() @ cam_cord.float() 148 | z = uv[:, -1:] + 1e-8 149 | uv = uv[:, :2] / z 150 | uv = uv.float() 151 | edge = 0 152 | cur_mask_seen = (uv[:, 0] < W - edge) & ( 153 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge) 154 | cur_mask_seen = cur_mask_seen & (z[:, :, 0] < 0) 155 | 156 | edge = -1000 157 | cur_mask_forecast = (uv[:, 0] < W - edge) & ( 158 | uv[:, 0] > edge) & (uv[:, 1] < H - edge) & (uv[:, 1] > edge) 159 | cur_mask_forecast = cur_mask_forecast & (z[:, :, 0] < 0) 160 | 161 | if self.depth_test: 162 | gt_depth = keyframe['depth'].to( 163 | device).reshape(1, 1, H, W) 164 | vgrid = uv.reshape(1, 1, -1, 2) 165 | # normalized to [-1, 1] 166 | vgrid[..., 0] = (vgrid[..., 0] / (W-1) * 2.0 - 1.0) 167 | vgrid[..., 1] = (vgrid[..., 1] / (H-1) * 2.0 - 1.0) 168 | depth_sample = F.grid_sample( 169 | gt_depth, vgrid, padding_mode='zeros', align_corners=True) 170 | depth_sample = depth_sample.reshape(-1) 171 | max_depth = torch.max(depth_sample) 172 | # forecast 173 | cur_mask_forecast = cur_mask_forecast.reshape(-1) 174 | proj_depth_forecast = -cam_cord[cur_mask_forecast, 175 | 2].reshape(-1) 176 | cur_mask_forecast[cur_mask_forecast.clone()] &= proj_depth_forecast < max_depth 177 | # seen 178 | cur_mask_seen = cur_mask_seen.reshape(-1) 179 | proj_depth_seen = - cam_cord[cur_mask_seen, 2].reshape(-1) 180 | cur_mask_seen[cur_mask_seen.clone()] &= \ 181 | (proj_depth_seen < depth_sample[cur_mask_seen]+2.4) \ 182 | & (depth_sample[cur_mask_seen]-2.4 < proj_depth_seen) 183 | else: 184 | max_depth = torch.max(keyframe['depth'])*1.1 185 | 186 | # forecast 187 | cur_mask_forecast = cur_mask_forecast.reshape(-1) 188 | proj_depth_forecast = -cam_cord[cur_mask_forecast, 189 | 2].reshape(-1) 190 | cur_mask_forecast[ 191 | cur_mask_forecast.clone()] &= proj_depth_forecast < max_depth 192 | 193 | # seen 194 | cur_mask_seen = cur_mask_seen.reshape(-1) 195 | proj_depth_seen = - \ 196 | cam_cord[cur_mask_seen, 2].reshape(-1) 197 | cur_mask_seen[cur_mask_seen.clone( 198 | )] &= proj_depth_seen < max_depth 199 | 200 | seen_mask |= cur_mask_seen 201 | forecast_mask |= cur_mask_forecast 202 | 203 | forecast_mask &= ~seen_mask 204 | unseen_mask = ~(seen_mask | forecast_mask) 205 | 206 | seen_mask = seen_mask.cpu().numpy() 207 | forecast_mask = forecast_mask.cpu().numpy() 208 | unseen_mask = unseen_mask.cpu().numpy() 209 | 210 | seen_mask_list.append(seen_mask) 211 | forecast_mask_list.append(forecast_mask) 212 | unseen_mask_list.append(unseen_mask) 213 | 214 | seen_mask = np.concatenate(seen_mask_list, axis=0) 215 | forecast_mask = np.concatenate(forecast_mask_list, axis=0) 216 | unseen_mask = np.concatenate(unseen_mask_list, axis=0) 217 | return seen_mask, forecast_mask, unseen_mask 218 | 219 | def get_bound_from_frames(self, keyframe_dict, scale=1): 220 | """ 221 | Get the scene bound (convex hull), 222 | using sparse estimated camera poses and corresponding depth images. 223 | 224 | Args: 225 | keyframe_dict (list): list of keyframe info dictionary. 226 | scale (float): scene scale. 227 | 228 | Returns: 229 | return_mesh (trimesh.Trimesh): the convex hull. 230 | """ 231 | 232 | H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy 233 | 234 | if version.parse(o3d.__version__) >= version.parse('0.13.0'): 235 | # for new version as provided in environment.yaml 236 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 237 | voxel_length=4.0 * scale / 512.0, 238 | sdf_trunc=0.04 * scale, 239 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8) 240 | else: 241 | # for lower version 242 | volume = o3d.integration.ScalableTSDFVolume( 243 | voxel_length=4.0 * scale / 512.0, 244 | sdf_trunc=0.04 * scale, 245 | color_type=o3d.integration.TSDFVolumeColorType.RGB8) 246 | cam_points = [] 247 | for keyframe in keyframe_dict: 248 | c2w = keyframe['est_c2w'].cpu().numpy() 249 | # convert to open3d camera pose 250 | c2w[:3, 1] *= -1.0 251 | c2w[:3, 2] *= -1.0 252 | w2c = np.linalg.inv(c2w) 253 | cam_points.append(c2w[:3, 3]) 254 | depth = keyframe['depth'].cpu().numpy() 255 | color = keyframe['color'].cpu().numpy() 256 | 257 | depth = o3d.geometry.Image(depth.astype(np.float32)) 258 | color = o3d.geometry.Image(np.array( 259 | (color * 255).astype(np.uint8))) 260 | 261 | intrinsic = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy) 262 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 263 | color, 264 | depth, 265 | depth_scale=1, 266 | depth_trunc=1000, 267 | convert_rgb_to_intensity=False) 268 | volume.integrate(rgbd, intrinsic, w2c) 269 | 270 | cam_points = np.stack(cam_points, axis=0) 271 | mesh = volume.extract_triangle_mesh() 272 | mesh_points = np.array(mesh.vertices) 273 | points = np.concatenate([cam_points, mesh_points], axis=0) 274 | o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points)) 275 | mesh, _ = o3d_pc.compute_convex_hull() 276 | mesh.compute_vertex_normals() 277 | if version.parse(o3d.__version__) >= version.parse('0.13.0'): 278 | mesh = mesh.scale(self.clean_mesh_bound_scale, mesh.get_center()) 279 | else: 280 | mesh = mesh.scale(self.clean_mesh_bound_scale, center=True) 281 | points = np.array(mesh.vertices) 282 | faces = np.array(mesh.triangles) 283 | return_mesh = trimesh.Trimesh(vertices=points, faces=faces) 284 | return return_mesh 285 | 286 | def eval_points(self, p, decoders, tsdf_volume, tsdf_bnds, c=None, stage='color', device='cuda:0'): 287 | """ 288 | Evaluates the occupancy and/or color value for the points. 289 | 290 | Args: 291 | p (tensor, N*3): point coordinates. 292 | decoders (nn.module decoders): decoders. 293 | tsdf_volume (tensor): tsdf volume. 294 | tsdf_bnds (tensor): tsdf volume bounds. 295 | c (dicts, optional): feature grids. Defaults to None. 296 | stage (str, optional): query stage, corresponds to different levels. Defaults to 'color'. 297 | device (str, optional): device name to compute on. Defaults to 'cuda:0'. 298 | 299 | Returns: 300 | ret (tensor): occupancy (and color) value of input points. 301 | """ 302 | 303 | p_split = torch.split(p, self.points_batch_size) 304 | bound = self.bound 305 | rets = [] 306 | 307 | for pi in p_split: 308 | # mask for points out of bound 309 | mask_x = (pi[:, 0] < bound[0][1]) & (pi[:, 0] > bound[0][0]) 310 | mask_y = (pi[:, 1] < bound[1][1]) & (pi[:, 1] > bound[1][0]) 311 | mask_z = (pi[:, 2] < bound[2][1]) & (pi[:, 2] > bound[2][0]) 312 | mask = mask_x & mask_y & mask_z 313 | 314 | pi = pi.unsqueeze(0) 315 | ret, _ = decoders(pi, c_grid=c, tsdf_volume=tsdf_volume, tsdf_bnds=tsdf_bnds, stage=stage) 316 | 317 | ret = ret.squeeze(0) 318 | if len(ret.shape) == 1 and ret.shape[0] == 4: 319 | ret = ret.unsqueeze(0) 320 | 321 | ret[~mask, 3] = 100 322 | rets.append(ret) 323 | 324 | ret = torch.cat(rets, dim=0) 325 | 326 | return ret 327 | 328 | def sample_grid_tsdf(self, p, tsdf_volume, device='cuda:0'): 329 | 330 | p_nor = normalize_3d_coordinate(p.clone(), self.tsdf_bnds) 331 | p_nor = p_nor.unsqueeze(0) 332 | vgrid = p_nor[:, :, None, None].float() 333 | # acutally trilinear interpolation if mode = 'bilinear' 334 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True, 335 | mode='bilinear').squeeze(-1).squeeze(-1) 336 | return tsdf_value 337 | 338 | 339 | def eval_points_tsdf(self, p, tsdf_volume, device='cuda:0'): 340 | """ 341 | Evaluates the occupancy and/or color value for the points. 342 | 343 | Args: 344 | p (tensor, N*3): Point coordinates. 345 | tsdf_volume (tensor): tsdf volume. 346 | 347 | Returns: 348 | ret (tensor): tsdf value of input points. 349 | """ 350 | 351 | p_split = torch.split(p, self.points_batch_size) 352 | tsdf_vals = [] 353 | for pi in p_split: 354 | pi = pi.unsqueeze(0) 355 | tsdf_volume_tensor = tsdf_volume 356 | 357 | tsdf_val = self.sample_grid_tsdf(pi, tsdf_volume_tensor, device) 358 | tsdf_val = tsdf_val.squeeze(0) 359 | tsdf_vals.append(tsdf_val) 360 | 361 | tsdf_values = torch.cat(tsdf_vals, dim=1) 362 | return tsdf_values 363 | 364 | 365 | def get_grid_uniform(self, resolution): 366 | """ 367 | Get query point coordinates for marching cubes. 368 | 369 | Args: 370 | resolution (int): marching cubes resolution. 371 | 372 | Returns: 373 | (dict): points coordinates and sampled coordinates for each axis. 374 | """ 375 | bound = self.marching_cubes_bound 376 | 377 | padding = 0.05 378 | x = np.linspace(bound[0][0] - padding, bound[0][1] + padding, 379 | resolution) 380 | y = np.linspace(bound[1][0] - padding, bound[1][1] + padding, 381 | resolution) 382 | z = np.linspace(bound[2][0] - padding, bound[2][1] + padding, 383 | resolution) 384 | 385 | xx, yy, zz = np.meshgrid(x, y, z) 386 | grid_points = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T 387 | grid_points = torch.tensor(np.vstack( 388 | [xx.ravel(), yy.ravel(), zz.ravel()]).T, 389 | dtype=torch.float) 390 | 391 | 392 | 393 | return {"grid_points": grid_points, "xyz": [x, y, z]} 394 | 395 | def get_mesh(self, 396 | mesh_out_file, 397 | c, 398 | decoders, 399 | keyframe_dict, 400 | estimate_c2w_list, 401 | idx, 402 | tsdf_volume, 403 | device='cuda:0', 404 | color=True, 405 | clean_mesh=True, 406 | get_mask_use_all_frames=False): 407 | """ 408 | Extract mesh from scene representation and save mesh to file. 409 | 410 | Args: 411 | mesh_out_file (str): output mesh filename. 412 | c (dicts): feature grids. 413 | decoders (nn.module): decoders. 414 | keyframe_dict (list): list of keyframe info. 415 | estimate_c2w_list (tensor): estimated camera pose. 416 | idx (int): current processed camera ID. 417 | tsdf volume (tensor): tsdf volume. 418 | device (str, optional): device name to compute on. Defaults to 'cuda:0'. 419 | color (bool, optional): whether to extract colored mesh. Defaults to True. 420 | clean_mesh (bool, optional): whether to clean the output mesh 421 | (remove outliers outside the convexhull and small geometry noise). 422 | Defaults to True. 423 | get_mask_use_all_frames (bool, optional): 424 | whether to use all frames or just keyframes when getting the seen/unseen mask. Defaults to False. 425 | """ 426 | with torch.no_grad(): 427 | 428 | grid = self.get_grid_uniform(self.resolution) 429 | points = grid['grid_points'] 430 | points = points.to(device) 431 | eval_tsdf_volume = tsdf_volume 432 | 433 | mesh_bound = self.get_bound_from_frames( 434 | keyframe_dict, self.scale) 435 | z = [] 436 | mask = [] 437 | for i, pnts in enumerate(torch.split(points, self.points_batch_size, dim=0)): 438 | mask.append(mesh_bound.contains(pnts.cpu().numpy())) 439 | mask = np.concatenate(mask, axis=0) 440 | for i, pnts in enumerate(torch.split(points, self.points_batch_size, dim=0)): 441 | eval_tsdf = self.eval_points_tsdf(pnts, eval_tsdf_volume, device) 442 | eval_tsdf_mask = ((eval_tsdf > -1.0+1e-4) & (eval_tsdf < 1.0-1e-4)).cpu().numpy() 443 | ret = self.eval_points(pnts, decoders, tsdf_volume, self.tsdf_bnds, c, 'high', device) 444 | ret = ret.cpu().numpy()[:, -1] 445 | 446 | eval_tsdf_mask = eval_tsdf_mask.reshape(ret.shape) 447 | z.append(ret) 448 | 449 | z = np.concatenate(z, axis=0) 450 | z[~mask] = 100 451 | z = z.astype(np.float32) 452 | 453 | z_uni_m = z.reshape( 454 | grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 455 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]) 456 | 457 | print('begin marching cube...') 458 | combine_occ_tsdf = z_uni_m 459 | 460 | try: 461 | if version.parse( 462 | skimage.__version__) > version.parse('0.15.0'): 463 | # for new version as provided in environment.yaml 464 | verts, faces, normals, values = skimage.measure.marching_cubes( 465 | volume=combine_occ_tsdf, 466 | level=self.level_set, 467 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 468 | grid['xyz'][1][2] - grid['xyz'][1][1], 469 | grid['xyz'][2][2] - grid['xyz'][2][1])) 470 | else: 471 | # for lower version 472 | verts, faces, normals, values = skimage.measure.marching_cubes_lewiner( 473 | volume=combine_occ_tsdf, 474 | level=self.level_set, 475 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 476 | grid['xyz'][1][2] - grid['xyz'][1][1], 477 | grid['xyz'][2][2] - grid['xyz'][2][1])) 478 | except: 479 | print( 480 | 'marching_cubes error. Possibly no surface extracted from the level set.' 481 | ) 482 | return 483 | 484 | # convert back to world coordinates 485 | vertices = verts + np.array( 486 | [grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 487 | 488 | if clean_mesh: 489 | points = vertices 490 | mesh = trimesh.Trimesh(vertices=vertices, 491 | faces=faces, 492 | process=False) 493 | seen_mask, _, unseen_mask = self.point_masks( 494 | points, keyframe_dict, estimate_c2w_list, idx, device=device, 495 | get_mask_use_all_frames=get_mask_use_all_frames) 496 | unseen_mask = ~seen_mask 497 | face_mask = unseen_mask[mesh.faces].all(axis=1) 498 | mesh.update_faces(~face_mask) 499 | 500 | # get connected components 501 | components = mesh.split(only_watertight=False) 502 | if self.get_largest_components: 503 | areas = np.array([c.area for c in components], dtype=np.float) 504 | mesh = components[areas.argmax()] 505 | else: 506 | new_components = [] 507 | for comp in components: 508 | if comp.area > self.remove_small_geometry_threshold * self.scale * self.scale: 509 | new_components.append(comp) 510 | mesh = trimesh.util.concatenate(new_components) 511 | vertices = mesh.vertices 512 | faces = mesh.faces 513 | 514 | if color: 515 | if self.color_mesh_extraction_method == 'direct_point_query': 516 | # color is extracted by passing the coordinates of mesh vertices through the network 517 | points = torch.from_numpy(vertices) 518 | z = [] 519 | for i, pnts in enumerate( 520 | torch.split(points, self.points_batch_size, dim=0)): 521 | ret = self.eval_points( 522 | pnts.to(device).float(), decoders, tsdf_volume, self.tsdf_bnds, c, 'color', 523 | device) 524 | z_color = ret.cpu()[..., :3] 525 | z.append(z_color) 526 | z = torch.cat(z, axis=0) 527 | vertex_colors = z.numpy() 528 | 529 | vertex_colors = np.clip(vertex_colors, 0, 1) * 255 530 | vertex_colors = vertex_colors.astype(np.uint8) 531 | 532 | 533 | else: 534 | vertex_colors = None 535 | 536 | vertices /= self.scale 537 | mesh = trimesh.Trimesh(vertices, faces, vertex_colors=vertex_colors) 538 | mesh.export(mesh_out_file) 539 | if self.verbose: 540 | print('Saved mesh at', mesh_out_file) 541 | 542 | return z_uni_m 543 | -------------------------------------------------------------------------------- /src/utils/Renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.common import get_rays, raw2outputs_nerf_color, sample_pdf, normalize_3d_coordinate 3 | import torch.nn.functional as F 4 | 5 | 6 | class Renderer(object): 7 | def __init__(self, cfg, args, slam, points_batch_size=500000, ray_batch_size=100000): 8 | self.ray_batch_size = ray_batch_size 9 | self.points_batch_size = points_batch_size 10 | 11 | self.lindisp = cfg['rendering']['lindisp'] 12 | self.perturb = cfg['rendering']['perturb'] 13 | self.N_samples = cfg['rendering']['N_samples'] 14 | self.N_surface = cfg['rendering']['N_surface'] 15 | self.N_importance = cfg['rendering']['N_importance'] 16 | 17 | self.scale = cfg['scale'] 18 | self.occupancy = cfg['occupancy'] 19 | self.bound = slam.bound 20 | self.sample_mode = 'bilinear' 21 | self.tsdf_bnds = slam.vol_bnds 22 | 23 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = slam.H, slam.W, slam.fx, slam.fy, slam.cx, slam.cy 24 | 25 | self.resolution = cfg['meshing']['resolution'] 26 | 27 | def eval_points(self, p, decoders, tsdf_volume, tsdf_bnds, c=None, stage='color', device='cuda:0'): 28 | """ 29 | Evaluates the occupancy and/or color value for the points. 30 | 31 | Args: 32 | p (tensor, N*3): Point coordinates. 33 | decoders (nn.module decoders): Decoders. 34 | tsdf_volume (tensor): tsdf volume. 35 | tsdf_bnds (tensor): tsdf volume bounds. 36 | c (dicts, optional): Feature grids. Defaults to None. 37 | stage (str, optional): Query stage, corresponds to different levels. Defaults to 'color'. 38 | device (str, optional): CUDA device. Defaults to 'cuda:0'. 39 | 40 | Returns: 41 | ret (tensor): occupancy (and color) value of input points. 42 | """ 43 | 44 | p_split = torch.split(p, self.points_batch_size) 45 | bound = self.bound 46 | rets = [] 47 | weights = [] 48 | 49 | for pi in p_split: 50 | # mask for points out of bound 51 | mask_x = (pi[:, 0] < bound[0][1]) & (pi[:, 0] > bound[0][0]) 52 | mask_y = (pi[:, 1] < bound[1][1]) & (pi[:, 1] > bound[1][0]) 53 | mask_z = (pi[:, 2] < bound[2][1]) & (pi[:, 2] > bound[2][0]) 54 | mask = mask_x & mask_y & mask_z 55 | 56 | pi = pi.unsqueeze(0) 57 | ret, w = decoders(pi, c_grid=c, tsdf_volume=tsdf_volume, tsdf_bnds=tsdf_bnds, stage=stage) 58 | ret = ret.squeeze(0) 59 | 60 | 61 | if len(ret.shape) == 1 and ret.shape[0] == 4: 62 | ret = ret.unsqueeze(0) 63 | 64 | ret[~mask, 3] = 100 65 | rets.append(ret) 66 | weights.append(w) 67 | 68 | ret = torch.cat(rets, dim=0) 69 | weight = torch.cat(weights, dim=0) 70 | 71 | return ret, weight 72 | 73 | def sample_grid_tsdf(self, p, tsdf_volume, device='cuda:0'): 74 | 75 | p_nor = normalize_3d_coordinate(p.clone(), self.tsdf_bnds) 76 | p_nor = p_nor.unsqueeze(0) 77 | vgrid = p_nor[:, :, None, None].float() 78 | # acutally trilinear interpolation if mode = 'bilinear' 79 | tsdf_value = F.grid_sample(tsdf_volume.to(device), vgrid.to(device), padding_mode='border', align_corners=True, 80 | mode='bilinear').squeeze(-1).squeeze(-1) 81 | return tsdf_value 82 | 83 | 84 | def eval_points_tsdf(self, p, tsdf_volume, device='cuda:0'): 85 | """ 86 | Evaluates the occupancy and/or color value for the points. 87 | 88 | Args: 89 | p (tensor, N*3): Point coordinates. 90 | 91 | 92 | Returns: 93 | ret (tensor): tsdf value of input points. 94 | """ 95 | 96 | p_split = torch.split(p, self.points_batch_size) 97 | tsdf_vals = [] 98 | for pi in p_split: 99 | pi = pi.unsqueeze(0) 100 | tsdf_volume_tensor = tsdf_volume 101 | 102 | tsdf_val = self.sample_grid_tsdf(pi, tsdf_volume_tensor, device) 103 | tsdf_val = tsdf_val.squeeze(0) 104 | tsdf_vals.append(tsdf_val) 105 | 106 | tsdf_values = torch.cat(tsdf_vals, dim=1) 107 | return tsdf_values 108 | 109 | 110 | def render_batch_ray(self, c, decoders, rays_d, rays_o, device, tsdf_volume, tsdf_bnds, stage, gt_depth=None): 111 | """ 112 | Render color, depth and uncertainty of a batch of rays. 113 | 114 | Args: 115 | c (dict): feature grids. 116 | decoders (nn.module): decoders. 117 | rays_d (tensor, N*3): rays direction. 118 | rays_o (tensor, N*3): rays origin. 119 | device (str): device name to compute on. 120 | tsdf_volume (tensor): tsdf volume. 121 | tsdf_bnds (tensor): tsdf volume bounds. 122 | stage (str): query stage. 123 | gt_depth (tensor, optional): sensor depth image. Defaults to None. 124 | 125 | Returns: 126 | depth (tensor): rendered depth. 127 | uncertainty (tensor): rendered uncertainty. 128 | color (tensor): rendered color. 129 | weight (tensor): attention weight. 130 | """ 131 | eval_tsdf_volume = tsdf_volume 132 | 133 | 134 | N_samples = self.N_samples 135 | N_surface = self.N_surface 136 | N_importance = self.N_importance 137 | 138 | N_rays = rays_o.shape[0] 139 | 140 | if gt_depth is None: 141 | N_surface = 0 142 | near = 0.01 143 | else: 144 | gt_depth = gt_depth.reshape(-1, 1) 145 | gt_depth_samples = gt_depth.repeat(1, N_samples) 146 | near = gt_depth_samples*0.01 147 | 148 | with torch.no_grad(): 149 | det_rays_o = rays_o.clone().detach().unsqueeze(-1) # (N, 3, 1) 150 | det_rays_d = rays_d.clone().detach().unsqueeze(-1) # (N, 3, 1) 151 | t = (self.bound.unsqueeze(0).to(device) - 152 | det_rays_o)/det_rays_d # (N, 3, 2) 153 | far_bb, _ = torch.min(torch.max(t, dim=2)[0], dim=1) 154 | far_bb = far_bb.unsqueeze(-1) 155 | far_bb += 0.01 156 | 157 | if gt_depth is not None: 158 | # in case the bound is too large 159 | far = torch.clamp(far_bb, 0, torch.max(gt_depth*1.2)) 160 | 161 | else: 162 | far = far_bb 163 | if N_surface > 0: 164 | if False: 165 | # this naive implementation downgrades performance 166 | gt_depth_surface = gt_depth.repeat(1, N_surface) 167 | t_vals_surface = torch.linspace( 168 | 0., 1., steps=N_surface).to(device) 169 | z_vals_surface = 0.95*gt_depth_surface * \ 170 | (1.-t_vals_surface) + 1.05 * \ 171 | gt_depth_surface * (t_vals_surface) 172 | else: 173 | # since we want to colorize even on regions with no depth sensor readings, 174 | # meaning colorize on interpolated geometry region, 175 | # we sample all pixels (not using depth mask) for color loss. 176 | # Therefore, for pixels with non-zero depth value, we sample near the surface, 177 | # since it is not a good idea to sample 16 points near (half even behind) camera, 178 | # for pixels with zero depth value, we sample uniformly from camera to max_depth. 179 | gt_none_zero_mask = gt_depth > 0 180 | gt_none_zero = gt_depth[gt_none_zero_mask] 181 | gt_none_zero = gt_none_zero.unsqueeze(-1) 182 | gt_depth_surface = gt_none_zero.repeat(1, N_surface) 183 | t_vals_surface = torch.linspace( 184 | 0., 1., steps=N_surface).double().to(device) 185 | # emperical range 0.05*depth 186 | z_vals_surface_depth_none_zero = 0.95*gt_depth_surface * \ 187 | (1.-t_vals_surface) + 1.05 * \ 188 | gt_depth_surface * (t_vals_surface) 189 | z_vals_surface = torch.zeros( 190 | gt_depth.shape[0], N_surface).to(device).double() 191 | gt_none_zero_mask = gt_none_zero_mask.squeeze(-1) 192 | z_vals_surface[gt_none_zero_mask, 193 | :] = z_vals_surface_depth_none_zero 194 | near_surface = 0.001 195 | far_surface = torch.max(gt_depth) 196 | z_vals_surface_depth_zero = near_surface * \ 197 | (1.-t_vals_surface) + far_surface * (t_vals_surface) 198 | z_vals_surface_depth_zero.unsqueeze( 199 | 0).repeat((~gt_none_zero_mask).sum(), 1) 200 | z_vals_surface[~gt_none_zero_mask, 201 | :] = z_vals_surface_depth_zero 202 | 203 | t_vals = torch.linspace(0., 1., steps=N_samples, device=device) 204 | 205 | if not self.lindisp: 206 | z_vals = near * (1.-t_vals) + far * (t_vals) 207 | else: 208 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 209 | 210 | if self.perturb > 0.: 211 | # get intervals between samples 212 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 213 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 214 | lower = torch.cat([z_vals[..., :1], mids], -1) 215 | # stratified samples in those intervals 216 | t_rand = torch.rand(z_vals.shape).to(device) 217 | z_vals = lower + (upper - lower) * t_rand 218 | 219 | if N_surface > 0: 220 | z_vals, _ = torch.sort( 221 | torch.cat([z_vals, z_vals_surface.double()], -1), -1) 222 | 223 | pts = rays_o[..., None, :] + rays_d[..., None, :] * \ 224 | z_vals[..., :, None] # [N_rays, N_samples+N_surface, 3] 225 | pointsf = pts.reshape(-1, 3) 226 | 227 | raw, weight = self.eval_points(pointsf, decoders, tsdf_volume, tsdf_bnds, c, stage, device) 228 | raw = raw.reshape(N_rays, N_samples+N_surface, -1) 229 | weight = weight.reshape(N_rays, N_samples+N_surface, -1) 230 | 231 | 232 | depth, uncertainty, color, weights = raw2outputs_nerf_color( 233 | raw, z_vals, rays_d, occupancy=self.occupancy, device=device) 234 | 235 | if N_importance > 0: 236 | z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 237 | z_samples = sample_pdf( 238 | z_vals_mid, weights[..., 1:-1], N_importance, det=(self.perturb == 0.), device=device) 239 | z_samples = z_samples.detach() 240 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 241 | 242 | pts = rays_o[..., None, :] + \ 243 | rays_d[..., None, :] * z_vals[..., :, None] 244 | pts = pts.reshape(-1, 3) 245 | 246 | raw, weight = self.eval_points(pointsf, decoders, tsdf_volume, tsdf_bnds, c, stage, device) 247 | raw = raw.reshape(N_rays, N_samples+N_surface, -1) 248 | weight = weight.reshape(N_rays, N_samples+N_surface, -1) 249 | 250 | depth, uncertainty, color, weights = raw2outputs_nerf_color( 251 | raw, z_vals, rays_d, occupancy=self.occupancy, device=device) 252 | return depth, uncertainty, color, weight 253 | 254 | 255 | return depth, uncertainty, color, weight 256 | 257 | 258 | def render_img(self, c, decoders, c2w, device, tsdf_volume, tsdf_bnds, stage, gt_depth=None): 259 | """ 260 | Renders out depth, uncertainty, and color images. 261 | 262 | Args: 263 | c (dict): feature grids. 264 | decoders (nn.module): decoders. 265 | c2w (tensor): camera to world matrix of current frame. 266 | device (str): device name to compute on. 267 | tsdf_volume (tensor): tsdf volume. 268 | tsdf_bnds (tensor): tsdf volume bounds. 269 | stage (str): query stage. 270 | gt_depth (tensor, optional): sensor depth image. Defaults to None. 271 | 272 | Returns: 273 | depth (tensor, H*W): rendered depth image. 274 | uncertainty (tensor, H*W): rendered uncertainty image. 275 | color (tensor, H*W*3): rendered color image. 276 | """ 277 | 278 | with torch.no_grad(): 279 | H = self.H 280 | W = self.W 281 | rays_o, rays_d = get_rays( 282 | H, W, self.fx, self.fy, self.cx, self.cy, c2w, device) 283 | rays_o = rays_o.reshape(-1, 3) 284 | rays_d = rays_d.reshape(-1, 3) 285 | 286 | depth_list = [] 287 | uncertainty_list = [] 288 | color_list = [] 289 | 290 | 291 | ray_batch_size = self.ray_batch_size 292 | gt_depth = gt_depth.reshape(-1) 293 | 294 | for i in range(0, rays_d.shape[0], ray_batch_size): 295 | rays_d_batch = rays_d[i:i+ray_batch_size] 296 | rays_o_batch = rays_o[i:i+ray_batch_size] 297 | 298 | iter = 10 299 | 300 | if gt_depth is None: 301 | ret = self.render_batch_ray( 302 | c, decoders, rays_d_batch, rays_o_batch, device, tsdf_volume, tsdf_bnds, stage, gt_depth=None) 303 | else: 304 | gt_depth_batch = gt_depth[i:i+ray_batch_size] 305 | ret = self.render_batch_ray( 306 | c, decoders, rays_d_batch, rays_o_batch, device, tsdf_volume, tsdf_bnds, stage, gt_depth=gt_depth_batch) 307 | 308 | depth, uncertainty, color, _= ret 309 | 310 | 311 | depth_list.append(depth.double()) 312 | uncertainty_list.append(uncertainty.double()) 313 | color_list.append(color) 314 | 315 | 316 | 317 | 318 | 319 | depth = torch.cat(depth_list, dim=0) 320 | uncertainty = torch.cat(uncertainty_list, dim=0) 321 | color = torch.cat(color_list, dim=0) 322 | 323 | depth = depth.reshape(H, W) 324 | uncertainty = uncertainty.reshape(H, W) 325 | color = color.reshape(H, W, 3) 326 | 327 | return depth, uncertainty, color 328 | 329 | 330 | -------------------------------------------------------------------------------- /src/utils/Visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from src.common import get_camera_from_tensor 6 | import open3d as o3d 7 | 8 | class Visualizer(object): 9 | """ 10 | Visualize intermediate results, render out depth, color and depth uncertainty images. 11 | It can be called per iteration, which is good for debugging (to see how each tracking/mapping iteration performs). 12 | 13 | """ 14 | 15 | def __init__(self, freq, inside_freq, vis_dir, renderer, verbose, device='cuda:0'): 16 | self.freq = freq 17 | self.device = device 18 | self.vis_dir = vis_dir 19 | self.verbose = verbose 20 | self.renderer = renderer 21 | self.inside_freq = inside_freq 22 | os.makedirs(f'{vis_dir}', exist_ok=True) 23 | 24 | def vis(self, idx, iter, gt_depth, gt_color, c2w_or_camera_tensor, c, 25 | decoders, tsdf_volume, tsdf_bnds): 26 | """ 27 | Visualization of depth, color images and save to file. 28 | 29 | Args: 30 | idx (int): current frame index. 31 | iter (int): the iteration number. 32 | gt_depth (tensor): ground truth depth image of the current frame. 33 | gt_color (tensor): ground truth color image of the current frame. 34 | c2w_or_camera_tensor (tensor): camera pose, represented in 35 | camera to world matrix or quaternion and translation tensor. 36 | c (dicts): feature grids. 37 | decoders (nn.module): decoders. 38 | tsdf_volume (tensor): tsdf volume. 39 | tsdf_bnds (tensor): tsdf volume bounds. 40 | """ 41 | with torch.no_grad(): 42 | if (idx % self.freq == 0) and (iter % self.inside_freq == 0): 43 | gt_depth_np = gt_depth.cpu().numpy() 44 | gt_color_np = gt_color.cpu().numpy() 45 | if len(c2w_or_camera_tensor.shape) == 1: 46 | bottom = torch.from_numpy( 47 | np.array([0, 0, 0, 1.]).reshape([1, 4])).type( 48 | torch.float32).to(self.device) 49 | c2w = get_camera_from_tensor( 50 | c2w_or_camera_tensor.clone().detach()) 51 | c2w = torch.cat([c2w, bottom], dim=0) 52 | else: 53 | c2w = c2w_or_camera_tensor 54 | 55 | depth, _, color = self.renderer.render_img( 56 | c, 57 | decoders, 58 | c2w, 59 | self.device, 60 | tsdf_volume, 61 | tsdf_bnds, 62 | stage='color', 63 | gt_depth=gt_depth) 64 | 65 | # convert to open3d camera pose 66 | c2w = c2w.cpu().numpy() 67 | c2w[:3, 1] *= -1.0 68 | c2w[:3, 2] *= -1.0 69 | 70 | 71 | depth_np = depth.detach().cpu().numpy() 72 | color_np = color.detach().cpu().numpy() 73 | depth = depth_np.astype(np.float32) 74 | color = np.array((color_np * 255).astype(np.uint8)) 75 | 76 | depth_residual = np.abs(gt_depth_np - depth_np) 77 | depth_residual[gt_depth_np == 0.0] = 0.0 78 | color_residual = np.abs(gt_color_np - color_np) 79 | color_residual[gt_depth_np == 0.0] = 0.0 80 | 81 | 82 | fig, axs = plt.subplots(2, 3) 83 | fig.tight_layout() 84 | max_depth = np.max(gt_depth_np) 85 | axs[0, 0].imshow(gt_depth_np, cmap="plasma", 86 | vmin=0, vmax=max_depth) 87 | axs[0, 0].set_title('Input Depth') 88 | axs[0, 0].set_xticks([]) 89 | axs[0, 0].set_yticks([]) 90 | axs[0, 1].imshow(depth_np, cmap="plasma", 91 | vmin=0, vmax=max_depth) 92 | axs[0, 1].set_title('Generated Depth') 93 | axs[0, 1].set_xticks([]) 94 | axs[0, 1].set_yticks([]) 95 | axs[0, 2].imshow(depth_residual, cmap="plasma", 96 | vmin=0, vmax=max_depth) 97 | axs[0, 2].set_title('Depth Residual') 98 | axs[0, 2].set_xticks([]) 99 | axs[0, 2].set_yticks([]) 100 | gt_color_np = np.clip(gt_color_np, 0, 1) 101 | color_np = np.clip(color_np, 0, 1) 102 | color_residual = np.clip(color_residual, 0, 1) 103 | axs[1, 0].imshow(gt_color_np, cmap="plasma") 104 | axs[1, 0].set_title('Input RGB') 105 | axs[1, 0].set_xticks([]) 106 | axs[1, 0].set_yticks([]) 107 | axs[1, 1].imshow(color_np, cmap="plasma") 108 | axs[1, 1].set_title('Generated RGB') 109 | axs[1, 1].set_xticks([]) 110 | axs[1, 1].set_yticks([]) 111 | axs[1, 2].imshow(color_residual, cmap="plasma") 112 | axs[1, 2].set_title('RGB Residual') 113 | axs[1, 2].set_xticks([]) 114 | axs[1, 2].set_yticks([]) 115 | plt.subplots_adjust(wspace=0, hspace=0) 116 | plt.savefig( 117 | f'{self.vis_dir}/{idx:05d}_{iter:04d}.jpg', bbox_inches='tight', pad_inches=0.2) 118 | plt.clf() 119 | 120 | if self.verbose: 121 | print( 122 | f'Saved rendering visualization of color/depth image at {self.vis_dir}/{idx:05d}_{iter:04d}.jpg') 123 | 124 | -------------------------------------------------------------------------------- /src/utils/__pycache__/Logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Logger.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Mesher.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Mesher.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Mesher.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Mesher.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Renderer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Renderer.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Renderer.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/Visualizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/Visualizer.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MachinePerceptionLab/Attentive_DFPrior/401c71384ba511a5def8bcf8657ad11eef4e24eb/src/utils/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from src.common import as_intrinsics_matrix 9 | from torch.utils.data import Dataset 10 | 11 | 12 | def readEXR_onlydepth(filename): 13 | """ 14 | Read depth data from EXR image file. 15 | 16 | Args: 17 | filename (str): File path. 18 | 19 | Returns: 20 | Y (numpy.array): Depth buffer in float32 format. 21 | """ 22 | # move the import here since only CoFusion needs these package 23 | # sometimes installation of openexr is hard, you can run all other datasets 24 | # even without openexr 25 | import Imath 26 | import OpenEXR as exr 27 | 28 | exrfile = exr.InputFile(filename) 29 | header = exrfile.header() 30 | dw = header['dataWindow'] 31 | isize = (dw.max.y - dw.min.y + 1, dw.max.x - dw.min.x + 1) 32 | 33 | channelData = dict() 34 | 35 | for c in header['channels']: 36 | C = exrfile.channel(c, Imath.PixelType(Imath.PixelType.FLOAT)) 37 | C = np.fromstring(C, dtype=np.float32) 38 | C = np.reshape(C, isize) 39 | 40 | channelData[c] = C 41 | 42 | Y = None if 'Y' not in header['channels'] else channelData['Y'] 43 | 44 | return Y 45 | 46 | 47 | def get_dataset(cfg, args, scale, device='cuda:0'): 48 | return dataset_dict[cfg['dataset']](cfg, args, scale, device=device) 49 | 50 | 51 | class BaseDataset(Dataset): 52 | def __init__(self, cfg, args, scale, device='cuda:0' 53 | ): 54 | super(BaseDataset, self).__init__() 55 | self.name = cfg['dataset'] 56 | self.device = device 57 | self.scale = scale 58 | self.png_depth_scale = cfg['cam']['png_depth_scale'] 59 | 60 | self.H, self.W, self.fx, self.fy, self.cx, self.cy = cfg['cam']['H'], cfg['cam'][ 61 | 'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy'] 62 | 63 | self.distortion = np.array( 64 | cfg['cam']['distortion']) if 'distortion' in cfg['cam'] else None 65 | self.crop_size = cfg['cam']['crop_size'] if 'crop_size' in cfg['cam'] else None 66 | 67 | if args.input_folder is None: 68 | self.input_folder = cfg['data']['input_folder'] 69 | else: 70 | self.input_folder = args.input_folder 71 | 72 | self.crop_edge = cfg['cam']['crop_edge'] 73 | 74 | def __len__(self): 75 | return self.n_img 76 | 77 | def __getitem__(self, index): 78 | color_path = self.color_paths[index] 79 | depth_path = self.depth_paths[index] 80 | color_data = cv2.imread(color_path) 81 | if '.png' in depth_path: 82 | depth_data = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) 83 | elif '.exr' in depth_path: 84 | depth_data = readEXR_onlydepth(depth_path) 85 | if self.distortion is not None: 86 | K = as_intrinsics_matrix([self.fx, self.fy, self.cx, self.cy]) 87 | # undistortion is only applied on color image, not depth! 88 | color_data = cv2.undistort(color_data, K, self.distortion) 89 | 90 | color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB) 91 | color_data = color_data / 255. 92 | depth_data = depth_data.astype(np.float32) / self.png_depth_scale 93 | H, W = depth_data.shape 94 | color_data = cv2.resize(color_data, (W, H)) 95 | color_data = torch.from_numpy(color_data) 96 | depth_data = torch.from_numpy(depth_data)*self.scale 97 | if self.crop_size is not None: 98 | # follow the pre-processing step in lietorch, actually is resize 99 | color_data = color_data.permute(2, 0, 1) 100 | color_data = F.interpolate( 101 | color_data[None], self.crop_size, mode='bilinear', align_corners=True)[0] 102 | depth_data = F.interpolate( 103 | depth_data[None, None], self.crop_size, mode='nearest')[0, 0] 104 | color_data = color_data.permute(1, 2, 0).contiguous() 105 | 106 | edge = self.crop_edge 107 | if edge > 0: 108 | # crop image edge, there are invalid value on the edge of the color image 109 | color_data = color_data[edge:-edge, edge:-edge] 110 | depth_data = depth_data[edge:-edge, edge:-edge] 111 | pose = self.poses[index] 112 | pose[:3, 3] *= self.scale 113 | return index, color_data.to(self.device), depth_data.to(self.device), pose.to(self.device) 114 | 115 | 116 | class Replica(BaseDataset): 117 | def __init__(self, cfg, args, scale, device='cuda:0' 118 | ): 119 | super(Replica, self).__init__(cfg, args, scale, device) 120 | self.color_paths = sorted( 121 | glob.glob(f'{self.input_folder}/results/frame*.jpg')) 122 | self.depth_paths = sorted( 123 | glob.glob(f'{self.input_folder}/results/depth*.png')) 124 | self.n_img = len(self.color_paths) 125 | self.load_poses(f'{self.input_folder}/traj.txt') 126 | 127 | def load_poses(self, path): 128 | self.poses = [] 129 | with open(path, "r") as f: 130 | lines = f.readlines() 131 | for i in range(self.n_img): 132 | line = lines[i] 133 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 134 | c2w[:3, 1] *= -1 135 | c2w[:3, 2] *= -1 136 | c2w = torch.from_numpy(c2w).float() 137 | self.poses.append(c2w) 138 | 139 | 140 | class Azure(BaseDataset): 141 | def __init__(self, cfg, args, scale, device='cuda:0' 142 | ): 143 | super(Azure, self).__init__(cfg, args, scale, device) 144 | self.color_paths = sorted( 145 | glob.glob(os.path.join(self.input_folder, 'color', '*.jpg'))) 146 | self.depth_paths = sorted( 147 | glob.glob(os.path.join(self.input_folder, 'depth', '*.png'))) 148 | self.n_img = len(self.color_paths) 149 | self.load_poses(os.path.join( 150 | self.input_folder, 'scene', 'trajectory.log')) 151 | 152 | def load_poses(self, path): 153 | self.poses = [] 154 | if os.path.exists(path): 155 | with open(path) as f: 156 | content = f.readlines() 157 | 158 | # Load .log file. 159 | for i in range(0, len(content), 5): 160 | # format %d (src) %d (tgt) %f (fitness) 161 | data = list(map(float, content[i].strip().split(' '))) 162 | ids = (int(data[0]), int(data[1])) 163 | fitness = data[2] 164 | 165 | # format %f x 16 166 | c2w = np.array( 167 | list(map(float, (''.join( 168 | content[i + 1:i + 5])).strip().split()))).reshape((4, 4)) 169 | 170 | c2w[:3, 1] *= -1 171 | c2w[:3, 2] *= -1 172 | c2w = torch.from_numpy(c2w).float() 173 | self.poses.append(c2w) 174 | else: 175 | for i in range(self.n_img): 176 | c2w = np.eye(4) 177 | c2w = torch.from_numpy(c2w).float() 178 | self.poses.append(c2w) 179 | 180 | 181 | class ScanNet(BaseDataset): 182 | def __init__(self, cfg, args, scale, device='cuda:0' 183 | ): 184 | super(ScanNet, self).__init__(cfg, args, scale, device) 185 | self.input_folder = os.path.join(self.input_folder, 'frames') 186 | self.color_paths = sorted(glob.glob(os.path.join( 187 | self.input_folder, 'color', '*.jpg')), key=lambda x: int(os.path.basename(x)[:-4])) 188 | self.depth_paths = sorted(glob.glob(os.path.join( 189 | self.input_folder, 'depth', '*.png')), key=lambda x: int(os.path.basename(x)[:-4])) 190 | self.load_poses(os.path.join(self.input_folder, 'pose')) 191 | self.n_img = len(self.color_paths) 192 | 193 | def load_poses(self, path): 194 | self.poses = [] 195 | pose_paths = sorted(glob.glob(os.path.join(path, '*.txt')), 196 | key=lambda x: int(os.path.basename(x)[:-4])) 197 | for pose_path in pose_paths: 198 | with open(pose_path, "r") as f: 199 | lines = f.readlines() 200 | ls = [] 201 | for line in lines: 202 | l = list(map(float, line.split(' '))) 203 | ls.append(l) 204 | c2w = np.array(ls).reshape(4, 4) 205 | c2w[:3, 1] *= -1 206 | c2w[:3, 2] *= -1 207 | c2w = torch.from_numpy(c2w).float() 208 | self.poses.append(c2w) 209 | 210 | 211 | class CoFusion(BaseDataset): 212 | def __init__(self, cfg, args, scale, device='cuda:0' 213 | ): 214 | super(CoFusion, self).__init__(cfg, args, scale, device) 215 | self.input_folder = os.path.join(self.input_folder) 216 | self.color_paths = sorted( 217 | glob.glob(os.path.join(self.input_folder, 'colour', '*.png'))) 218 | self.depth_paths = sorted(glob.glob(os.path.join( 219 | self.input_folder, 'depth_noise', '*.exr'))) 220 | self.n_img = len(self.color_paths) 221 | self.load_poses(os.path.join(self.input_folder, 'trajectories')) 222 | 223 | def load_poses(self, path): 224 | # We tried, but cannot align the coordinate frame of cofusion to ours. 225 | # So here we provide identity matrix as proxy. 226 | # But it will not affect the calculation of ATE since camera trajectories can be aligned. 227 | self.poses = [] 228 | for i in range(self.n_img): 229 | c2w = np.eye(4) 230 | c2w = torch.from_numpy(c2w).float() 231 | self.poses.append(c2w) 232 | 233 | 234 | class TUM_RGBD(BaseDataset): 235 | def __init__(self, cfg, args, scale, device='cuda:0' 236 | ): 237 | super(TUM_RGBD, self).__init__(cfg, args, scale, device) 238 | self.color_paths, self.depth_paths, self.poses = self.loadtum( 239 | self.input_folder, frame_rate=32) 240 | self.n_img = len(self.color_paths) 241 | 242 | def parse_list(self, filepath, skiprows=0): 243 | """ read list data """ 244 | data = np.loadtxt(filepath, delimiter=' ', 245 | dtype=np.unicode_, skiprows=skiprows) 246 | return data 247 | 248 | def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08): 249 | """ pair images, depths, and poses """ 250 | associations = [] 251 | for i, t in enumerate(tstamp_image): 252 | if tstamp_pose is None: 253 | j = np.argmin(np.abs(tstamp_depth - t)) 254 | if (np.abs(tstamp_depth[j] - t) < max_dt): 255 | associations.append((i, j)) 256 | 257 | else: 258 | j = np.argmin(np.abs(tstamp_depth - t)) 259 | k = np.argmin(np.abs(tstamp_pose - t)) 260 | 261 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \ 262 | (np.abs(tstamp_pose[k] - t) < max_dt): 263 | associations.append((i, j, k)) 264 | 265 | return associations 266 | 267 | def loadtum(self, datapath, frame_rate=-1): 268 | """ read video data in tum-rgbd format """ 269 | if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')): 270 | pose_list = os.path.join(datapath, 'groundtruth.txt') 271 | elif os.path.isfile(os.path.join(datapath, 'pose.txt')): 272 | pose_list = os.path.join(datapath, 'pose.txt') 273 | 274 | image_list = os.path.join(datapath, 'rgb.txt') 275 | depth_list = os.path.join(datapath, 'depth.txt') 276 | 277 | image_data = self.parse_list(image_list) 278 | depth_data = self.parse_list(depth_list) 279 | pose_data = self.parse_list(pose_list, skiprows=1) 280 | pose_vecs = pose_data[:, 1:].astype(np.float64) 281 | 282 | tstamp_image = image_data[:, 0].astype(np.float64) 283 | tstamp_depth = depth_data[:, 0].astype(np.float64) 284 | tstamp_pose = pose_data[:, 0].astype(np.float64) 285 | associations = self.associate_frames( 286 | tstamp_image, tstamp_depth, tstamp_pose) 287 | 288 | indicies = [0] 289 | for i in range(1, len(associations)): 290 | t0 = tstamp_image[associations[indicies[-1]][0]] 291 | t1 = tstamp_image[associations[i][0]] 292 | if t1 - t0 > 1.0 / frame_rate: 293 | indicies += [i] 294 | 295 | images, poses, depths, intrinsics = [], [], [], [] 296 | inv_pose = None 297 | for ix in indicies: 298 | (i, j, k) = associations[ix] 299 | images += [os.path.join(datapath, image_data[i, 1])] 300 | depths += [os.path.join(datapath, depth_data[j, 1])] 301 | c2w = self.pose_matrix_from_quaternion(pose_vecs[k]) 302 | if inv_pose is None: 303 | inv_pose = np.linalg.inv(c2w) 304 | c2w = np.eye(4) 305 | else: 306 | c2w = inv_pose@c2w 307 | c2w[:3, 1] *= -1 308 | c2w[:3, 2] *= -1 309 | c2w = torch.from_numpy(c2w).float() 310 | poses += [c2w] 311 | 312 | return images, depths, poses 313 | 314 | def pose_matrix_from_quaternion(self, pvec): 315 | """ convert 4x4 pose matrix to (t, q) """ 316 | from scipy.spatial.transform import Rotation 317 | 318 | pose = np.eye(4) 319 | pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix() 320 | pose[:3, 3] = pvec[:3] 321 | return pose 322 | 323 | 324 | dataset_dict = { 325 | "replica": Replica, 326 | "scannet": ScanNet, 327 | "cofusion": CoFusion, 328 | "azure": Azure, 329 | "tumrgbd": TUM_RGBD 330 | } 331 | --------------------------------------------------------------------------------