├── README.md ├── assets ├── PLV.png ├── cover_2.png ├── encoder.png └── toy_result.png ├── configs └── replica │ ├── office0.yaml │ ├── office0_custom.yaml │ ├── office0_w_clip.yaml │ └── office0_w_slam.yaml ├── demo.py ├── example ├── office0 │ ├── depth000000.png │ ├── depth000020.png │ ├── frame000000.jpg │ ├── frame000020.jpg │ └── traj.txt ├── render_w_lim.py ├── toy.py └── util.py ├── external └── openseg │ └── openseg_api.py ├── scripts └── download_replica.sh ├── setup.py ├── uni ├── __init__.py ├── dataset │ ├── 3dscene.py │ ├── NICE_SLAM_config │ │ └── demo.yaml │ ├── NICE_SLAM_dataset.py │ ├── __init__.py │ ├── aug_icl.py │ ├── azure.py │ ├── bpnet_scannet.py │ ├── custom.py │ ├── custom_w_slam.py │ ├── fountain.py │ ├── icl_nuim.py │ ├── latent_map.py │ ├── matterport3d.py │ ├── replica.py │ ├── scannet.py │ └── tum.py ├── encoder │ ├── __init__.py │ ├── position_encoder.pth │ ├── uni_encoder_v2.py │ └── utility.py ├── ext │ ├── __init__.py │ ├── imgproc │ │ ├── common.cuh │ │ ├── imgproc.cpp │ │ ├── imgproc.cu │ │ └── photometric.cu │ ├── indexing │ │ ├── indexing.cpp │ │ └── indexing.cu │ ├── marching_cubes │ │ ├── mc.cpp │ │ ├── mc_data.cuh │ │ └── mc_interp_kernel.cu │ └── pcproc │ │ ├── cuda_kdtree.cu │ │ ├── cuda_kdtree.cuh │ │ ├── cutil_math.h │ │ ├── pcproc.cpp │ │ └── pcproc.cu ├── mapper │ ├── __init__.py │ ├── base_map.py │ ├── context_map_v2.py │ ├── latent_map.py │ └── surface_map.py ├── tracker │ ├── __init__.py │ ├── cicp.py │ └── tracker_custom.py └── utils │ ├── __init__.py │ ├── exp_util.py │ ├── linalg_util.py │ ├── motion_util.py │ ├── pt_util.py │ ├── ray_cast.py │ ├── torch_scatter.py │ └── vis_util.py └── vis_LIMs.py /README.md: -------------------------------------------------------------------------------- 1 | # [Uni-Fusion: Universal Continuous Mapping](https://jarrome.github.io/Uni-Fusion/) 2 | 3 | [Yijun Yuan](https://jarrome.github.io/), [Andreas Nüchter](https://www.informatik.uni-wuerzburg.de/robotics/team/nuechter/) 4 | 5 | [Preprint](https://arxiv.org/abs/2303.12678) | [website](https://jarrome.github.io/Uni-Fusion/) 6 | 7 | #### Uni-Fusion is *nothing to do with NeRF!* 8 | #### It is a Fusion method (only forward and fusion)! 9 | 10 |

11 | 12 | 13 |

14 | 15 | *Universal encoder **no need data train** | Picture on the right is voxel grid for mapping* 16 | 17 | *Therefore, it supports **any mapping**:* 18 | 19 |

20 | 21 |

22 | 23 | 24 |
25 | Table of Contents 26 |
    27 |
  1. 28 | Installation 29 |
  2. 30 |
  3. 31 | Demo 32 |
  4. 33 |
  5. 34 | TODO 35 |
  6. 36 |
  7. 37 | Citation 38 |
  8. 39 |
  9. 40 | Acknowledgement 41 |
  10. 42 |
43 |
44 | 45 | ## Env setting and install 46 |
47 | Unfold this for installation 48 | 49 | * Create env 50 | ```bash 51 | conda create -n uni python=3.8 52 | conda activate uni 53 | 54 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 55 | pip install torch-scatter torch-sparse torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html 56 | pip install ninja functorch==0.2.1 numba open3d opencv-python trimesh torchfile 57 | ``` 58 | 59 | * install package 60 | ```bash 61 | git clone https://github.com/Jarrome/Uni-Fusion.git && cd Uni-Fusion 62 | # install uni package 63 | python setup.py install 64 | # install cuda function, this may take several minutes, please use `top` or `ps` to check 65 | python uni/ext/__init__.py 66 | ``` 67 | 68 | * train a uni encoder from nothing in 1 second 69 | ```bash 70 | python uni/encoder/uni_encoder_v2.py 71 | ``` 72 | 73 | 74 |
75 | optionally, you can install the [ORB-SLAM2](https://github.com/Jarrome/Uni-Fusion-use-ORB-SLAM2) that we use for tracking 76 | 77 | ```bash 78 | cd external 79 | git clone https://github.com/Jarrome/Uni-Fusion-use-ORB-SLAM2 80 | cd [this_folder] 81 | # this_folder is the absolute path for the orbslam2 82 | # Add ORB_SLAM2/lib to PYTHONPATH and LD_LIBRARY_PATH environment variables 83 | # I suggest putting this in ~/.bashrc 84 | export PYTHONPATH=$PYTHONPATH:[this_folder]/lib 85 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:[this_folder]/lib 86 | 87 | ./build.sh && ./build_python.sh 88 | ``` 89 |
90 |
91 | 92 | ## Demo 93 | 94 | ### 0. Quick try 95 | We provide a toy example to quick try our algorithm. 96 | You can either `python example/toy.py` or code as following: 97 | ```python 98 | import torch 99 | import numpy as np 100 | 101 | from example.util import get_modules, get_example_data 102 | 103 | device = torch.device("cuda", index=0) 104 | 105 | # get mapper and tracker 106 | sm, cm, tracker, config = get_modules(device) 107 | 108 | # prepare data 109 | colors, depths, customs, calib, poses = get_example_data(device) 110 | 111 | for i in [0, 1]: 112 | # preprocess rgbd to point cloud 113 | frame_pose = tracker.track_camera(colors[i], depths[i], customs, calib, poses[i], scene = config.sequence_type) 114 | # transform data 115 | tracker_pc, tracker_normal, tracker_customs= tracker.last_processed_pc 116 | opt_depth = frame_pose @ tracker_pc 117 | opt_normal = frame_pose.rotation @ tracker_normal 118 | color_pc, color, color_normal = tracker.last_colored_pc 119 | color_pc = frame_pose @ color_pc 120 | color_normal = frame_pose.rotation @ color_normal if color_normal is not None else None 121 | 122 | # mapping pc 123 | sm.integrate_keyframe(opt_depth, opt_normal) 124 | cm.integrate_keyframe(color_pc, color, color_normal) 125 | 126 | # mesh extraction 127 | map_mesh = sm.extract_mesh(config.resolution, int(4e7), max_std=0.15, extract_async=False, interpolate=True) 128 | 129 | import open3d as o3d 130 | o3d.io.write_triangle_mesh('example/mesh.ply', map_mesh) 131 | 132 | ``` 133 | You will get a mesh looks like this: 134 | 135 |

136 | 137 |

138 | 139 | 140 | 141 | 142 | --- 143 | Then 144 | 145 | * **All demo can be run with ```python demo.py [config]```** 146 | * **Mesh for color, style, infrad, semantic can be extracted with ```python vis_LIM.py [config]```** 147 | * **Rendering for RGB and Depth image can be extracted with ```python example/render_w_LIM.py [config] [optionally traj with GT poses]```** 148 | 149 | ### 1. Reconstruction Demo 150 | ```bash 151 | # download replica data 152 | source scripts/download_replica.sh 153 | 154 | # with gt pose 155 | python demo.py configs/replica/office0.yaml 156 | 157 | # with slam 158 | python demo.py configs/replica/office0_w_slam.yaml 159 | ``` 160 | 161 | Then you can find results in `output/replica/office0` where was specified in the `[config]` file: 162 | ```console 163 | $ ls output/replica/office0 164 | 165 | surface.lim 166 | color.lim 167 | final_recons.ply 168 | pred_traj.txt 169 | ``` 170 | 171 | * *in [scene_w_slam.yaml], we can choose 3 mode* 172 | 173 | |Usage| load_gt| slam| 174 | |---|---|---| 175 | |use SLAM track|False|True| 176 | |use SLAM pred pose|True|True| 177 | |use GT pose|True|False| 178 | 179 | * *you can set ```vis=True``` for online vis (```False``` by default), which is more Di-Fusion. You can tap keyboard ',' for step and '.' for continue running with GUI* 180 | 181 | * *LIM extraction for mesh* 182 | ``` 183 | python vis_LIM.py configs/replica/office0.yaml 184 | ``` 185 | 186 | will generate a `output/replica/office0/color_recons.ply` 187 | 188 | * *LIM rendering given result LIMs* 189 | ``` 190 | # with gt pose 191 | python example/render_w_lim.py configs/replica/office0.yaml data/replica/office0/traj.txt 192 | 193 | # otherwise 194 | python example/render_w_lim.py configs/replica/office0_w_slam.yaml 195 | ``` 196 | 197 | This will creat a `render` folder under `output/replica/office0` where was specified in the `[config]` file: 198 | 199 | ```console 200 | $ ls output/replica/office0 201 | 202 | surface.lim 203 | color.lim 204 | final_recons.ply 205 | pred_traj.txt 206 | render/ # here contains rendered RGB and Depth images 207 | ``` 208 | 209 | 210 | ### 2. Custom context Demo 211 | 212 | [```office0_custom.yaml```](https://github.com/Jarrome/Uni-Fusion/blob/main/configs/replica/office0_custom.yaml) contains all mapping you need 213 | 214 | ```bash 215 | # if you need saliency 216 | pip install transparent-background numba 217 | # if you need style 218 | cd external 219 | git clone https://github.com/Jarrome/PyTorch-Multi-Style-Transfer.git 220 | mv PyTorch-Multi-Style-Transfer style_transfer 221 | cd style_transfer/experiments 222 | bash models/download_model.sh 223 | cd ../../../ 224 | 225 | # run demo 226 | python demo.py configs/replica/office0_custom.yaml 227 | 228 | 229 | # LIM extraction of custom property shown on mesh 230 | python vis_LIM.py configs/replica/office0_custom.yaml 231 | ``` 232 | 233 | 234 | ### 3. Open Vocabulary Scene Understanding Demo 235 | This Text-Visual CLIP is from [OpenSeg](https://github.com/tensorflow/tpu/tree/641c1ac6e26ed788327b973582cbfa297d7d31e7/models/official/detection/projects/openseg) 236 | ```bash 237 | # install requirements 238 | pip install tensorflow==2.5.0 239 | pip install git+https://github.com/openai/CLIP.git 240 | 241 | # download openseg ckpt 242 | # can use `sudo snap install google-cloud-cli --classic` to install gsutil 243 | gsutil cp -r gs://cloud-tpu-checkpoints/detection/projects/openseg/colab/exported_model ./external/openseg/ 244 | 245 | python demo.py configs/replica/office0_w_clip.yaml 246 | 247 | # LIM extraction of semantic shown on mesh 248 | python vis_LIM.py configs/replica/office0_w_clip.yaml 249 | ``` 250 | 251 | ### 4. Self-captured data 252 | #### Azure capturing 253 | We provide the script to extract RGB, D and IR from azure.mp4: [azure_process](https://github.com/Jarrome/azure_process). 254 | 255 | The captured apartment data stores [here](https://robotik.informatik.uni-wuerzburg.de/telematics/download/appartment2.tgz). 256 | 257 | --- 258 | ## TODO: 259 | - [x] Upload the uni-encoder src (Jan.3) 260 | - [x] Upload the env script (Jan.4) 261 | - [x] Upload the recon. application (By Jan.8) 262 | - [x] Upload the used ORB-SLAM2 support (Jan.8) 263 | - [x] Upload the azure process for RGB,D,IR (Jan.8) 264 | - [x] Upload the seman. application (Jan.14) 265 | - [x] Upload the Custom context demo (Jan.14) 266 | - [x] Toy example for fast essembling Uni-Fusion into custom project 267 | - [x] Extraction of Mesh w properties from Latent Implicit Maps (LIMs) (Jun.26) [Sry for the delay... Yijun just get some free time...] 268 | - [x] Rendering of RGB and Depth images from Latent Implicit Maps (LIMs) (Jun.26) 269 | - [ ] Our current new project [SceneFactory](https://jarrome.github.io/SceneFactory/) has a better option, I plan to replace this ORB-SLAM2 with that option after open-release that work. 270 | 271 | --- 272 | ## Citation 273 | If you find this work interesting, please cite us: 274 | ```bibtex 275 | @article{yuan2024uni, 276 | title={Uni-Fusion: Universal Continuous Mapping}, 277 | author={Yuan, Yijun and N{\"u}chter, Andreas}, 278 | journal={IEEE Transactions on Robotics}, 279 | year={2024}, 280 | publisher={IEEE} 281 | } 282 | ``` 283 | 284 | ## Acknowledgement 285 | * This implementation is on top of [DI-Fusion](https://github.com/huangjh-pub/di-fusion). 286 | * We also borrow some dataset code from [NICE-SLAM](https://github.com/cvg/nice-slam). 287 | * We thank the detailed response of questions from Kejie Li, Björn Michele, Songyou Peng and Golnaz Ghiasi. 288 | -------------------------------------------------------------------------------- /assets/PLV.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/assets/PLV.png -------------------------------------------------------------------------------- /assets/cover_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/assets/cover_2.png -------------------------------------------------------------------------------- /assets/encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/assets/encoder.png -------------------------------------------------------------------------------- /assets/toy_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/assets/toy_result.png -------------------------------------------------------------------------------- /configs/replica/office0.yaml: -------------------------------------------------------------------------------- 1 | # Sequence parameters 2 | sequence_type: "replica.ReplicaRGBDDataset" 3 | sequence_kwargs: 4 | path: "data/Replica/office0/results/" 5 | start_frame: 0 6 | end_frame: -1 # Run all frames 7 | first_tq: [0, 0, 0.0, 0.0, -1.0, 0.0, 0.0] 8 | load_gt: True # use gt traj or not 9 | 10 | mesh_gt: "data/Replica/office0_mesh.ply" 11 | 12 | # Network parameters (network structure, etc. will be inherited from the training config) 13 | training_hypers: "ckpt/default/hyper.json" 14 | using_epoch: 300 15 | 16 | 17 | # Separate tracking and meshing. 18 | run_async: True 19 | # Enable visualization 20 | vis: False 21 | resolution: 4 22 | 23 | # These two define the range of depth observations to be cropped. Unit is meter. 24 | depth_cut_min: 0.1 25 | depth_cut_max: 10.0 26 | 27 | meshing_interval: 10 28 | integrate_interval: 10 29 | track_interval: 5 30 | #color_integrate_interval: 20 31 | 32 | 33 | # Mapping parameters 34 | surface_mapping: 35 | GPIS_mode: "sample" 36 | margin: .1 37 | 38 | # Bound of the scene to be reconstructed 39 | bound_min: [-10., -5., -10.] 40 | bound_max: [10., 5., 10.] 41 | 42 | voxel_size: 0.05 43 | # Prune observations if detected as noise. 44 | prune_min_vox_obs: 1 45 | ignore_count_th: 1.0 46 | encoder_count_th: 60000.0 47 | 48 | # Mapping parameters 49 | context_mapping: 50 | # Bound of the scene to be reconstructed 51 | bound_min: [-10., -5., -10.] 52 | bound_max: [10., 5., 10.] 53 | voxel_size: .02 54 | # Prune observations if detected as noise. 55 | prune_min_vox_obs: 1 56 | ignore_count_th: 1.0 57 | encoder_count_th: 60000.0 58 | 59 | outdir: "./output/replica/office0/" 60 | 61 | 62 | # Tracking parameters 63 | tracking: 64 | # An array defining how the camera pose is optimized. 65 | # Each element is a dictionary: 66 | # For example {"n": 2, "type": [['sdf'], ['rgb', 1]]} means to optimize the summation of sdf term and rgb term 67 | # at the 1st level pyramid for 2 iterations. 68 | iter_config: 69 | #- {"n": 10, "type": [['rgb', 2]]} 70 | - {"n": 5, "type": [['sdf'], ['rgb', 1]]} 71 | - {"n": 10, "type": [['sdf'], ['rgb', 0]]} 72 | sdf: 73 | robust_kernel: "huber" 74 | robust_k: 5.0 75 | subsample: 0.5 76 | rgb: 77 | weight: 500.0 78 | robust_kernel: null 79 | robust_k: 0.01 80 | min_grad_scale: 0.0 81 | max_depth_delta: 0.2 82 | -------------------------------------------------------------------------------- /configs/replica/office0_custom.yaml: -------------------------------------------------------------------------------- 1 | # Sequence parameters 2 | sequence_type: "custom_w_slam.CustomReplicawSLAM" 3 | sequence_kwargs: 4 | path: "data/Replica/office0/results/" 5 | start_frame: 0 6 | end_frame: -1 # Run all frames 7 | #first_tq: [-1.2, 1.3, 1.0, 0.0, -1.0, 0.0, 0.0] # Starting pose 8 | first_tq: [0, 0, 0.0, 0.0, -1.0, 0.0, 0.0] 9 | load_gt: True 10 | 11 | mesh_gt: "data/Replica/office0_mesh.ply" 12 | 13 | outdir: "./output/custom/office0/" 14 | slam: False 15 | # Network parameters (network structure, etc. will be inherited from the training config) 16 | training_hypers: "ckpt/default/hyper.json" 17 | using_epoch: 300 18 | 19 | 20 | # Separate tracking and meshing. 21 | run_async: True 22 | # Enable visualization 23 | vis: False 24 | resolution: 4 25 | 26 | # These two define the range of depth observations to be cropped. Unit is meter. 27 | depth_cut_min: 0.1 28 | depth_cut_max: 10.0 29 | 30 | meshing_interval: 10 31 | integrate_interval: 10 32 | track_interval: 5 33 | #color_integrate_interval: 20 34 | 35 | 36 | # Mapping parameters 37 | surface_mapping: 38 | GPIS_mode: "sample" 39 | margin: .1 40 | 41 | # Bound of the scene to be reconstructed 42 | bound_min: [-10., -5., -10.] 43 | bound_max: [10., 5., 10.] 44 | 45 | voxel_size: 0.05 46 | # Prune observations if detected as noise. 47 | prune_min_vox_obs: 1 48 | ignore_count_th: 1.0 49 | encoder_count_th: 60000.0 50 | 51 | # Mapping parameters 52 | context_mapping: 53 | # Bound of the scene to be reconstructed 54 | bound_min: [-10., -5., -10.] 55 | bound_max: [10., 5., 10.] 56 | voxel_size: .02 57 | # Prune observations if detected as noise. 58 | prune_min_vox_obs: 1 59 | ignore_count_th: 1.0 60 | encoder_count_th: 60000.0 61 | 62 | # Mapping parameters 63 | saliency_mapping: 64 | # Bound of the scene to be reconstructed 65 | bound_min: [-10., -5., -10.] 66 | bound_max: [10., 5., 10.] 67 | voxel_size: .1 68 | # Prune observations if detected as noise. 69 | prune_min_vox_obs: 1 70 | ignore_count_th: 1.0 71 | encoder_count_th: 60000.0 72 | 73 | # Mapping parameters 74 | style_mapping: 75 | # Bound of the scene to be reconstructed 76 | bound_min: [-10., -5., -10.] 77 | bound_max: [10., 5., 10.] 78 | voxel_size: .1 79 | # Prune observations if detected as noise. 80 | prune_min_vox_obs: 1 81 | ignore_count_th: 1.0 82 | encoder_count_th: 60000.0 83 | 84 | 85 | # Mapping parameters 86 | nlatent_mapping: 87 | # Bound of the scene to be reconstructed 88 | bound_min: [-10., -5., -10.] 89 | bound_max: [10., 5., 10.] 90 | voxel_size: .1 91 | # Prune observations if detected as noise. 92 | prune_min_vox_obs: 1 93 | ignore_count_th: 1.0 94 | encoder_count_th: 60000.0 95 | 96 | 97 | 98 | 99 | 100 | # Tracking parameters 101 | tracking: 102 | # An array defining how the camera pose is optimized. 103 | # Each element is a dictionary: 104 | # For example {"n": 2, "type": [['sdf'], ['rgb', 1]]} means to optimize the summation of sdf term and rgb term 105 | # at the 1st level pyramid for 2 iterations. 106 | iter_config: 107 | #- {"n": 10, "type": [['rgb', 2]]} 108 | - {"n": 5, "type": [['sdf'], ['rgb', 1]]} 109 | - {"n": 10, "type": [['sdf'], ['rgb', 0]]} 110 | sdf: 111 | robust_kernel: "huber" 112 | robust_k: 5.0 113 | subsample: 0.5 114 | rgb: 115 | weight: 500.0 116 | robust_kernel: null 117 | robust_k: 0.01 118 | min_grad_scale: 0.0 119 | max_depth_delta: 0.2 120 | -------------------------------------------------------------------------------- /configs/replica/office0_w_clip.yaml: -------------------------------------------------------------------------------- 1 | # Sequence parameters 2 | sequence_type: "custom_w_slam.CustomReplicawSLAM" 3 | sequence_kwargs: 4 | path: "data/Replica/office0/results/" 5 | start_frame: 0 6 | end_frame: -1 # Run all frames 7 | first_tq: [0, 0, 0.0, 0.0, -1.0, 0.0, 0.0] 8 | load_gt: True 9 | 10 | mesh_gt: "data/Replica/office0_mesh.ply" 11 | 12 | outdir: "./output/w_clip/replica/office0/" 13 | 14 | slam: False 15 | 16 | # Network parameters (network structure, etc. will be inherited from the training config) 17 | training_hypers: "ckpt/default/hyper.json" 18 | using_epoch: 300 19 | 20 | 21 | # Separate tracking and meshing. 22 | run_async: True 23 | # Enable visualization 24 | vis: False 25 | resolution: 4 26 | 27 | # These two define the range of depth observations to be cropped. Unit is meter. 28 | depth_cut_min: 0.1 29 | depth_cut_max: 10.0 30 | 31 | meshing_interval: 10 32 | integrate_interval: 10 33 | track_interval: 10 34 | #color_integrate_interval: 20 35 | 36 | 37 | # Mapping parameters 38 | surface_mapping: 39 | GPIS_mode: "sample" 40 | margin: .1 41 | 42 | # Bound of the scene to be reconstructed 43 | bound_min: [-10., -5., -10.] 44 | bound_max: [10., 5., 10.] 45 | 46 | voxel_size: 0.05 47 | # Prune observations if detected as noise. 48 | prune_min_vox_obs: 1 49 | ignore_count_th: 1.0 50 | encoder_count_th: 60000.0 51 | 52 | # Mapping parameters 53 | context_mapping: 54 | # Bound of the scene to be reconstructed 55 | bound_min: [-10., -5., -10.] 56 | bound_max: [10., 5., 10.] 57 | voxel_size: .02 58 | # Prune observations if detected as noise. 59 | prune_min_vox_obs: 1 60 | ignore_count_th: 1.0 61 | encoder_count_th: 60000.0 62 | 63 | # Mapping parameters 64 | latent_mapping: 65 | # Bound of the scene to be reconstructed 66 | bound_min: [-10., -5., -10.] 67 | bound_max: [10., 5., 10.] 68 | voxel_size: .1 69 | # Prune observations if detected as noise. 70 | prune_min_vox_obs: 1 71 | ignore_count_th: 1.0 72 | encoder_count_th: 60000000.0 73 | 74 | 75 | 76 | 77 | 78 | # Tracking parameters 79 | tracking: 80 | # An array defining how the camera pose is optimized. 81 | # Each element is a dictionary: 82 | # For example {"n": 2, "type": [['sdf'], ['rgb', 1]]} means to optimize the summation of sdf term and rgb term 83 | # at the 1st level pyramid for 2 iterations. 84 | iter_config: 85 | #- {"n": 10, "type": [['rgb', 2]]} 86 | - {"n": 5, "type": [['sdf'], ['rgb', 1]]} 87 | - {"n": 10, "type": [['sdf'], ['rgb', 0]]} 88 | sdf: 89 | robust_kernel: "huber" 90 | robust_k: 5.0 91 | subsample: 0.5 92 | rgb: 93 | weight: 500.0 94 | robust_kernel: null 95 | robust_k: 0.01 96 | min_grad_scale: 0.0 97 | max_depth_delta: 0.2 98 | -------------------------------------------------------------------------------- /configs/replica/office0_w_slam.yaml: -------------------------------------------------------------------------------- 1 | # Sequence parameters 2 | sequence_type: "custom_w_slam.CustomReplicawSLAM" 3 | sequence_kwargs: 4 | path: "data/Replica/office0/results/" 5 | start_frame: 0 6 | end_frame: -1 # Run all frames 7 | first_tq: [0, 0, 0.0, 0.0, -1.0, 0.0, 0.0] 8 | load_gt: False 9 | 10 | mesh_gt: "data/Replica/office0_mesh.ply" 11 | 12 | outdir: "./output/w_slam/replica/office0/" 13 | 14 | slam: True 15 | 16 | # Network parameters (network structure, etc. will be inherited from the training config) 17 | training_hypers: "ckpt/default/hyper.json" 18 | using_epoch: 300 19 | 20 | 21 | # Separate tracking and meshing. 22 | run_async: True 23 | # Enable visualization 24 | vis: False 25 | resolution: 4 26 | 27 | # These two define the range of depth observations to be cropped. Unit is meter. 28 | depth_cut_min: 0.1 29 | depth_cut_max: 10.0 30 | 31 | meshing_interval: 10 32 | integrate_interval: 10 33 | track_interval: 10 34 | #color_integrate_interval: 20 35 | 36 | 37 | # Mapping parameters 38 | surface_mapping: 39 | GPIS_mode: "sample" 40 | margin: .1 41 | 42 | # Bound of the scene to be reconstructed 43 | bound_min: [-10., -5., -10.] 44 | bound_max: [10., 5., 10.] 45 | 46 | voxel_size: 0.05 47 | # Prune observations if detected as noise. 48 | prune_min_vox_obs: 1 49 | ignore_count_th: 1.0 50 | encoder_count_th: 60000.0 51 | 52 | # Mapping parameters 53 | context_mapping: 54 | # Bound of the scene to be reconstructed 55 | bound_min: [-10., -5., -10.] 56 | bound_max: [10., 5., 10.] 57 | voxel_size: .02 58 | # Prune observations if detected as noise. 59 | prune_min_vox_obs: 1 60 | ignore_count_th: 1.0 61 | encoder_count_th: 60000.0 62 | 63 | 64 | 65 | # Tracking parameters 66 | tracking: 67 | # An array defining how the camera pose is optimized. 68 | # Each element is a dictionary: 69 | # For example {"n": 2, "type": [['sdf'], ['rgb', 1]]} means to optimize the summation of sdf term and rgb term 70 | # at the 1st level pyramid for 2 iterations. 71 | iter_config: 72 | #- {"n": 10, "type": [['rgb', 2]]} 73 | - {"n": 5, "type": [['sdf'], ['rgb', 1]]} 74 | - {"n": 10, "type": [['sdf'], ['rgb', 0]]} 75 | sdf: 76 | robust_kernel: "huber" 77 | robust_k: 5.0 78 | subsample: 0.5 79 | rgb: 80 | weight: 500.0 81 | robust_kernel: null 82 | robust_k: 0.01 83 | min_grad_scale: 0.0 84 | max_depth_delta: 0.2 85 | -------------------------------------------------------------------------------- /example/office0/depth000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/example/office0/depth000000.png -------------------------------------------------------------------------------- /example/office0/depth000020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/example/office0/depth000020.png -------------------------------------------------------------------------------- /example/office0/frame000000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/example/office0/frame000000.jpg -------------------------------------------------------------------------------- /example/office0/frame000020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/example/office0/frame000020.jpg -------------------------------------------------------------------------------- /example/render_w_lim.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import pathlib 3 | import importlib 4 | import open3d as o3d 5 | import trimesh 6 | import argparse 7 | from pathlib import Path 8 | import logging 9 | from time import time 10 | import torch 11 | 12 | import cv2 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | 17 | 18 | p = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 19 | sys.path.append(p) 20 | 21 | 22 | from uni.encoder.uni_encoder_v2 import get_uni_model 23 | from uni.mapper.context_map_v2 import ContextMap 24 | from uni.utils import exp_util, vis_util, motion_util 25 | from pyquaternion import Quaternion 26 | from uni.utils.ray_cast import RayCaster 27 | 28 | import pdb 29 | 30 | 31 | 32 | 33 | #from utils.ray_cast import RayCaster 34 | 35 | def depth2pc(depth_im, calib_mat): 36 | H,W = depth_im.shape 37 | 38 | d = depth_im.reshape(-1) 39 | 40 | fx = calib_mat[0,0] 41 | fy = calib_mat[1,1] 42 | cx = calib_mat[0,2] 43 | cy = calib_mat[1,2] 44 | 45 | x = np.arange(W) 46 | y = np.arange(H) 47 | yv, xv = np.meshgrid(y, x, indexing='ij') # HxW 48 | 49 | yv = yv.reshape(-1) # HW 50 | xv = xv.reshape(-1) # HW 51 | 52 | pc = np.zeros((H*W,3)) 53 | pc[:,0] = (xv - cx) / fx * d 54 | pc[:,1] = (yv - cy) / fy * d 55 | pc[:,2] = d 56 | return pc 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | if __name__ == '__main__': 71 | config_path = sys.argv[1] 72 | 73 | args = exp_util.parse_config_yaml(Path(config_path)) 74 | mesh_path = args.outdir+'/final_recons.ply' 75 | render_path = args.outdir+'/render/' 76 | 77 | 78 | if not hasattr(args,'slam'): 79 | slamed = False 80 | else: 81 | slamed = args.slam 82 | use_gt = args.sequence_kwargs['load_gt'] and not slamed 83 | 84 | if use_gt: 85 | traj_path = sys.argv[2] 86 | else: 87 | traj_path = args.outdir+'pred_traj.txt' 88 | 89 | 90 | 91 | 92 | args.context_mapping = exp_util.dict_to_args(args.context_mapping) 93 | main_device = torch.device('cuda',index=0) 94 | uni_model = get_uni_model(main_device) 95 | context_map = ContextMap(uni_model, 96 | args.context_mapping, uni_model.color_code_length, device=main_device, 97 | enable_async=args.run_async) 98 | 99 | context_map.load(args.outdir+'/color.lim') 100 | 101 | # Load in sequence. 102 | seq_package, seq_class = args.sequence_type.split(".") 103 | sequence_module = importlib.import_module("uni.dataset." + seq_package) 104 | sequence_module = getattr(sequence_module, seq_class) 105 | sequence = sequence_module(**args.sequence_kwargs) 106 | 107 | 108 | mesh = o3d.io.read_triangle_mesh(mesh_path) 109 | traj_data = np.genfromtxt(traj_path) 110 | 111 | render_path = pathlib.Path(render_path) 112 | render_path.mkdir(parents=True, exist_ok=True) 113 | 114 | 115 | calib_matrix = np.eye(3) 116 | calib_matrix[0,0] = 600. 117 | calib_matrix[1,1] = 600. 118 | calib_matrix[0,2] = 599.5 119 | calib_matrix[1,2] = 339.5 120 | 121 | H = 680 122 | W = 1200 123 | 124 | 125 | # if using gt_traj for trajectory, change_mat == I 126 | # else will need inv(traj_data[0]) 127 | if use_gt: 128 | change_mat = np.eye(4) 129 | else: 130 | change_mat = np.linalg.inv(traj_data[0].reshape(4,4)) 131 | 132 | ray_caster = RayCaster(mesh, H, W, calib_matrix) 133 | for id, pose in tqdm(enumerate(traj_data)): 134 | pose = pose.reshape((4,4)) 135 | pose = change_mat.dot(pose) 136 | # 2. predict 137 | ans, ray_direction = ray_caster.ray_cast(pose) # N,3 138 | depth_on_ray = ans['t_hit'].numpy().reshape((H,W)) 139 | facing_direction = pose[:3,:3].dot(np.array([[0.,0.,1.]]).T).T # 1,3 140 | facing_direction = facing_direction / np.linalg.norm(facing_direction) 141 | # depth_im is on z axis 142 | depth_im = (ray_direction * facing_direction).sum(-1).reshape((H,W)) * depth_on_ray 143 | 144 | mask_valid = ~np.isinf(depth_im.reshape(-1)) 145 | 146 | pc = depth2pc(depth_im, calib_matrix) 147 | 148 | 149 | pose_ = sequence.first_iso.matrix.dot(np.linalg.inv(traj_data[0].reshape(4,4)).dot(pose)) 150 | 151 | pc = (pose_[:3,:3].dot(pc[mask_valid,:].T) + pose_[:3,(3,)]).T 152 | color, pinds = context_map.infer(torch.from_numpy(pc).to(main_device).float()) 153 | color = torch.clip(color, 0., 1.) 154 | 155 | color_im = np.zeros((H*W,3)) 156 | color_im[mask_valid,:] = color.cpu().numpy() * 255 157 | color_im = color_im.reshape((H,W,3)) 158 | 159 | 160 | # inpainting to fill the hole 161 | mask = (~mask_valid.reshape((H,W,1))).astype(np.uint8) 162 | color_im = color_im.astype(np.uint8) 163 | color_im = cv2.inpaint(color_im, mask,3,cv2.INPAINT_TELEA) 164 | 165 | 166 | #cv2.imwrite(str(render_path)+'/%d.jpg'%(id), color_im[:,:,::-1]) 167 | cv2.imwrite(str(render_path)+'/%d.jpg'%(id), color_im[:,:,::-1]) 168 | cv2.imwrite(str(render_path)+'/%d.png'%(id), (depth_im*6553.5).astype(np.uint16)) 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /example/toy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from example.util import get_modules, get_example_data 5 | import pdb 6 | 7 | device = torch.device("cuda", index=0) 8 | 9 | # get mapper and tracker 10 | sm, cm, tracker, config = get_modules(device) 11 | 12 | # prepare data 13 | colors, depths, customs, calib, poses = get_example_data(device) 14 | 15 | for i in [0, 1]: 16 | # preprocess rgbd to point cloud 17 | frame_pose = tracker.track_camera(colors[i], depths[i], customs, calib, poses[i], scene = config.sequence_type) 18 | # transform data 19 | tracker_pc, tracker_normal, tracker_customs= tracker.last_processed_pc 20 | opt_depth = frame_pose @ tracker_pc 21 | opt_normal = frame_pose.rotation @ tracker_normal 22 | color_pc, color, color_normal = tracker.last_colored_pc 23 | color_pc = frame_pose @ color_pc 24 | color_normal = frame_pose.rotation @ color_normal if color_normal is not None else None 25 | 26 | # mapping pc 27 | sm.integrate_keyframe(opt_depth, opt_normal) 28 | cm.integrate_keyframe(color_pc, color, color_normal) 29 | 30 | # mesh extraction 31 | map_mesh = sm.extract_mesh(config.resolution, int(4e7), max_std=0.15, extract_async=False, interpolate=True) 32 | 33 | import open3d as o3d 34 | o3d.io.write_triangle_mesh('example/mesh.ply', map_mesh) 35 | 36 | 37 | -------------------------------------------------------------------------------- /example/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import yaml 5 | import argparse 6 | from pyquaternion import Quaternion 7 | from uni.encoder.uni_encoder_v2 import get_uni_model 8 | from uni.mapper.surface_map import SurfaceMap 9 | from uni.mapper.context_map_v2 import ContextMap # 8 points 10 | import uni.tracker.tracker_custom as tracker 11 | 12 | from uni.dataset import FrameIntrinsic 13 | from uni.utils import motion_util 14 | 15 | import pdb 16 | 17 | 18 | def get_modules(main_device='cuda:0'): 19 | with open('configs/replica/office0.yaml') as f: 20 | configs = yaml.load(f, Loader=yaml.FullLoader) 21 | args = argparse.Namespace(**configs) 22 | 23 | args.surface_mapping = argparse.Namespace(**(args.surface_mapping)) 24 | args.context_mapping = argparse.Namespace(**(args.context_mapping)) 25 | args.tracking = argparse.Namespace(**(args.tracking)) 26 | 27 | uni_model = get_uni_model(main_device) 28 | cm = ContextMap(uni_model, 29 | args.context_mapping, 30 | uni_model.color_code_length, 31 | device=main_device, 32 | enable_async=False) 33 | 34 | sm = SurfaceMap(uni_model, 35 | cm, 36 | args.surface_mapping, 37 | uni_model.surface_code_length, 38 | device=main_device, 39 | enable_async=False) 40 | 41 | tk = tracker.SDFTracker(sm, args.tracking) 42 | 43 | return sm, cm, tk, args 44 | 45 | def get_example_data(main_device='cuda:0'): 46 | 47 | colors, depths, poses = [], [], [] 48 | for name_rgb, name_depth in [('example/office0/frame000000.jpg', 'example/office0/depth000000.png'), 49 | ('example/office0/frame000020.jpg', 'example/office0/depth000020.png')]: 50 | rgb = cv2.imread(name_rgb,-1) 51 | depth = cv2.imread(name_depth,-1) 52 | 53 | color = torch.from_numpy(rgb).to(main_device).float() / 255. 54 | depth = torch.from_numpy(depth.astype(np.float32)).to(main_device).float() / 6553.5 55 | 56 | colors.append(color) 57 | depths.append(depth) 58 | 59 | 60 | 61 | customs = [None] * 4 62 | calib = FrameIntrinsic(600., 600., 599.5, 339.5, 6553.5) 63 | 64 | first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 65 | traj_mat = np.genfromtxt('example/office0/traj.txt').reshape((-1,4,4)) 66 | for i in [0,20]: 67 | T = traj_mat[i,:,:] 68 | pose = first_iso.dot(motion_util.Isometry.from_matrix(T)) 69 | poses.append(pose) 70 | 71 | return colors, depths, customs, calib, poses 72 | -------------------------------------------------------------------------------- /external/openseg/openseg_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | #sys.path.append(os.path.dirname(__file__)) 4 | 5 | import numpy as np 6 | import torch 7 | import clip 8 | 9 | import tensorflow.compat.v1 as tf 10 | import tensorflow as tf2 11 | 12 | from tqdm import tqdm 13 | 14 | import pdb 15 | ''' 16 | gpus = tf2.config.experimental.list_physical_devices('GPU') 17 | for gpu in gpus: 18 | tf2.config.experimental.set_memory_growth(gpu, True) 19 | ''' 20 | gpus = tf.config.experimental.list_physical_devices('GPU') 21 | if gpus: 22 | # Restrict TensorFlow to only allocate 1GB of memory on the first GPU 23 | try: 24 | tf.config.experimental.set_virtual_device_configuration( 25 | gpus[1], 26 | [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*12)]) # Notice here 27 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 28 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 29 | 30 | #tf.config.experimental.set_memory_growth(gpus[1], True) 31 | 32 | except RuntimeError as e: 33 | # Virtual devices must be set before GPUs have been initialized 34 | print(e) 35 | 36 | 37 | 38 | clip.available_models() 39 | model, preprocess = clip.load("ViT-L/14@336px") 40 | ''' 41 | def f_tx(label_src, device): 42 | #args.label_src = 'plant,grass,cat,stone,other' 43 | 44 | 45 | labels = [] 46 | print('** Input label value: {} **'.format(label_src)) 47 | lines = label_src.split(',') 48 | for line in lines: 49 | label = line 50 | labels.append(label) 51 | 52 | outputs = model.net.text_encode(labels, device) 53 | return outputs 54 | ''' 55 | 56 | def build_text_embedding(categories): 57 | run_on_gpu = torch.cuda.is_available() 58 | with torch.no_grad(): 59 | all_text_embeddings = [] 60 | print("Building text embeddings...") 61 | for category in tqdm(categories): 62 | texts = clip.tokenize(category) #tokenize 63 | if run_on_gpu: 64 | texts = texts.cuda(1) 65 | text_embeddings = model.encode_text(texts) #embed with text encoder 66 | 67 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 68 | 69 | text_embedding = text_embeddings.mean(dim=0) 70 | 71 | text_embedding /= text_embedding.norm() 72 | 73 | all_text_embeddings.append(text_embedding) 74 | 75 | all_text_embeddings = torch.stack(all_text_embeddings, dim=1) 76 | 77 | if run_on_gpu: 78 | all_text_embeddings = all_text_embeddings.cuda(0) 79 | return all_text_embeddings.cpu().numpy().T 80 | 81 | 82 | def f_tx(label_src): 83 | categories = label_src.split(',') 84 | feat = build_text_embedding(categories) 85 | return feat 86 | 87 | 88 | saved_model_dir = os.path.dirname(__file__)+'/exported_model' #@param {type:"string"} 89 | with tf.device('/GPU:1'): 90 | openseg_model = tf2.saved_model.load(saved_model_dir, tags=[tf.saved_model.tag_constants.SERVING],) 91 | 92 | classes = [ 93 | 'unannotated', # 0 94 | 'wall', # 1 95 | 'floor', # 2 96 | 'chair', # 3 97 | 'table', # 4 98 | 'desk', # 5 99 | 'bed', # 6 100 | 'bookshelf', # 7 101 | 'sofa', # 8 102 | 'sink', # 9 103 | 'bathtub', # 10 104 | 'toilet', # 11 105 | 'curtain', # 12 106 | 'counter', # 13 107 | 'door', # 14 108 | 'window', # 15 109 | 'shower curtain', # 16 110 | 'refrigerator', # 17 111 | 'picture', # 18 112 | 'cabinet', # 19 113 | 'otherfurniture', # 20 114 | ] 115 | 116 | 117 | 118 | text = classes 119 | text[0] = 'other' 120 | text = ','.join(text) 121 | 122 | 123 | text_embedding = f_tx(text)#f_tx('desk,table') 124 | num_words_per_category = 1 125 | with tf.device('/GPU:1'): 126 | text_embedding = tf.reshape( 127 | text_embedding, [-1, num_words_per_category, text_embedding.shape[-1]]) 128 | text_embedding = tf.cast(text_embedding, tf.float32) 129 | 130 | def f_im(np_str,H=320,W=240): 131 | with tf.device('/GPU:1'): 132 | output = openseg_model.signatures['serving_default']( 133 | inp_image_bytes=tf.convert_to_tensor(np_str[0]), 134 | inp_text_emb=text_embedding) 135 | #feat = output['image_embedding_feat'][0,:480,:,:] # 1,640,640,768 -> 480,640,768 136 | 137 | # if scannet 138 | ''' 139 | feat = output['ppixel_ave_feat'][0,:480,:,:] # 1,640,640,768 -> 480,640,768 140 | feat = tf.image.resize(feat, [240, 320]) # 240,320,768 141 | ''' 142 | # if 2D-3D-S 143 | feat = output['ppixel_ave_feat'][0,:,:,:] 144 | feat_h, feat_w = feat.shape[:2] 145 | H_ov_W = float(H)/float(W) 146 | feat_cropped = feat[:int(H_ov_W*feat_h),:] 147 | 148 | feat = tf.image.resize(feat_cropped, [H,W]) 149 | 150 | 151 | #feat = tf.image.resize(feat, [320, 320]) 152 | 153 | feat = feat.numpy() 154 | #feat = feat / np.linalg.norm(feat, axis=-1, keepdims=True) 155 | 156 | return feat 157 | 158 | def classify(image_features, text_features): 159 | ''' 160 | both in np 161 | F_im is N,c 162 | F_tx is k,c where k is the classes 163 | ''' 164 | #image_features = image_features / image_features.norm(dim=-1, keepdim=True) 165 | #text_features = text_features / text_features.norm(dim=-1, keepdim=True) 166 | logits_per_image = image_features.half() @ text_features.T # N,k 167 | return logits_per_image 168 | 169 | def get_api(): 170 | return f_im, f_tx, classify, 768 171 | 172 | 173 | 174 | 175 | 176 | ''' 177 | For each category you can list different names. Use ';' to separate different categories and use ',' to separate different names of a category. 178 | E.g. 'lady, ladies, girl, girls; book' creates two categories of 'lady or ladies or girl or girls' and 'book'. 179 | ''' 180 | 181 | -------------------------------------------------------------------------------- /scripts/download_replica.sh: -------------------------------------------------------------------------------- 1 | # script from NICE-SLAM 2 | 3 | mkdir -p data 4 | cd data 5 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip 6 | unzip Replica.zip 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='uni', 4 | version='0.1', 5 | description='Uni-Fusion', 6 | author='Yijun Yuan', 7 | url='https://github.com/Jarrome/Uni-Fusion', 8 | packages=find_packages(), 9 | ) 10 | -------------------------------------------------------------------------------- /uni/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/uni/__init__.py -------------------------------------------------------------------------------- /uni/dataset/3dscene.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import torch 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from threading import Thread, Lock 10 | from dataset.production import * 11 | from utils import motion_util 12 | from pyquaternion import Quaternion 13 | 14 | import open3d as o3d 15 | 16 | import pdb 17 | 18 | 19 | class ImageReader(object): 20 | def __init__(self, ids, timestamps=None, cam=None, is_rgb=False): 21 | self.ids = ids 22 | self.timestamps = timestamps 23 | self.cam = cam 24 | self.cache = dict() 25 | self.idx = 0 26 | 27 | self.is_rgb = is_rgb 28 | 29 | self.ahead = 10 # 10 images ahead of current index 30 | self.waiting = 1.5 # waiting time 31 | 32 | self.preload_thread = Thread(target=self.preload) 33 | self.thread_started = False 34 | 35 | def read(self, path): 36 | img = cv2.imread(path, -1) 37 | if self.is_rgb: 38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 39 | 40 | if self.cam is None: 41 | return img 42 | else: 43 | return self.cam.rectify(img) 44 | 45 | def preload(self): 46 | idx = self.idx 47 | t = float('inf') 48 | while True: 49 | if time.time() - t > self.waiting: 50 | return 51 | if self.idx == idx: 52 | time.sleep(1e-2) 53 | continue 54 | 55 | for i in range(self.idx, self.idx + self.ahead): 56 | if i not in self.cache and i < len(self.ids): 57 | self.cache[i] = self.read(self.ids[i]) 58 | if self.idx + self.ahead > len(self.ids): 59 | return 60 | idx = self.idx 61 | t = time.time() 62 | 63 | def __len__(self): 64 | return len(self.ids) 65 | 66 | def __getitem__(self, idx): 67 | self.idx = idx 68 | # if not self.thread_started: 69 | # self.thread_started = True 70 | # self.preload_thread.start() 71 | 72 | if idx in self.cache: 73 | img = self.cache[idx] 74 | del self.cache[idx] 75 | else: 76 | img = self.read(self.ids[idx]) 77 | 78 | return img 79 | 80 | def __iter__(self): 81 | for i, timestamp in enumerate(self.timestamps): 82 | yield timestamp, self[i] 83 | 84 | @property 85 | def dtype(self): 86 | return self[0].dtype 87 | @property 88 | def shape(self): 89 | return self[0].shape 90 | 91 | 92 | 93 | 94 | 95 | 96 | def make_pair(matrix, threshold=1): 97 | assert (matrix >= 0).all() 98 | pairs = [] 99 | base = defaultdict(int) 100 | while True: 101 | i = matrix[:, 0].argmin() 102 | min0 = matrix[i, 0] 103 | j = matrix[0, :].argmin() 104 | min1 = matrix[0, j] 105 | 106 | if min0 < min1: 107 | i, j = i, 0 108 | else: 109 | i, j = 0, j 110 | if min(min1, min0) < threshold: 111 | pairs.append((i + base['i'], j + base['j'])) 112 | 113 | matrix = matrix[i + 1:, j + 1:] 114 | base['i'] += (i + 1) 115 | base['j'] += (j + 1) 116 | 117 | if min(matrix.shape) == 0: 118 | break 119 | return pairs 120 | 121 | 122 | class ThreeDSceneRGBDDataset: 123 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, register=True, mesh_gt: str = None): 124 | path = os.path.expanduser(path) 125 | 126 | self.calib = FrameIntrinsic(525., 525., 319.5, 239.5, 1000) 127 | 128 | 129 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 130 | 131 | 132 | if not register: 133 | rgb_ids, rgb_timestamps = self.listdir(path, 'color') 134 | depth_ids, depth_timestamps = self.listdir(path, 'depth') 135 | else: 136 | rgb_imgs, rgb_timestamps = self.listdir(path, 'color', ext='.png') 137 | depth_imgs, depth_timestamps = self.listdir(path, 'depth') 138 | 139 | interval = (rgb_timestamps[1:] - rgb_timestamps[:-1]).mean() * 2/3 140 | matrix = np.abs(rgb_timestamps[:, np.newaxis] - depth_timestamps) 141 | pairs = make_pair(matrix, interval) 142 | 143 | rgb_ids = [] 144 | depth_ids = [] 145 | for i, j in pairs: 146 | rgb_ids.append(rgb_imgs[i]) 147 | depth_ids.append(depth_imgs[j]) 148 | 149 | self.rgb = ImageReader(rgb_ids, rgb_timestamps, is_rgb=True) 150 | self.depth = ImageReader(depth_ids, depth_timestamps) 151 | self.timestamps = rgb_timestamps 152 | 153 | self.frame_id = 0 154 | 155 | if load_gt: 156 | path = path[:-1] if path[-1]=='/' else path 157 | scene = path.split('/')[-1] 158 | gt_traj_path = os.path.join(path,scene+'_trajectory.log') 159 | self.gt_trajectory = self._parse_traj_file(gt_traj_path) 160 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 161 | #change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 162 | #self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 163 | assert len(self.gt_trajectory) == len(self.rgb) 164 | self.T_gt2uni = np.eye(4) #change_iso.matrix 165 | 166 | else: 167 | self.gt_trajectory = None 168 | self.T_gt2uni = self.first_iso.matrix 169 | 170 | 171 | 172 | if mesh_gt is None: 173 | print("using reconstruction mesh") 174 | else: 175 | if mesh_gt != '' and load_gt: 176 | self.gt_mesh = self.get_ground_truth_mesh(mesh_gt)#, gt_traj_path) 177 | def get_ground_truth_mesh(self, mesh_path):#, gt_traj_path): 178 | import trimesh 179 | ''' 180 | with open(gt_traj_path, 'r') as f: 181 | ls = f.readlines() 182 | traj_data = [] 183 | for i in range(1): 184 | mat = [] 185 | for j in range(1,5): 186 | mat.append([float(item) for item in ls[i*5+j].strip().split('\t')]) 187 | traj_data.append(np.array(mat)) 188 | 189 | 190 | T0 = traj_data[0].reshape((4,4)) 191 | change_mat = (self.first_iso.matrix.dot(np.linalg.inv(T0))) 192 | ''' 193 | 194 | change_mat = self.T_gt2uni 195 | 196 | 197 | 198 | mesh_gt = trimesh.load(mesh_path) 199 | mesh_gt.apply_transform(change_mat) 200 | return mesh_gt.as_open3d 201 | 202 | 203 | 204 | 205 | 206 | def _parse_traj_file(self,traj_path): 207 | camera_ext = {} 208 | with open(traj_path, 'r') as f: 209 | ls = f.readlines() 210 | traj_data = [] 211 | for i in range(int(len(ls)/5)): 212 | mat = [] 213 | for j in range(1,5): 214 | mat.append([float(item) for item in ls[i*5+j].strip().split('\t')]) 215 | traj_data.append(np.array(mat)) 216 | 217 | #cano_quat = motion_util.Isometry(q=Quaternion(axis=[0.0, 0.0, 1.0], degrees=180.0)) 218 | for id, cur_p in enumerate(traj_data): 219 | T = cur_p.reshape((4,4)) 220 | #cur_q = T[:3,:3] 221 | #cur_t = T[:3, 3] 222 | cur_iso = motion_util.Isometry.from_matrix(T, ortho=True) #(q=Quaternion(matrix=cur_q), t=cur_t) 223 | camera_ext[id] = cur_iso #cano_quat.dot(cur_iso) 224 | camera_ext[len(camera_ext)] = camera_ext[len(camera_ext)-1] 225 | return [camera_ext[t] for t in range(len(camera_ext))] 226 | 227 | 228 | 229 | 230 | 231 | def sort(self, xs, st = 3): 232 | return sorted(xs, key=lambda x:float(x[st:-4])) 233 | 234 | def listdir(self, path, split='rgb', ext='.png'): 235 | imgs, timestamps = [], [] 236 | files = [x for x in os.listdir(os.path.join(path, split)) if x.endswith(ext)] 237 | st = 0 238 | for name in self.sort(files,st): 239 | imgs.append(os.path.join(path, split, name)) 240 | timestamp = float(name[st:-len(ext)].rstrip('.')) 241 | timestamps.append(timestamp) 242 | 243 | return imgs, np.array(timestamps) 244 | 245 | def __getitem__(self, idx): 246 | frame_data = FrameData() 247 | if self.gt_trajectory is not None: 248 | frame_data.gt_pose = self.gt_trajectory[idx] 249 | else: 250 | frame_data.gt_pose = None 251 | frame_data.calib = self.calib 252 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda().float() / 1000. 253 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda().float() / 255. 254 | return frame_data 255 | 256 | 257 | def __next__(self): 258 | frame_data = FrameData() 259 | if self.gt_trajectory is not None: 260 | frame_data.gt_pose = self.gt_trajectory[self.frame_id] 261 | else: 262 | frame_data.gt_pose = None 263 | frame_data.calib = self.calib 264 | frame_data.depth = torch.from_numpy(self.depth[self.frame_id].astype(np.float32)).cuda().float() / 1000. 265 | frame_data.rgb = torch.from_numpy(self.rgb[self.frame_id]).cuda().float() / 255. 266 | self.frame_id += 1 267 | return frame_data 268 | 269 | def __len__(self): 270 | return len(self.rgb) 271 | 272 | 273 | -------------------------------------------------------------------------------- /uni/dataset/NICE_SLAM_config/demo.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'scannet' 2 | sync_method: loose 3 | coarse: True 4 | verbose: False 5 | meshing: 6 | resolution: 256 7 | tracking: 8 | vis_freq: 50 9 | vis_inside_freq: 25 10 | ignore_edge_W: 20 11 | ignore_edge_H: 20 12 | seperate_LR: False 13 | const_speed_assumption: True 14 | lr: 0.0005 15 | pixels: 1000 16 | iters: 30 17 | mapping: 18 | every_frame: 10 19 | vis_freq: 50 20 | vis_inside_freq: 30 21 | mesh_freq: 50 22 | ckpt_freq: 500 23 | keyframe_every: 50 24 | mapping_window_size: 10 25 | pixels: 1000 26 | iters_first: 400 27 | iters: 10 28 | bound: [[0.0,6.5],[0.0,4.0],[0,3.5]] 29 | marching_cubes_bound: [[0.0,6.5],[0.0,4.0],[0,3.5]] 30 | cam: 31 | H: 480 32 | W: 640 33 | fx: 577.590698 34 | fy: 578.729797 35 | cx: 318.905426 36 | cy: 242.683609 37 | png_depth_scale: 1000. #for depth image in png format 38 | crop_edge: 0 39 | data: 40 | input_folder: Datasets/Demo 41 | output: output/Demo 42 | 43 | -------------------------------------------------------------------------------- /uni/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class FrameIntrinsic: 5 | def __init__(self, fx, fy, cx, cy, dscale): 6 | self.cx = cx 7 | self.cy = cy 8 | self.fx = fx 9 | self.fy = fy 10 | self.dscale = dscale 11 | 12 | def to_K(self): 13 | return np.asarray([ 14 | [self.fx, 0.0, self.cx], 15 | [0.0, self.fy, self.cy], 16 | [0.0, 0.0, 1.0] 17 | ]) 18 | 19 | 20 | class FrameData: 21 | def __init__(self): 22 | self.rgb = None 23 | self.depth = None 24 | self.gt_pose = None 25 | self.calib = None 26 | 27 | 28 | class RGBDSequence: 29 | def __init__(self): 30 | self.frame_id = 0 31 | 32 | def __iter__(self): 33 | return self 34 | 35 | def __len__(self): 36 | raise NotImplementedError 37 | 38 | def __next__(self): 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /uni/dataset/aug_icl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import torch 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from threading import Thread, Lock 10 | from dataset.production import * 11 | from utils import motion_util 12 | from pyquaternion import Quaternion 13 | 14 | import open3d as o3d 15 | 16 | import pdb 17 | 18 | 19 | class ImageReader(object): 20 | def __init__(self, ids, timestamps=None, cam=None, is_rgb=False): 21 | self.ids = ids 22 | self.timestamps = timestamps 23 | self.cam = cam 24 | self.cache = dict() 25 | self.idx = 0 26 | 27 | self.is_rgb = is_rgb 28 | 29 | self.ahead = 10 # 10 images ahead of current index 30 | self.waiting = 1.5 # waiting time 31 | 32 | self.preload_thread = Thread(target=self.preload) 33 | self.thread_started = False 34 | 35 | def read(self, path): 36 | img = cv2.imread(path, -1) 37 | if self.is_rgb: 38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 39 | 40 | if self.cam is None: 41 | return img 42 | else: 43 | return self.cam.rectify(img) 44 | 45 | def preload(self): 46 | idx = self.idx 47 | t = float('inf') 48 | while True: 49 | if time.time() - t > self.waiting: 50 | return 51 | if self.idx == idx: 52 | time.sleep(1e-2) 53 | continue 54 | 55 | for i in range(self.idx, self.idx + self.ahead): 56 | if i not in self.cache and i < len(self.ids): 57 | self.cache[i] = self.read(self.ids[i]) 58 | if self.idx + self.ahead > len(self.ids): 59 | return 60 | idx = self.idx 61 | t = time.time() 62 | 63 | def __len__(self): 64 | return len(self.ids) 65 | 66 | def __getitem__(self, idx): 67 | self.idx = idx 68 | # if not self.thread_started: 69 | # self.thread_started = True 70 | # self.preload_thread.start() 71 | 72 | if idx in self.cache: 73 | img = self.cache[idx] 74 | del self.cache[idx] 75 | else: 76 | img = self.read(self.ids[idx]) 77 | 78 | return img 79 | 80 | def __iter__(self): 81 | for i, timestamp in enumerate(self.timestamps): 82 | yield timestamp, self[i] 83 | 84 | @property 85 | def dtype(self): 86 | return self[0].dtype 87 | @property 88 | def shape(self): 89 | return self[0].shape 90 | 91 | 92 | 93 | 94 | 95 | 96 | def make_pair(matrix, threshold=1): 97 | assert (matrix >= 0).all() 98 | pairs = [] 99 | base = defaultdict(int) 100 | while True: 101 | i = matrix[:, 0].argmin() 102 | min0 = matrix[i, 0] 103 | j = matrix[0, :].argmin() 104 | min1 = matrix[0, j] 105 | 106 | if min0 < min1: 107 | i, j = i, 0 108 | else: 109 | i, j = 0, j 110 | if min(min1, min0) < threshold: 111 | pairs.append((i + base['i'], j + base['j'])) 112 | 113 | matrix = matrix[i + 1:, j + 1:] 114 | base['i'] += (i + 1) 115 | base['j'] += (j + 1) 116 | 117 | if min(matrix.shape) == 0: 118 | break 119 | return pairs 120 | 121 | 122 | class AugICLRGBDDataset: 123 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, register=False, mesh_gt: str = None): 124 | path = os.path.expanduser(path) 125 | 126 | self.calib = FrameIntrinsic(525., 525., 319.5, 239.5, 1000) 127 | 128 | 129 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 130 | 131 | 132 | if not register: 133 | rgb_ids, rgb_timestamps = self.listdir(path, 'color',ext='.jpg') 134 | depth_ids, depth_timestamps = self.listdir(path, 'depth_w_noise') 135 | else: 136 | assert False, "not implemented" 137 | 138 | self.rgb = ImageReader(rgb_ids, rgb_timestamps, is_rgb=True) 139 | self.depth = ImageReader(depth_ids, depth_timestamps) 140 | self.timestamps = rgb_timestamps 141 | 142 | self.frame_id = 0 143 | 144 | if load_gt: 145 | path = path[:-1] if path[-1]=='/' else path 146 | scene = path.split('/')[-1] 147 | gt_traj_path = os.path.join(path,scene+'-traj.txt') 148 | self.gt_trajectory = self._parse_traj_file(gt_traj_path) 149 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 150 | change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 151 | self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 152 | assert len(self.gt_trajectory) == len(self.rgb) 153 | self.T_gt2uni = change_iso.matrix 154 | 155 | else: 156 | self.gt_trajectory = None 157 | self.T_gt2uni = self.first_iso.matrix 158 | 159 | 160 | 161 | if mesh_gt is None: 162 | print("using reconstruction mesh") 163 | else: 164 | if mesh_gt != '' and load_gt: 165 | self.gt_mesh = self.get_ground_truth_mesh(mesh_gt)#, gt_traj_path) 166 | def get_ground_truth_mesh(self, mesh_path):#, gt_traj_path): 167 | import trimesh 168 | ''' 169 | with open(gt_traj_path, 'r') as f: 170 | ls = f.readlines() 171 | traj_data = [] 172 | for i in range(1): 173 | mat = [] 174 | for j in range(1,5): 175 | mat.append([float(item) for item in ls[i*5+j].strip().split(' ')]) 176 | traj_data.append(np.array(mat)) 177 | 178 | 179 | T0 = traj_data[0].reshape((4,4)) 180 | change_mat = (self.first_iso.matrix.dot(np.linalg.inv(T0))) 181 | ''' 182 | change_mat = self.T_gt2uni 183 | 184 | mesh_gt = trimesh.load(mesh_path) 185 | mesh_gt.apply_transform(change_mat) 186 | return mesh_gt#.as_open3d 187 | 188 | 189 | 190 | 191 | 192 | def _parse_traj_file(self,traj_path): 193 | camera_ext = {} 194 | with open(traj_path, 'r') as f: 195 | ls = f.readlines() 196 | traj_data = [] 197 | for i in range(int(len(ls)/5)): 198 | mat = [] 199 | for j in range(1,5): 200 | mat.append([float(item) for item in ls[i*5+j].strip().split(' ')]) 201 | traj_data.append(np.array(mat)) 202 | 203 | #cano_quat = motion_util.Isometry(q=Quaternion(axis=[0.0, 0.0, 1.0], degrees=180.0)) 204 | for id, cur_p in enumerate(traj_data): 205 | T = cur_p.reshape((4,4)) 206 | #cur_q = T[:3,:3] 207 | #cur_t = T[:3, 3] 208 | cur_iso = motion_util.Isometry.from_matrix(T, ortho=True) #(q=Quaternion(matrix=cur_q), t=cur_t) 209 | camera_ext[id] = cur_iso #cano_quat.dot(cur_iso) 210 | camera_ext[len(camera_ext)] = camera_ext[len(camera_ext)-1] 211 | return [camera_ext[t] for t in range(len(camera_ext))] 212 | 213 | 214 | 215 | 216 | 217 | def sort(self, xs, st = 3): 218 | return sorted(xs, key=lambda x:float(x[st:-4])) 219 | 220 | def listdir(self, path, split='rgb', ext='.png'): 221 | imgs, timestamps = [], [] 222 | files = [x for x in os.listdir(os.path.join(path, split)) if x.endswith(ext)] 223 | st = 0 224 | for name in self.sort(files,st): 225 | imgs.append(os.path.join(path, split, name)) 226 | timestamp = float(name[st:-len(ext)].rstrip('.')) 227 | timestamps.append(timestamp) 228 | 229 | return imgs, np.array(timestamps) 230 | 231 | def __getitem__(self, idx): 232 | return self.rgb[idx], self.depth[idx] 233 | def __next__(self): 234 | frame_data = FrameData() 235 | if self.gt_trajectory is not None: 236 | frame_data.gt_pose = self.gt_trajectory[self.frame_id] 237 | else: 238 | frame_data.gt_pose = None 239 | frame_data.calib = self.calib 240 | frame_data.depth = torch.from_numpy(self.depth[self.frame_id].astype(np.float32)).cuda().float() / 1000. 241 | frame_data.rgb = torch.from_numpy(self.rgb[self.frame_id]).cuda().float() / 255. 242 | self.frame_id += 1 243 | return frame_data 244 | 245 | def __len__(self): 246 | return len(self.rgb) 247 | 248 | 249 | -------------------------------------------------------------------------------- /uni/dataset/azure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import torch 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from threading import Thread, Lock 10 | from uni.dataset import * 11 | from uni.utils import motion_util 12 | from pyquaternion import Quaternion 13 | 14 | import open3d as o3d 15 | 16 | from PIL import Image 17 | 18 | 19 | import pdb 20 | 21 | 22 | class ImageReader(object): 23 | def __init__(self, ids, timestamps=None, cam=None, is_rgb=False): 24 | self.ids = ids 25 | self.timestamps = timestamps 26 | self.cam = cam 27 | self.cache = dict() 28 | self.idx = 0 29 | 30 | self.is_rgb = is_rgb 31 | 32 | self.ahead = 10 # 10 images ahead of current index 33 | self.waiting = 1.5 # waiting time 34 | 35 | self.preload_thread = Thread(target=self.preload) 36 | self.thread_started = False 37 | 38 | def read(self, path): 39 | img = cv2.imread(path, -1) 40 | if self.is_rgb: 41 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 42 | 43 | if self.cam is None: 44 | return img 45 | else: 46 | return self.cam.rectify(img) 47 | 48 | def preload(self): 49 | idx = self.idx 50 | t = float('inf') 51 | while True: 52 | if time.time() - t > self.waiting: 53 | return 54 | if self.idx == idx: 55 | time.sleep(1e-2) 56 | continue 57 | 58 | for i in range(self.idx, self.idx + self.ahead): 59 | if i not in self.cache and i < len(self.ids): 60 | self.cache[i] = self.read(self.ids[i]) 61 | if self.idx + self.ahead > len(self.ids): 62 | return 63 | idx = self.idx 64 | t = time.time() 65 | 66 | def __len__(self): 67 | return len(self.ids) 68 | 69 | def __getitem__(self, idx): 70 | self.idx = idx 71 | # if not self.thread_started: 72 | # self.thread_started = True 73 | # self.preload_thread.start() 74 | 75 | if idx in self.cache: 76 | img = self.cache[idx] 77 | del self.cache[idx] 78 | else: 79 | img = self.read(self.ids[idx]) 80 | 81 | return img 82 | 83 | def __iter__(self): 84 | for i, timestamp in enumerate(self.timestamps): 85 | yield timestamp, self[i] 86 | 87 | @property 88 | def dtype(self): 89 | return self[0].dtype 90 | @property 91 | def shape(self): 92 | return self[0].shape 93 | 94 | 95 | 96 | def read_orbslam2_file(traj_file): 97 | with open(traj_file) as f: 98 | lines = f.readlines() 99 | poses = [] 100 | frame_ids = [] 101 | for line_id, line in enumerate(lines): 102 | vs = [float(v) for v in line.strip().split(' ')] 103 | frame_id = round(vs[0]*30) 104 | #frame_id = round(vs[0]) 105 | v_t = vs[1:4] 106 | #v_q = vs[4:] # xyzw 107 | v_q = Quaternion(vs[-1],*vs[4:-1]) 108 | pose = v_q.transformation_matrix 109 | pose[:3,3] = np.array(v_t) 110 | poses.append(pose) 111 | frame_ids.append(frame_id) 112 | return frame_ids, poses 113 | 114 | 115 | 116 | 117 | class AzureRGBDIDataset(object): 118 | 119 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, register=True, mesh_gt: str = None): 120 | path = os.path.expanduser(path) 121 | 122 | 123 | cam = np.genfromtxt(path+'/intrinsic.txt') 124 | self.cam = namedtuple('camera', 'fx fy cx cy scale')( 125 | *(cam.tolist()), 1000) 126 | 127 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 128 | 129 | 130 | self.start_frame = start_frame 131 | self.end_frame = end_frame 132 | 133 | rgb_ids, rgb_timestamps = self.listdir(path, 'color') 134 | depth_ids, depth_timestamps = self.listdir(path, 'depth') 135 | ir_ids, ir_timestamps = self.listdir(path, 'ir') 136 | 137 | 138 | if load_gt: 139 | traj_path = path+'/traj_orbslam2.txt' 140 | frame_ids, poses = read_orbslam2_file(traj_path) 141 | rgb_ids = [rgb_ids[frame_id] for frame_id in frame_ids] 142 | depth_ids = [depth_ids[frame_id] for frame_id in frame_ids] 143 | ir_ids = [ir_ids[frame_id] for frame_id in frame_ids] 144 | rgb_timestamps = [rgb_timestamps[frame_id] for frame_id in frame_ids] 145 | ''' 146 | now_id = 0 147 | frame_ids_ = [] 148 | poses_ = [] 149 | for idx in range(0,frame_ids[-1]+1): 150 | frame_ids_.append(idx) 151 | if idx in frame_ids: 152 | poses_.append(poses[now_id]) 153 | now_id += 1 154 | else: 155 | poses_.append(np.eye(4)) 156 | frame_ids = frame_ids_ 157 | poses = poses_ 158 | ''' 159 | 160 | self.gt_trajectory = self._parse_traj(poses) 161 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 162 | change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 163 | self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 164 | self.T_gt2uni = change_iso.matrix 165 | 166 | else: 167 | self.gt_trajectory = None 168 | self.T_gt2uni = self.first_iso.matrix 169 | 170 | 171 | 172 | self.rgb_ids = rgb_ids 173 | self.rgb = ImageReader(rgb_ids, rgb_timestamps, is_rgb=True) 174 | self.depth = ImageReader(depth_ids, depth_timestamps) 175 | self.ir = ImageReader(ir_ids, ir_timestamps) 176 | 177 | self.timestamps = rgb_timestamps 178 | 179 | self.frame_id = 0 180 | 181 | if mesh_gt is None: 182 | print("using reconstruction mesh") 183 | else: 184 | if mesh_gt != '' and load_gt: 185 | self.gt_mesh = self.get_ground_truth_mesh(mesh_gt)#, gt_traj_path) 186 | 187 | # saliency 188 | from transparent_background import Remover 189 | self.saliency_detector = Remover() 190 | 191 | # style 192 | from thirdparts.style_transfer.experiments import style_api 193 | self.style_painting = style_api.get_api() 194 | 195 | 196 | def _parse_traj(self,traj_data): 197 | camera_ext = {} 198 | for id, cur_p in enumerate(traj_data): 199 | T = cur_p.reshape((4,4)) 200 | #cur_q = Quaternion(imaginary=cur_p[4:7], real=cur_p[-1]).rotation_matrix 201 | cur_q = T[:3,:3] 202 | cur_t = T[:3, 3] 203 | cur_iso = motion_util.Isometry(q=Quaternion(matrix=cur_q), t=cur_t) 204 | camera_ext[id] = cur_iso #cano_quat.dot(cur_iso) 205 | camera_ext[len(camera_ext)] = camera_ext[len(camera_ext)-1] 206 | return [camera_ext[t] for t in range(len(camera_ext))] 207 | 208 | 209 | 210 | 211 | 212 | def sort(self, xs, st = 3): 213 | return sorted(xs, key=lambda x:float(x[st:-4])) 214 | 215 | def listdir(self, path, split='rgb', ext='.png'): 216 | imgs, timestamps = [], [] 217 | files = [x for x in os.listdir(os.path.join(path, split)) if x.endswith(ext)] 218 | st = 0 219 | for name in self.sort(files,st): 220 | imgs.append(os.path.join(path, split, name)) 221 | timestamp = float(name[st:-len(ext)].rstrip('.')) 222 | timestamps.append(timestamp) 223 | 224 | imgs = imgs[self.start_frame: self.end_frame] 225 | timestamps = timestamps[self.start_frame: self.end_frame] 226 | 227 | return imgs, np.array(timestamps) 228 | 229 | def __getitem__(self, idx): 230 | frame_data = FrameData() 231 | if self.gt_trajectory is not None: 232 | frame_data.gt_pose = self.gt_trajectory[idx] 233 | else: 234 | frame_data.gt_pose = None 235 | frame_data.calib = FrameIntrinsic(self.cam.fx, self.cam.fy, self.cam.cx, self.cam.cy, self.cam.scale) 236 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda().float() / self.cam.scale 237 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda().float() / 255. 238 | frame_data.ir = torch.from_numpy(self.ir[idx].astype(np.float32)).cuda().float().unsqueeze(-1) 239 | 240 | img = Image.fromarray(self.rgb[idx]).convert('RGB') 241 | frame_data.saliency = torch.from_numpy( 242 | self.saliency_detector.process(img,type='map') 243 | ).cuda().float() /255 244 | 245 | 246 | frame_data.style = torch.from_numpy( 247 | self.style_painting(img)).cuda().float() / 255. 248 | 249 | 250 | return frame_data 251 | 252 | 253 | def __next__(self): 254 | frame_data = FrameData() 255 | if self.gt_trajectory is not None: 256 | frame_data.gt_pose = self.gt_trajectory[self.frame_id] 257 | else: 258 | frame_data.gt_pose = None 259 | frame_data.calib = FrameIntrinsic(self.cam.fx, self.cam.fy, self.cam.cx, self.cam.cy, self.cam.scale) 260 | frame_data.depth = torch.from_numpy(self.depth[self.frame_id].astype(np.float32)).cuda().float() / self.cam.scale 261 | frame_data.rgb = torch.from_numpy(self.rgb[self.frame_id]).cuda().float() / 255. 262 | frame_data.ir = torch.from_numpy(self.ir[self.frame_id].astype(np.float32)).cuda().float() 263 | 264 | img = Image.fromarray(frame).convert('RGB') 265 | frame_data.saliency = torch.from_numpy( 266 | self.saliency_detector.process(img,type='map').astype(np.float32) 267 | ).cuda().float() 268 | 269 | 270 | frame_data.style = torch.from_numpy( 271 | self.style_painting(img)).cuda().float() / 255. 272 | 273 | 274 | self.frame_id += 1 275 | return frame_data 276 | 277 | def __len__(self): 278 | return len(self.rgb) 279 | 280 | -------------------------------------------------------------------------------- /uni/dataset/custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import open3d as o3d 5 | from PIL import Image 6 | import tensorflow.compat.v1 as tf 7 | 8 | 9 | from dataset.production import * 10 | from dataset.production.latent_map import ScanNetDataset 11 | from dataset.production.replica import ReplicaRGBDDataset 12 | from dataset.production.bpnet_scannet import ScanNetLatentDataset 13 | from dataset.production.azure import AzureRGBDIDataset 14 | 15 | from tqdm import tqdm 16 | 17 | 18 | import pdb 19 | 20 | 21 | class CustomScanNet(ScanNetDataset): 22 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, train=True, mesh_gt = None, style_idx=1, has_style=True, has_ir=False, has_saliency=True, has_latent=False, f_im=None): 23 | super().__init__(path, start_frame, end_frame, first_tq, load_gt, train, mesh_gt) 24 | self.has_ir = has_ir 25 | self.has_saliency = has_saliency 26 | self.has_style = has_style 27 | self.has_latent = has_latent 28 | # saliency 29 | from transparent_background import Remover 30 | self.saliency_detector = Remover() 31 | 32 | # style 33 | from thirdparts.style_transfer.experiments import style_api 34 | self.style_painting = style_api.get_api(style_idx) 35 | 36 | # latent 37 | self.latent_func = f_im 38 | 39 | # np_str 40 | self.np_image_strings = [] 41 | for rgb_id in self.inner_dataset.color_paths: 42 | with tf.gfile.GFile(rgb_id, 'rb') as f: 43 | np_image_string = np.array([f.read()]) 44 | self.np_image_strings.append(np_image_string) 45 | 46 | 47 | 48 | def __getitem__(self, idx): 49 | index, rgb, depth, pose = self.inner_dataset[idx] 50 | 51 | frame_data = FrameData() 52 | frame_data.calib = FrameIntrinsic(self.fx, self.fy, self.cx, self.cy, self.depth_scaling_factor) 53 | frame_data.depth = depth.cuda(0).float()# torch.from_numpy(self.depth[idx_id].astype(np.float32)).cuda().float()# / self.depth_scaling_factor 54 | frame_data.rgb = rgb.cuda(0).float() #torch.from_numpy(self.rgb[idx_id]).cuda().float() #/ 255. 55 | 56 | frame_data.gt_pose = self.gt_trajectory[idx] 57 | 58 | 59 | frame_data.ir = None 60 | img = Image.fromarray((rgb.cpu().numpy()*255).astype(np.ubyte)).convert('RGB') 61 | frame_data.saliency = torch.from_numpy( 62 | self.saliency_detector.process(img,type='map').astype(np.float32) 63 | ).cuda(0).float() / 255. if self.has_saliency else None 64 | 65 | frame_data.style = torch.from_numpy( 66 | self.style_painting(img)).cuda(0).float() / 255. if self.has_style else None 67 | 68 | H,W,_ = frame_data.rgb.shape 69 | 70 | frame_data.latent = torch.from_numpy(self.latent_func(self.np_image_strings[idx], H, W)).cuda(0).float() if self.has_latent else None 71 | 72 | 73 | 74 | frame_data.customs = [frame_data.ir, frame_data.saliency, frame_data.style, frame_data.latent] 75 | return frame_data 76 | 77 | class CustomAzure(AzureRGBDIDataset): 78 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, train=True, mesh_gt = None, style_idx=1, has_style=True, has_ir=False, has_saliency=True, has_latent=False, f_im=None): 79 | super().__init__(path, start_frame, end_frame, first_tq, load_gt, train, mesh_gt) 80 | self.has_ir = has_ir 81 | self.has_saliency = has_saliency 82 | self.has_style = has_style 83 | self.has_latent = has_latent 84 | # saliency 85 | from transparent_background import Remover 86 | self.saliency_detector = Remover() 87 | 88 | # style 89 | from thirdparts.style_transfer.experiments import style_api 90 | self.style_painting = style_api.get_api(style_idx) 91 | 92 | # latent 93 | self.latent_func = f_im 94 | 95 | # np_str 96 | if has_latent: 97 | self.np_image_strings = [] 98 | for rgb_id in tqdm(self.rgb_ids): 99 | with tf.gfile.GFile(rgb_id, 'rb') as f: 100 | np_image_string = np.array([f.read()]) 101 | self.np_image_strings.append(np_image_string) 102 | 103 | 104 | 105 | def __getitem__(self, idx): 106 | frame_data = FrameData() 107 | if self.gt_trajectory is not None: 108 | frame_data.gt_pose = self.gt_trajectory[idx] 109 | else: 110 | frame_data.gt_pose = None 111 | frame_data.calib = FrameIntrinsic(self.cam.fx, self.cam.fy, self.cam.cx, self.cam.cy, self.cam.scale) 112 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda(0).float() / self.cam.scale 113 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda(0).float() / 255. 114 | frame_data.ir = torch.from_numpy(self.ir[idx].astype(np.float32)).cuda(0).float().unsqueeze(-1) if self.has_ir else None 115 | 116 | img = Image.fromarray(self.rgb[idx]).convert('RGB') 117 | frame_data.saliency = torch.from_numpy( 118 | self.saliency_detector.process(img,type='map') 119 | ).cuda(0).float() /255. if self.has_saliency else None 120 | 121 | 122 | 123 | frame_data.style = torch.from_numpy( 124 | self.style_painting(img)).cuda(0).float() / 255. if self.has_style else None 125 | 126 | H,W,_ = frame_data.rgb.shape 127 | 128 | frame_data.latent = torch.from_numpy(self.latent_func(self.np_image_strings[idx], H, W)).cuda(0).float() if self.has_latent else None 129 | 130 | 131 | frame_data.customs = [frame_data.ir, frame_data.saliency, frame_data.style, frame_data.latent] 132 | return frame_data 133 | 134 | 135 | 136 | class CustomReplica(ReplicaRGBDDataset): 137 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, train=True, mesh_gt = None, style_idx=1, has_style=True, has_ir=False, has_saliency=True, has_latent=False, f_im=None): 138 | super().__init__(path, start_frame, end_frame, first_tq, load_gt, train, mesh_gt) 139 | self.has_ir = has_ir 140 | self.has_saliency = has_saliency 141 | self.has_style = has_style 142 | self.has_latent = has_latent 143 | # saliency 144 | from transparent_background import Remover 145 | self.saliency_detector = Remover() 146 | 147 | # style 148 | from thirdparts.style_transfer.experiments import style_api 149 | self.style_painting = style_api.get_api(style_idx) 150 | 151 | # latent 152 | self.latent_func = f_im 153 | 154 | 155 | def __getitem__(self, idx): 156 | #return self.rgb[idx], self.depth[idx] 157 | frame_data = FrameData() 158 | if self.gt_trajectory is not None: 159 | frame_data.gt_pose = self.gt_trajectory[idx] 160 | else: 161 | frame_data.gt_pose = None 162 | 163 | frame_data.calib = FrameIntrinsic(600., 600., 599.5, 339.5, 6553.5) 164 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda().float() / 6553.5 165 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda().float() / 255. 166 | 167 | img = Image.fromarray((frame_data.rgb.cpu().numpy()*255).astype(np.ubyte)).convert('RGB') 168 | 169 | frame_data.ir = None 170 | frame_data.saliency = torch.from_numpy( 171 | self.saliency_detector.process(img,type='map').astype(np.float32) 172 | ).cuda().float() / 255. if self.has_saliency else None 173 | 174 | frame_data.style = torch.from_numpy( 175 | self.style_painting(img)).cuda().float() / 255. if self.has_style else None 176 | 177 | frame_data.latent = self.latent_func(frame_data.rgb).permute(2,3,1,0).squeeze(-1) if self.has_latent else None 178 | 179 | frame_data.customs = [frame_data.ir, frame_data.saliency, frame_data.style, frame_data.latent] 180 | 181 | return frame_data 182 | 183 | 184 | class CustomBPNetScanNet(ScanNetLatentDataset): 185 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, train=True, mesh_gt = None, style_idx=1, has_style=True, has_ir=False, has_saliency=True, has_latent=False, f_im=None): 186 | super().__init__(path, f_im, start_frame, end_frame, first_tq, load_gt, train, mesh_gt) 187 | self.has_ir = has_ir 188 | self.has_saliency = has_saliency 189 | self.has_style = has_style 190 | self.has_latent = has_latent 191 | # saliency 192 | from transparent_background import Remover 193 | self.saliency_detector = Remover() 194 | 195 | # style 196 | from thirdparts.style_transfer.experiments import style_api 197 | self.style_painting = style_api.get_api(style_idx) 198 | 199 | # latent 200 | self.latent_func = f_im 201 | 202 | 203 | def __getitem__(self, idx): 204 | frame_data = FrameData() 205 | if self.gt_trajectory is not None: 206 | frame_data.gt_pose = self.gt_trajectory[idx] 207 | else: 208 | frame_data.gt_pose = None 209 | frame_data.calib = FrameIntrinsic(*self.calib) 210 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda(0).float() / 1000 211 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda(0).float() / 255. 212 | 213 | frame_data.ir = None 214 | img = Image.fromarray((frame_data.rgb.cpu().numpy()*255).astype(np.ubyte)).convert('RGB') 215 | frame_data.saliency = torch.from_numpy( 216 | self.saliency_detector.process(img,type='map').astype(np.float32) 217 | ).cuda().float() / 255. if self.has_saliency else None 218 | 219 | frame_data.style = torch.from_numpy( 220 | self.style_painting(img)).cuda().float() / 255. if self.has_style else None 221 | 222 | H,W,_ = self.rgb.shape 223 | 224 | frame_data.latent = torch.from_numpy(self.f_im(self.np_image_strings[idx], H, W)).cuda(0).float() if self.has_latent else None 225 | 226 | 227 | 228 | frame_data.customs = [frame_data.ir, frame_data.saliency, frame_data.style, frame_data.latent] 229 | return frame_data 230 | 231 | 232 | -------------------------------------------------------------------------------- /uni/dataset/fountain.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | from dataset.production import * 5 | from pyquaternion import Quaternion 6 | from pathlib import Path 7 | from utils import motion_util 8 | import pdb 9 | 10 | 11 | class FountainSequence(RGBDSequence): 12 | def __init__(self, path: str, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False): 13 | super().__init__() 14 | self.path = Path(path) 15 | 16 | with open(self.path/"depth_vga.match", 'r') as f: 17 | lines = f.readlines() 18 | pairs = [l.strip().split(' ') for l in lines] 19 | self.color_names = [] 20 | self.depth_names = [] 21 | for pair in pairs: 22 | self.depth_names.append(pair[0]) 23 | self.color_names.append(pair[-1]) 24 | 25 | #self.color_names = sorted([f"rgb/{t}" for t in os.listdir(self.path / "rgb")], key=lambda t: int(t[4:].split(".")[0])) 26 | #self.depth_names = [f"depth/{t}.png" for t in range(len(self.color_names))] 27 | self.calib = [525., 525.0, 319.50, 239.50, 1000.0] 28 | if first_tq is not None: 29 | self.first_iso = motion_util.Isometry(q=Quaternion(array=first_tq[3:]), t=np.array(first_tq[:3])) 30 | else: 31 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 32 | 33 | if end_frame == -1: 34 | end_frame = len(self.color_names) 35 | 36 | self.color_names = self.color_names[start_frame:end_frame] 37 | self.depth_names = self.depth_names[start_frame:end_frame] 38 | 39 | if load_gt: 40 | gt_traj_path = list(self.path.glob("fountain_key.log"))[0] 41 | self.gt_trajectory = self._parse_traj_file(gt_traj_path) 42 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 43 | change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 44 | self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 45 | assert len(self.gt_trajectory) == len(self.color_names) 46 | else: 47 | self.gt_trajectory = None 48 | 49 | def _parse_traj_file(self, traj_path): 50 | camera_ext = {} 51 | traj_data = []#np.genfromtxt(traj_path) 52 | with open(traj_path, 'r') as f: 53 | ls = f.readlines() 54 | 55 | 56 | cano_quat = motion_util.Isometry(q=Quaternion(axis=[0.0, 0.0, 1.0], degrees=180.0)) 57 | for cur_p in traj_data: 58 | cur_q = Quaternion(imaginary=cur_p[4:7], real=cur_p[-1]).rotation_matrix 59 | cur_t = cur_p[1:4] 60 | ''' 61 | cur_q[1] = -cur_q[1] 62 | cur_q[:, 1] = -cur_q[:, 1] 63 | cur_t[1] = -cur_t[1] 64 | ''' 65 | cur_iso = motion_util.Isometry(q=Quaternion(matrix=cur_q), t=cur_t) 66 | camera_ext[cur_p[0]] = cano_quat.dot(cur_iso) 67 | camera_ext[0] = camera_ext[1] 68 | return [camera_ext[t] for t in range(len(camera_ext))] 69 | 70 | def __len__(self): 71 | return len(self.color_names) 72 | 73 | def __next__(self): 74 | if self.frame_id >= len(self): 75 | raise StopIteration 76 | depth_img_path = self.path / self.depth_names[self.frame_id] 77 | rgb_img_path = self.path / self.color_names[self.frame_id] 78 | 79 | # Convert depth image into point cloud. 80 | depth_data = cv2.imread(str(depth_img_path), cv2.IMREAD_UNCHANGED) 81 | depth_data = torch.from_numpy(depth_data.astype(np.float32)).cuda() / self.calib[4] 82 | rgb_data = cv2.imread(str(rgb_img_path)) 83 | rgb_data = cv2.cvtColor(rgb_data, cv2.COLOR_BGR2RGB) 84 | rgb_data = torch.from_numpy(rgb_data).cuda().float() / 255. 85 | 86 | frame_data = FrameData() 87 | frame_data.gt_pose = self.gt_trajectory[self.frame_id] if self.gt_trajectory is not None else None 88 | frame_data.calib = FrameIntrinsic(self.calib[0], self.calib[1], self.calib[2], self.calib[3], self.calib[4]) 89 | frame_data.depth = depth_data 90 | frame_data.rgb = rgb_data 91 | 92 | self.frame_id += 1 93 | return frame_data 94 | -------------------------------------------------------------------------------- /uni/dataset/icl_nuim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | from dataset.production import * 5 | from pyquaternion import Quaternion 6 | from pathlib import Path 7 | from utils import motion_util 8 | import pdb 9 | 10 | 11 | class ICLNUIMSequence(RGBDSequence): 12 | def __init__(self, path: str, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False): 13 | super().__init__() 14 | self.path = Path(path) 15 | self.color_names = sorted([f"rgb/{t}" for t in os.listdir(self.path / "rgb")], key=lambda t: int(t[4:].split(".")[0])) 16 | self.depth_names = [f"depth/{t}.png" for t in range(len(self.color_names))] 17 | self.calib = [481.2, 480.0, 319.50, 239.50, 5000.0] 18 | if first_tq is not None: 19 | self.first_iso = motion_util.Isometry(q=Quaternion(array=first_tq[3:]), t=np.array(first_tq[:3])) 20 | else: 21 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 22 | 23 | if end_frame == -1: 24 | end_frame = len(self.color_names) 25 | 26 | self.color_names = self.color_names[start_frame:end_frame] 27 | self.depth_names = self.depth_names[start_frame:end_frame] 28 | 29 | if load_gt: 30 | gt_traj_path = (list(self.path.glob("*.freiburg")) + list(self.path.glob("groundtruth.txt")))[0] 31 | self.gt_trajectory = self._parse_traj_file(gt_traj_path) 32 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 33 | change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 34 | self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 35 | assert len(self.gt_trajectory) == len(self.color_names) 36 | else: 37 | self.gt_trajectory = None 38 | 39 | def _parse_traj_file(self, traj_path): 40 | camera_ext = {} 41 | traj_data = np.genfromtxt(traj_path) 42 | cano_quat = motion_util.Isometry(q=Quaternion(axis=[0.0, 0.0, 1.0], degrees=180.0)) 43 | for cur_p in traj_data: 44 | cur_q = Quaternion(imaginary=cur_p[4:7], real=cur_p[-1]).rotation_matrix 45 | cur_t = cur_p[1:4] 46 | cur_q[1] = -cur_q[1] 47 | cur_q[:, 1] = -cur_q[:, 1] 48 | cur_t[1] = -cur_t[1] 49 | cur_iso = motion_util.Isometry(q=Quaternion(matrix=cur_q), t=cur_t) 50 | camera_ext[cur_p[0]] = cano_quat.dot(cur_iso) 51 | camera_ext[0] = camera_ext[1] 52 | return [camera_ext[t] for t in range(len(camera_ext))] 53 | 54 | def __len__(self): 55 | return len(self.color_names) 56 | 57 | def __next__(self): 58 | if self.frame_id >= len(self): 59 | raise StopIteration 60 | 61 | depth_img_path = self.path / self.depth_names[self.frame_id] 62 | rgb_img_path = self.path / self.color_names[self.frame_id] 63 | 64 | # Convert depth image into point cloud. 65 | depth_data = cv2.imread(str(depth_img_path), cv2.IMREAD_UNCHANGED) 66 | depth_data = torch.from_numpy(depth_data.astype(np.float32)).cuda() / self.calib[4] 67 | rgb_data = cv2.imread(str(rgb_img_path)) 68 | rgb_data = cv2.cvtColor(rgb_data, cv2.COLOR_BGR2RGB) 69 | rgb_data = torch.from_numpy(rgb_data).cuda().float() / 255. 70 | 71 | frame_data = FrameData() 72 | frame_data.gt_pose = self.gt_trajectory[self.frame_id] if self.gt_trajectory is not None else None 73 | frame_data.calib = FrameIntrinsic(self.calib[0], self.calib[1], self.calib[2], self.calib[3], self.calib[4]) 74 | frame_data.depth = depth_data 75 | frame_data.rgb = rgb_data 76 | 77 | self.frame_id += 1 78 | return frame_data 79 | -------------------------------------------------------------------------------- /uni/dataset/matterport3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import torch 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from threading import Thread, Lock 10 | from dataset.production import * 11 | from utils import motion_util 12 | from pyquaternion import Quaternion 13 | 14 | import open3d as o3d 15 | 16 | import pdb 17 | 18 | 19 | 20 | 21 | class Matterport3DRGBDDataset(): 22 | ''' 23 | follow https://github.com/otakuxiang/circle/blob/master/torch/sample_matterport.py 24 | ''' 25 | 26 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, mesh_gt: str = None): 27 | path = os.path.expanduser(path) 28 | 29 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 30 | 31 | self.depth_path = os.path.join(path,"matterport_depth_images") 32 | self.rgb_path = os.path.join(path,"matterport_color_images") 33 | self.pose_path = os.path.join(path,"matterport_camera_poses") 34 | self.intri_path = os.path.join(path,"matterport_camera_intrinsics") 35 | tripod_numbers = [ins[:ins.find("_")] for ins in os.listdir(self.intri_path)] 36 | self.depthMapFactor = 4000 37 | 38 | 39 | self.frames = [] 40 | for tripod_number in tripod_numbers: 41 | for camera_id in range(3): 42 | for frame_id in range(6): 43 | self.frames.append([tripod_number,camera_id,frame_id]) 44 | self.frame_ids = list(range(len(self.frames))) 45 | 46 | 47 | if load_gt: 48 | self.gt_trajectory = self._parse_traj_file(self.pose_path) 49 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 50 | change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 51 | self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 52 | #assert len(self.gt_trajectory) == len(self.rgb) 53 | self.T_gt2uni = change_iso.matrix 54 | 55 | else: 56 | self.gt_trajectory = None 57 | self.T_gt2uni = self.first_iso.matrix 58 | 59 | 60 | 61 | 62 | 63 | 64 | self.frame_id = 0 65 | 66 | def __len__(self): 67 | return len(self.frames) 68 | 69 | def _parse_traj_file(self, traj_path): 70 | traj_data = [] 71 | for frame_id in range(len(self)): 72 | tripod_number,camera_id,frame_idx = self.frames[frame_id] 73 | f = open(os.path.join(self.pose_path,f"{tripod_number}_pose_{camera_id}_{frame_idx}.txt")) 74 | pose = np.zeros((4,4)) 75 | for idx,line in enumerate(f): 76 | ss = line.strip().split(" ") 77 | for k in range(0,4): 78 | pose[idx,k] = float(ss[k]) 79 | # pose = np.linalg.inv(pose) 80 | traj_data.append(pose) 81 | 82 | f.close() 83 | 84 | camera_ext = {} 85 | for id, cur_p in enumerate(traj_data): 86 | T = cur_p 87 | cur_q = T[:3,:3] 88 | cur_t = T[:3, 3] 89 | cur_iso = motion_util.Isometry(q=Quaternion(matrix=cur_q, atol=1e-5, rtol=1e-5), t=cur_t) 90 | camera_ext[id] = cur_iso 91 | camera_ext[len(camera_ext)] = camera_ext[len(camera_ext)-1] 92 | return [camera_ext[t] for t in range(len(camera_ext))] 93 | 94 | 95 | 96 | 97 | 98 | def __getitem__(self, frame_id): 99 | tripod_number,camera_id,frame_idx = self.frames[frame_id] 100 | ''' 101 | f = open(os.path.join(self.pose_path,f"{tripod_number}_pose_{camera_id}_{frame_idx}.txt")) 102 | pose = np.zeros((4,4)) 103 | for idx,line in enumerate(f): 104 | ss = line.strip().split(" ") 105 | for k in range(0,4): 106 | pose[idx,k] = float(ss[k]) 107 | # pose = np.linalg.inv(pose) 108 | pose = torch.from_numpy(pose).float() 109 | 110 | f.close() 111 | ''' 112 | K_depth = np.zeros((3,3)) 113 | f = open(os.path.join(self.intri_path,f"{tripod_number}_intrinsics_{camera_id}.txt")) 114 | p = np.zeros((4)) 115 | for idx,line in enumerate(f): 116 | ss = line.strip().split(" ") 117 | for j in range(4): 118 | p[j] = float(ss[j+2]) 119 | f.close() 120 | K_depth[0,0] = p[0] 121 | K_depth[1,1] = p[1] 122 | K_depth[2,2] = 1 123 | K_depth[0,2] = p[2] 124 | K_depth[1,2] = p[3] 125 | depth_path = os.path.join(self.depth_path,tripod_number+"_d"+str(camera_id)+"_"+str(frame_idx)+".png") 126 | depth =cv2.imread(depth_path,-1) 127 | rgb_path = os.path.join(self.rgb_path,tripod_number+"_i"+str(camera_id)+"_"+str(frame_idx)+".jpg") 128 | rgb = cv2.imread(rgb_path, -1) 129 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) 130 | 131 | ins = torch.from_numpy(K_depth).float() 132 | if depth is None: 133 | print("get None image!") 134 | print(depth_path) 135 | return None 136 | 137 | depth = depth.astype(np.float32) / self.depthMapFactor 138 | depth = torch.from_numpy(depth).float() 139 | 140 | rgb = rgb.astype(np.float32) / 255 141 | rgb = torch.from_numpy(rgb).float() 142 | 143 | assert depth.shape[:2] == rgb.shape[:2], 'depth shape should == rgb shape' 144 | 145 | return rgb,depth,ins 146 | def __next__(self): 147 | rgb, depth, K = self[self.frame_id] 148 | 149 | frame_data = FrameData() 150 | frame_data.calib = FrameIntrinsic(K[0,0],K[1,1],K[0,2],K[1,2],self.depthMapFactor) 151 | frame_data.depth = depth.cuda() 152 | frame_data.rgb = rgb.cuda() 153 | frame_data.gt_pose = self.gt_trajectory[self.frame_id] 154 | 155 | self.frame_id += 1 156 | return frame_data 157 | 158 | 159 | -------------------------------------------------------------------------------- /uni/dataset/scannet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import torch 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from threading import Thread, Lock 10 | from uni.dataset import * 11 | from uni.utils import motion_util 12 | from pyquaternion import Quaternion 13 | 14 | import open3d as o3d 15 | from tqdm import tqdm 16 | import glob 17 | 18 | import pdb 19 | 20 | 21 | class ImageReader(object): 22 | def __init__(self, ids, timestamps=None, cam=None, is_rgb=False, resize_shape=None): 23 | self.ids = ids 24 | self.timestamps = timestamps 25 | self.cam = cam 26 | self.cache = dict() 27 | self.idx = 0 28 | 29 | self.resize_shape = resize_shape 30 | self.is_rgb = is_rgb 31 | 32 | self.ahead = 10 # 10 images ahead of current index 33 | self.waiting = 1.5 # waiting time 34 | 35 | self.preload_thread = Thread(target=self.preload) 36 | self.thread_started = False 37 | 38 | def read(self, path): 39 | img = cv2.imread(path, -1) 40 | if self.is_rgb: 41 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 42 | 43 | if self.resize_shape is not None: 44 | img = cv2.resize(img, self.resize_shape) 45 | 46 | 47 | if self.cam is None: 48 | return img 49 | else: 50 | return self.cam.rectify(img) 51 | 52 | def preload(self): 53 | idx = self.idx 54 | t = float('inf') 55 | while True: 56 | if time.time() - t > self.waiting: 57 | return 58 | if self.idx == idx: 59 | time.sleep(1e-2) 60 | continue 61 | 62 | for i in range(self.idx, self.idx + self.ahead): 63 | if i not in self.cache and i < len(self.ids): 64 | self.cache[i] = self.read(self.ids[i]) 65 | if self.idx + self.ahead > len(self.ids): 66 | return 67 | idx = self.idx 68 | t = time.time() 69 | 70 | def __len__(self): 71 | return len(self.ids) 72 | 73 | def __getitem__(self, idx): 74 | self.idx = idx 75 | # if not self.thread_started: 76 | # self.thread_started = True 77 | # self.preload_thread.start() 78 | 79 | if idx in self.cache: 80 | img = self.cache[idx] 81 | del self.cache[idx] 82 | else: 83 | img = self.read(self.ids[idx]) 84 | 85 | return img 86 | 87 | def __iter__(self): 88 | for i, timestamp in enumerate(self.timestamps): 89 | yield timestamp, self[i] 90 | 91 | @property 92 | def dtype(self): 93 | return self[0].dtype 94 | @property 95 | def shape(self): 96 | return self[0].shape 97 | 98 | 99 | 100 | 101 | 102 | 103 | def make_pair(matrix, threshold=1): 104 | assert (matrix >= 0).all() 105 | pairs = [] 106 | base = defaultdict(int) 107 | while True: 108 | i = matrix[:, 0].argmin() 109 | min0 = matrix[i, 0] 110 | j = matrix[0, :].argmin() 111 | min1 = matrix[0, j] 112 | 113 | if min0 < min1: 114 | i, j = i, 0 115 | else: 116 | i, j = 0, j 117 | if min(min1, min0) < threshold: 118 | pairs.append((i + base['i'], j + base['j'])) 119 | 120 | matrix = matrix[i + 1:, j + 1:] 121 | base['i'] += (i + 1) 122 | base['j'] += (j + 1) 123 | 124 | if min(matrix.shape) == 0: 125 | break 126 | return pairs 127 | 128 | 129 | 130 | class ScanNetRGBDDataset(object): 131 | ''' 132 | path example: 'path/to/your/TUM R-GBD Dataset/rgbd_dataset_freiburg1_xyz' 133 | ''' 134 | 135 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, register=True, mesh_gt: str = None): 136 | path = os.path.expanduser(path) 137 | 138 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 139 | 140 | if load_gt: 141 | gt_traj_path = sorted(glob.glob(path+'/pose/*.txt'), key=lambda x: int(os.path.basename(x)[:-4])) 142 | self.gt_trajectory = self._parse_traj_file(gt_traj_path) 143 | # some pose is None 144 | invalid_id = [i for i, pose in enumerate(self.gt_trajectory) if pose is None] 145 | for i in invalid_id[::-1]: 146 | del self.gt_trajectory[i] 147 | 148 | 149 | self.gt_trajectory = self.gt_trajectory[start_frame:end_frame] 150 | change_iso = self.first_iso.dot(self.gt_trajectory[0].inv()) 151 | self.gt_trajectory = [change_iso.dot(t) for t in self.gt_trajectory] 152 | #assert len(self.gt_trajectory) == len(self.rgb) 153 | self.T_gt2uni = change_iso.matrix 154 | 155 | else: 156 | self.gt_trajectory = None 157 | self.T_gt2uni = self.first_iso.matrix 158 | rgb_ids, rgb_timestamps = self.listdir(path, 'color', ext='.jpg') 159 | depth_ids, depth_timestamps = self.listdir(path, 'depth') 160 | if load_gt: 161 | for i in invalid_id[::-1]: 162 | del(rgb_ids[i]) 163 | del(depth_ids[i]) 164 | 165 | np.delete(rgb_timestamps,invalid_id) 166 | np.delete(depth_timestamps,invalid_id) 167 | 168 | rgb_ids = rgb_ids[start_frame:end_frame] 169 | depth_ids = depth_ids[start_frame:end_frame] 170 | rgb_timestamps = rgb_timestamps[start_frame:end_frame] 171 | 172 | self.rgb_ids = rgb_ids 173 | 174 | self.depth = ImageReader(depth_ids, depth_timestamps) 175 | H, W = self.depth[0].shape 176 | self.rgb = ImageReader(rgb_ids, rgb_timestamps, is_rgb=True, resize_shape=(W,H)) 177 | 178 | self.timestamps = rgb_timestamps 179 | 180 | self.frame_id = 0 181 | 182 | 183 | 184 | if mesh_gt is None: 185 | print("using reconstruction mesh") 186 | else: 187 | if mesh_gt != '' and load_gt: 188 | self.gt_mesh = self.get_ground_truth_mesh(mesh_gt)#, gt_traj_path) 189 | def get_ground_truth_mesh(self, mesh_path):#, gt_traj_path): 190 | import trimesh 191 | ''' 192 | traj_data = np.genfromtxt(gt_traj_path) 193 | T0 = traj_data[0].reshape((4,4)) 194 | change_mat = (self.first_iso.matrix.dot(np.linalg.inv(T0))) 195 | ''' 196 | change_mat = self.T_gt2uni 197 | 198 | 199 | 200 | mesh_gt = trimesh.load(mesh_path) 201 | mesh_gt.apply_transform(change_mat) 202 | return mesh_gt.as_open3d 203 | 204 | 205 | 206 | 207 | 208 | def _parse_traj_file(self,traj_path): 209 | camera_ext = {} 210 | traj_data = [np.genfromtxt(traj_file) for traj_file in traj_path] 211 | #cano_quat = motion_util.Isometry(q=Quaternion(axis=[0.0, 0.0, 1.0], degrees=180.0)) 212 | for id, cur_p in enumerate(traj_data): 213 | T = cur_p.reshape((4,4)) 214 | #cur_q = Quaternion(imaginary=cur_p[4:7], real=cur_p[-1]).rotation_matrix 215 | cur_q = T[:3,:3] 216 | cur_t = T[:3, 3] 217 | #cur_q[1] = -cur_q[1] 218 | #cur_q[:, 1] = -cur_q[:, 1] 219 | #cur_t[1] = -cur_t[1] 220 | try: 221 | cur_iso = motion_util.Isometry(q=Quaternion(matrix=cur_q, atol=1e-5, rtol=1e-5), t=cur_t) 222 | except Exception as e: 223 | cur_iso = None 224 | camera_ext[id] = cur_iso #cano_quat.dot(cur_iso) 225 | #camera_ext[len(camera_ext)] = camera_ext[len(camera_ext)-1] 226 | return [camera_ext[t] for t in range(len(camera_ext))] 227 | 228 | 229 | 230 | def __len__(self): 231 | return len(self.rgb) 232 | 233 | 234 | 235 | 236 | def sort(self, xs, st = 3): 237 | return sorted(xs, key=lambda x:float(x[st:-4])) 238 | 239 | def listdir(self, path, split='rgb', ext='.png'): 240 | imgs, timestamps = [], [] 241 | files = [x for x in os.listdir(os.path.join(path, split)) if x.endswith(ext)] 242 | st = 0 243 | for name in self.sort(files,st): 244 | imgs.append(os.path.join(path, split, name)) 245 | timestamp = float(name[st:-len(ext)].rstrip('.')) 246 | timestamps.append(timestamp) 247 | 248 | return imgs, np.array(timestamps) 249 | 250 | def __getitem__(self, idx): 251 | #return self.rgb[idx], self.depth[idx] 252 | frame_data = FrameData() 253 | if self.gt_trajectory is not None: 254 | frame_data.gt_pose = self.gt_trajectory[idx] 255 | else: 256 | frame_data.gt_pose = None 257 | frame_data.calib = FrameIntrinsic(600., 600., 599.5, 339.5, 6553.5) 258 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda().float() / 6553.5 259 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda().float() / 255. 260 | return frame_data 261 | 262 | 263 | 264 | -------------------------------------------------------------------------------- /uni/dataset/tum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import time 5 | import torch 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from threading import Thread, Lock 10 | from uni.dataset import * 11 | from uni.utils import motion_util 12 | from pyquaternion import Quaternion 13 | 14 | import open3d as o3d 15 | 16 | import pdb 17 | 18 | 19 | class ImageReader(object): 20 | def __init__(self, ids, timestamps=None, cam=None, is_rgb=False): 21 | self.ids = ids 22 | self.timestamps = timestamps 23 | self.cam = cam 24 | self.cache = dict() 25 | self.idx = 0 26 | 27 | self.is_rgb = is_rgb 28 | 29 | self.ahead = 10 # 10 images ahead of current index 30 | self.waiting = 1.5 # waiting time 31 | 32 | self.preload_thread = Thread(target=self.preload) 33 | self.thread_started = False 34 | 35 | def read(self, path): 36 | img = cv2.imread(path, -1) 37 | if self.is_rgb: 38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 39 | 40 | if self.cam is None: 41 | return img 42 | else: 43 | return self.cam.rectify(img) 44 | 45 | def preload(self): 46 | idx = self.idx 47 | t = float('inf') 48 | while True: 49 | if time.time() - t > self.waiting: 50 | return 51 | if self.idx == idx: 52 | time.sleep(1e-2) 53 | continue 54 | 55 | for i in range(self.idx, self.idx + self.ahead): 56 | if i not in self.cache and i < len(self.ids): 57 | self.cache[i] = self.read(self.ids[i]) 58 | if self.idx + self.ahead > len(self.ids): 59 | return 60 | idx = self.idx 61 | t = time.time() 62 | 63 | def __len__(self): 64 | return len(self.ids) 65 | 66 | def __getitem__(self, idx): 67 | self.idx = idx 68 | # if not self.thread_started: 69 | # self.thread_started = True 70 | # self.preload_thread.start() 71 | 72 | if idx in self.cache: 73 | img = self.cache[idx] 74 | del self.cache[idx] 75 | else: 76 | img = self.read(self.ids[idx]) 77 | 78 | return img 79 | 80 | def __iter__(self): 81 | for i, timestamp in enumerate(self.timestamps): 82 | yield timestamp, self[i] 83 | 84 | @property 85 | def dtype(self): 86 | return self[0].dtype 87 | @property 88 | def shape(self): 89 | return self[0].shape 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | class TUMRGBDDataset(): 98 | def __init__(self, path, start_frame: int = 0, end_frame: int = -1, first_tq: list = None, load_gt: bool = False, register=True, mesh_gt: str = None): 99 | path = os.path.expanduser(path) 100 | self.first_iso = motion_util.Isometry(q=Quaternion(array=[0.0, -1.0, 0.0, 0.0])) 101 | rgb_ids = [] 102 | depth_ids = [] 103 | self.timestamps = [] 104 | with open(os.path.join(path,'asso.txt'),'r') as f: 105 | ls = f.readlines() 106 | for l in ls: 107 | elems = l.strip().split(' ') 108 | rgb_id = elems[1] 109 | depth_id = elems[3] 110 | timestamp = elems[0] 111 | rgb_ids.append(os.path.join(path,rgb_id)) 112 | depth_ids.append(os.path.join(path,depth_id)) 113 | self.timestamps.append(timestamp) 114 | 115 | 116 | 117 | self.rgb = ImageReader(rgb_ids) 118 | self.depth = ImageReader(depth_ids) 119 | 120 | self.frame_id = 0 121 | 122 | 123 | 124 | assert load_gt == False, "NO TUM GT TRAJECTORY" 125 | self.gt_trajectory = None 126 | self.T_gt2uni = self.first_iso.matrix 127 | 128 | 129 | 130 | 131 | def sort(self, xs): 132 | return sorted(xs, key=lambda x:float(x[:-4])) 133 | 134 | def __getitem__(self, idx): 135 | frame_data = FrameData() 136 | frame_data.gt_pose = None 137 | frame_data.calib = FrameIntrinsic(525., 525., 319.5, 239.5, 5000) 138 | frame_data.depth = torch.from_numpy(self.depth[idx].astype(np.float32)).cuda().float() / 5000 139 | frame_data.rgb = torch.from_numpy(self.rgb[idx]).cuda().float() / 255. 140 | return frame_data 141 | 142 | 143 | def __next__(self): 144 | frame_data = FrameData() 145 | frame_data.gt_pose = None 146 | frame_data.calib = FrameIntrinsic(525., 525., 319.5, 239.5, 5000) 147 | frame_data.depth = torch.from_numpy(self.depth[self.frame_id].astype(np.float32)).cuda().float() / 5000 148 | frame_data.rgb = torch.from_numpy(self.rgb[self.frame_id]).cuda().float() / 255. 149 | self.frame_id += 1 150 | return frame_data 151 | 152 | 153 | 154 | def __len__(self): 155 | return len(self.rgb) 156 | -------------------------------------------------------------------------------- /uni/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/uni/encoder/__init__.py -------------------------------------------------------------------------------- /uni/encoder/position_encoder.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/uni/encoder/position_encoder.pth -------------------------------------------------------------------------------- /uni/encoder/utility.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | 7 | p_dir = os.path.dirname(os.path.dirname(__file__)) 8 | sys.path.append(p_dir) 9 | 10 | from utils import exp_util 11 | 12 | from pathlib import Path 13 | import importlib 14 | import pdb 15 | 16 | 17 | class Networks: 18 | def __init__(self): 19 | self.decoder = None 20 | self.encoder = None 21 | 22 | def eval(self): 23 | if self.encoder is not None: 24 | self.encoder.eval() 25 | if self.decoder is not None: 26 | self.decoder.eval() 27 | 28 | 29 | def load_model(training_hyper_path: str, use_epoch: int = -1): 30 | """ 31 | Load in the model and hypers used. 32 | :param training_hyper_path: 33 | :param use_epoch: if -1, will load the latest model. 34 | :return: Networks 35 | """ 36 | training_hyper_path = Path(training_hyper_path) 37 | 38 | if training_hyper_path.name.split(".")[-1] == "json": 39 | args = exp_util.parse_config_json(training_hyper_path) 40 | exp_dir = training_hyper_path.parent 41 | model_paths = exp_dir.glob('model_*.pth.tar') 42 | model_paths = {int(str(t).split("model_")[-1].split(".pth")[0]): t for t in model_paths} 43 | assert use_epoch in model_paths.keys(), f"{use_epoch} not found in {sorted(list(model_paths.keys()))}" 44 | args.checkpoint = model_paths[use_epoch] 45 | else: 46 | args = exp_util.parse_config_yaml(Path('configs/training_defaults.yaml')) 47 | args = exp_util.parse_config_yaml(training_hyper_path, args) 48 | logging.warning("Loaded a un-initialized model.") 49 | args.checkpoint = None 50 | 51 | model = Networks() 52 | net_module = importlib.import_module("network." + args.network_name) 53 | model.decoder = net_module.Model(args.code_length, **args.network_specs).cuda() 54 | if args.encoder_name is not None: 55 | encoder_module = importlib.import_module("network." + args.encoder_name) 56 | model.encoder = encoder_module.Model(**args.encoder_specs).cuda() 57 | if args.checkpoint is not None: 58 | if model.decoder is not None: 59 | state_dict = torch.load(args.checkpoint)["model_state"] 60 | model.decoder.load_state_dict(state_dict) 61 | if model.encoder is not None: 62 | state_dict = torch.load(Path(args.checkpoint).parent / f"encoder_{use_epoch}.pth.tar")["model_state"] 63 | model.encoder.load_state_dict(state_dict) 64 | 65 | return model, args 66 | 67 | 68 | def forward_model(model: nn.Module, network_input: torch.Tensor = None, 69 | latent_input: torch.Tensor = None, 70 | xyz_input: torch.Tensor = None, 71 | loss_func=None, max_sample: int = 2 ** 32, 72 | no_detach: bool = False, 73 | verbose: bool = False): 74 | """ 75 | Forward the neural network model. (if loss_func is not None, will also compute the gradient w.r.t. the loss) 76 | Either network_input or (latent_input, xyz_input) tuple could be provided. 77 | :param model: MLP model. 78 | :param network_input: (N, 128) 79 | :param latent_input: (N, 125) 80 | :param xyz_input: (N, 3) 81 | :param loss_func: 82 | :param max_sample 83 | :return: [(N, X)] several values 84 | """ 85 | if latent_input is not None and xyz_input is not None: 86 | assert network_input is None 87 | network_input = torch.cat((latent_input, xyz_input), dim=1) 88 | 89 | assert network_input.ndimension() == 2 90 | 91 | n_chunks = math.ceil(network_input.size(0) / max_sample) 92 | assert not no_detach or n_chunks == 1 93 | 94 | network_input = torch.chunk(network_input, n_chunks) 95 | 96 | if verbose: 97 | logging.debug(f"Network input chunks = {n_chunks}, each chunk = {network_input[0].size()}") 98 | 99 | head = 0 100 | output_chunks = None 101 | for chunk_i, input_chunk in enumerate(network_input): 102 | # (N, 1) 103 | network_output = model.surface_decoding(input_chunk) 104 | if not isinstance(network_output, tuple): 105 | network_output = [network_output, ] 106 | 107 | if chunk_i == 0: 108 | output_chunks = [[] for _ in range(len(network_output))] 109 | 110 | if loss_func is not None: 111 | # The 'graph' in pytorch stores how the final variable is computed to its current form. 112 | # Under normal situations, we can delete this path right after the gradient is computed because the path 113 | # will be re-constructed on next forward call. 114 | # However, in our case, self.latent_vec is the leaf node requesting the gradient, the specific computation: 115 | # vec = self.latent_vec[inds] && cat(vec, xyz) 116 | # will be forgotten, too. if we delete the entire graph. 117 | # Indeed, the above computation is the ONLY part that we do not re-build during the next forwarding. 118 | # So, we set retain_graph to True. 119 | # According to https://github.com/pytorch/pytorch/issues/31185, if we delete the head loss immediately 120 | # after the backward(retain_graph=True), the un-referenced part graph will be deleted too, 121 | # hence keeping only the needed part (a sub-graph). Perfect :) 122 | loss_func(network_output, 123 | torch.arange(head, head + network_output[0].size(0), device=network_output[0].device) 124 | ).backward(retain_graph=(chunk_i != n_chunks - 1)) 125 | if not no_detach: 126 | network_output = [t.detach() for t in network_output] 127 | 128 | for payload_i, payload in enumerate(network_output): 129 | output_chunks[payload_i].append(payload) 130 | head += network_output[0].size(0) 131 | 132 | output_chunks = [torch.cat(t, dim=0) for t in output_chunks] 133 | return output_chunks 134 | 135 | 136 | def get_samples(r: int, device: torch.device, a: float = 0.0, b: float = None): 137 | """ 138 | Get samples within a cube, the voxel size is (b-a)/(r-1). range is from [a, b] 139 | :param r: num samples 140 | :param a: bound min 141 | :param b: bound max 142 | :return: (r*r*r, 3) 143 | """ 144 | overall_index = torch.arange(0, r ** 3, 1, device=device, dtype=torch.long) 145 | r = int(r) 146 | 147 | if b is None: 148 | b = 1. - 1. / r 149 | 150 | vsize = (b - a) / (r - 1) 151 | samples = torch.zeros(r ** 3, 3, device=device, dtype=torch.float32) 152 | samples[:, 0] = (overall_index // (r * r)) * vsize + a 153 | samples[:, 1] = ((overall_index // r) % r) * vsize + a 154 | samples[:, 2] = (overall_index % r) * vsize + a 155 | 156 | return samples 157 | 158 | 159 | def pack_samples(sample_indexer: torch.Tensor, count: int, 160 | sample_values: torch.Tensor = None): 161 | """ 162 | Pack a set of samples into batches. Each element in the batch is a random subsampling of the sample_values 163 | :param sample_indexer: (N, ) 164 | :param count: C 165 | :param sample_values: (N, L), if None, will return packed_inds instead of packed. 166 | :return: packed (B, C, L) or packed_inds (B, C), mapping: (B, ). 167 | """ 168 | from system.ext import pack_batch 169 | 170 | # First shuffle the samples to avoid biased samples. 171 | shuffle_inds = torch.randperm(sample_indexer.size(0), device=sample_indexer.device) 172 | sample_indexer = sample_indexer[shuffle_inds] 173 | 174 | mapping, pinds, pcount = torch.unique(sample_indexer, return_inverse=True, return_counts=True) 175 | 176 | n_batch = mapping.size(0) 177 | packed_inds = pack_batch(pinds, n_batch, count * 2) # (B, 2C) 178 | 179 | pcount.clamp_(max=count * 2 - 1) 180 | packed_inds_ind = torch.floor(torch.rand((n_batch, count), device=pcount.device) * pcount.unsqueeze(-1)).long() # (B, C) 181 | 182 | packed_inds = torch.gather(packed_inds, 1, packed_inds_ind) # (B, C) 183 | packed_inds = shuffle_inds[packed_inds] # (B, C) 184 | 185 | if sample_values is not None: 186 | assert sample_values.size(0) == sample_indexer.size(0) 187 | packed = torch.index_select(sample_values, 0, packed_inds.view(-1)).view(n_batch, count, sample_values.size(-1)) 188 | return packed, mapping 189 | else: 190 | return packed_inds, mapping 191 | 192 | 193 | def groupby_reduce(sample_indexer: torch.Tensor, sample_values: torch.Tensor, op: str = "max"): 194 | """ 195 | Group-By and Reduce sample_values according to their indices, the reduction operation is defined in `op`. 196 | :param sample_indexer: (N,). An index, must start from 0 and go to the (max-1), can be obtained using torch.unique. 197 | :param sample_values: (N, L) 198 | :param op: have to be in 'max', 'mean' 199 | :return: reduced values: (C, L) 200 | """ 201 | C = sample_indexer.max() + 1 202 | n_samples = sample_indexer.size(0) 203 | 204 | assert n_samples == sample_values.size(0), "Indexer and Values must agree on sample count!" 205 | 206 | if op == 'mean': 207 | from system.ext import groupby_sum 208 | values_sum, values_count = groupby_sum(sample_values, sample_indexer, C) 209 | return values_sum / values_count.unsqueeze(-1) 210 | elif op == 'sum': 211 | from system.ext import groupby_sum 212 | values_sum, _ = groupby_sum(sample_values, sample_indexer, C) 213 | return values_sum 214 | else: 215 | raise NotImplementedError 216 | 217 | 218 | def fix_weight_norm_pickle(net: torch.nn.Module): 219 | from torch.nn.utils.weight_norm import WeightNorm 220 | for mdl in net.modules(): 221 | fix_name = None 222 | if isinstance(mdl, torch.nn.Linear): 223 | for k, hook in mdl._forward_pre_hooks.items(): 224 | if isinstance(hook, WeightNorm): 225 | fix_name = hook.name 226 | if fix_name is not None: 227 | delattr(mdl, fix_name) 228 | -------------------------------------------------------------------------------- /uni/ext/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from torch.utils.cpp_extension import load 3 | 4 | 5 | def p(rel_path): 6 | abs_path = Path(__file__).parent / rel_path 7 | return str(abs_path) 8 | 9 | 10 | __COMPILE_VERBOSE = False 11 | optimize_flags = {'extra_cflags': ['-O3'], 'extra_cuda_cflags': ['-O3']} 12 | #optimize_flags = {} 13 | 14 | # Load in Marching cubes. 15 | _marching_cubes_module = load(name='marching_cubes', 16 | sources=[p('marching_cubes/mc.cpp'), 17 | p('marching_cubes/mc_interp_kernel.cu')], 18 | verbose=__COMPILE_VERBOSE, **optimize_flags) 19 | marching_cubes_interp = _marching_cubes_module.marching_cubes_sparse_interp 20 | 21 | # Load in Image processing modules. 22 | _imgproc_module = load(name='imgproc', 23 | sources=[p('imgproc/imgproc.cu'), p('imgproc/imgproc.cpp'), p('imgproc/photometric.cu')], 24 | verbose=__COMPILE_VERBOSE, **optimize_flags) 25 | unproject_depth = _imgproc_module.unproject_depth 26 | compute_normal_weight = _imgproc_module.compute_normal_weight 27 | compute_normal_weight_robust = _imgproc_module.compute_normal_weight_robust 28 | filter_depth = _imgproc_module.filter_depth 29 | rgb_odometry = _imgproc_module.rgb_odometry 30 | gradient_xy = _imgproc_module.gradient_xy 31 | 32 | # Load in Indexing modules. (which deal with complicated indexing scheme) 33 | _indexing_module = load(name='indexing', 34 | sources=[p('indexing/indexing.cpp'), p('indexing/indexing.cu')], 35 | verbose=__COMPILE_VERBOSE, **optimize_flags) 36 | pack_batch = _indexing_module.pack_batch 37 | groupby_sum = _indexing_module.groupby_sum 38 | 39 | # We need point cloud processing module. 40 | _pcproc_module = load(name='pcproc', 41 | sources=[p('pcproc/pcproc.cpp'), p('pcproc/pcproc.cu'), p('pcproc/cuda_kdtree.cu')], 42 | verbose=__COMPILE_VERBOSE, **optimize_flags) 43 | remove_radius_outlier = _pcproc_module.remove_radius_outlier 44 | estimate_normals = _pcproc_module.estimate_normals 45 | -------------------------------------------------------------------------------- /uni/ext/imgproc/common.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | using DepthAccessor = torch::PackedTensorAccessor32; 12 | using PCMapAccessor = torch::PackedTensorAccessor32; 13 | using IntensityAccessor = torch::PackedTensorAccessor32; 14 | using GradientAccessor = torch::PackedTensorAccessor32; 15 | //using MaskAccessor = torch::PackedTensorAccessor32; 16 | 17 | struct matrix3 18 | { 19 | float3 r1{0.0, 0.0, 0.0}; 20 | float3 r2{0.0, 0.0, 0.0}; 21 | float3 r3{0.0, 0.0, 0.0}; 22 | 23 | explicit matrix3(const std::vector& data) { 24 | r1.x = data[0]; r1.y = data[1]; r1.z = data[2]; 25 | r2.x = data[3]; r2.y = data[4]; r2.z = data[5]; 26 | r3.x = data[6]; r3.y = data[7]; r3.z = data[8]; 27 | } 28 | 29 | __host__ __device__ float3 operator*(const float3& rv) const { 30 | return make_float3( 31 | r1.x * rv.x + r1.y * rv.y + r1.z * rv.z, 32 | r2.x * rv.x + r2.y * rv.y + r2.z * rv.z, 33 | r3.x * rv.x + r3.y * rv.y + r3.z * rv.z); 34 | } 35 | }; 36 | 37 | static uint div_up(const uint a, const uint b) { 38 | return (a + b - 1) / b; 39 | } 40 | 41 | __device__ __forceinline__ static float3 get_vec3(const PCMapAccessor map, const uint i, const uint j) { 42 | return make_float3(map[i][j][0], map[i][j][1], map[i][j][2]); 43 | } 44 | 45 | inline __host__ __device__ float3 operator+(const float3& a, const float3& b) { 46 | return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); 47 | } 48 | 49 | inline __host__ __device__ float3 operator-(const float3& a, const float3& b) { 50 | return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); 51 | } 52 | 53 | inline __host__ __device__ void operator/=(float3 &a, float b) { 54 | a.x /= b; a.y /= b; a.z /= b; 55 | } 56 | 57 | inline __host__ __device__ void operator+=(float3 &a, const float3& b) { 58 | a.x += b.x; a.y += b.y; a.z += b.z; 59 | } 60 | 61 | inline __host__ __device__ void operator-=(float3 &a, const float3& b) { 62 | a.x -= b.x; a.y -= b.y; a.z -= b.z; 63 | } 64 | 65 | inline __host__ __device__ float3 cross(const float3& a, const float3& b) { 66 | return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); 67 | } 68 | 69 | inline __host__ __device__ float dot(const float3& a, const float3& b) { 70 | return a.x * b.x + a.y * b.y + a.z * b.z; 71 | } 72 | 73 | inline __host__ __device__ float length(const float3& v) { 74 | return sqrt(v.x * v.x + v.y * v.y + v.z * v.z); 75 | } 76 | 77 | inline __host__ __device__ float squared_length(const float3& v) { 78 | return v.x * v.x + v.y * v.y + v.z * v.z; 79 | } 80 | -------------------------------------------------------------------------------- /uni/ext/imgproc/imgproc.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor unproject_depth(torch::Tensor depth, float fx, float fy, float cx, float cy); 4 | //torch::Tensor unproject_depth(torch::Tensor depth, float fx, float fy, float cx, float cy); 5 | 6 | // We might do this several times, this interface enables re-use memory. 7 | void filter_depth(torch::Tensor depth_in, torch::Tensor depth_out); 8 | torch::Tensor compute_normal_weight(torch::Tensor pc_map); 9 | torch::Tensor compute_normal_weight_robust(torch::Tensor pc_map); 10 | 11 | // Compute rgb-image based odometry. Will return per-correspondence residual (M, ) and jacobian (M, 6) w.r.t. lie algebra. 12 | // ... based on current given estimate ( T(xi) * prev_xyz = cur_xyz ). 13 | // prev_intensity (H, W), float32, raning from 0 to 1. 14 | std::vector rgb_odometry( 15 | torch::Tensor prev_intensity, torch::Tensor prev_depth, 16 | torch::Tensor cur_intensity, torch::Tensor cur_depth, torch::Tensor cur_dIdxy, 17 | const std::vector& intr, 18 | const std::vector& krkinv_data, 19 | const std::vector& kt_data, 20 | float min_grad_scale, float max_depth_delta, bool compute_J); 21 | torch::Tensor gradient_xy(torch::Tensor cur_intensity); 22 | 23 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 24 | m.def("unproject_depth", &unproject_depth, "Unproject Depth (CUDA)"); 25 | m.def("filter_depth", &filter_depth, "Filter Depth (CUDA)"); 26 | m.def("compute_normal_weight", &compute_normal_weight, "Compute normal and weight (CUDA)"); 27 | m.def("compute_normal_weight_robust", &compute_normal_weight_robust, "Compute normal and weight (CUDA)"); 28 | m.def("rgb_odometry", &rgb_odometry, "Compute the function value and gradient for RGB Odometry (CUDA)"); 29 | m.def("gradient_xy", &gradient_xy, "Compute Gradient of an image for jacobian computation. (CUDA)"); 30 | } 31 | -------------------------------------------------------------------------------- /uni/ext/imgproc/photometric.cu: -------------------------------------------------------------------------------- 1 | #include "common.cuh" 2 | 3 | __global__ static void gradient_xy_kernel(const IntensityAccessor intensity, GradientAccessor out_grad) { 4 | const uint v = blockIdx.x * blockDim.x + threadIdx.x; 5 | const uint u = blockIdx.y * blockDim.y + threadIdx.y; 6 | if (v > intensity.size(0) - 1 || u > intensity.size(1) - 1) { return; } 7 | if (v < 1 || v > intensity.size(0) - 2 || u < 1 || u > intensity.size(1) - 2) { 8 | out_grad[v][u][0] = out_grad[v][u][1] = NAN; 9 | return; 10 | } 11 | 12 | // Sobel morph. 13 | float u_d1 = intensity[v - 1][u + 1] - intensity[v - 1][u - 1]; 14 | float u_d2 = intensity[v][u + 1] - intensity[v][u - 1]; 15 | float u_d3 = intensity[v + 1][u + 1] - intensity[v + 1][u - 1]; 16 | out_grad[v][u][0] = (u_d1 + 2 * u_d2 + u_d3) / 8.0f; 17 | 18 | float v_d1 = intensity[v + 1][u - 1] - intensity[v - 1][u - 1]; 19 | float v_d2 = intensity[v + 1][u] - intensity[v - 1][u]; 20 | float v_d3 = intensity[v + 1][u + 1] - intensity[v - 1][u + 1]; 21 | out_grad[v][u][1] = (v_d1 + 2 * v_d2 + v_d3) / 8.0f; 22 | } 23 | 24 | __global__ static void evaluate_fJ(const IntensityAccessor prev_img, const DepthAccessor prev_depth, 25 | const IntensityAccessor cur_img, const DepthAccessor cur_depth, 26 | const GradientAccessor cur_dIdxy, const float min_grad_scale, const float max_depth_delta, 27 | matrix3 krkinv, float3 kt, float4 calib, 28 | IntensityAccessor f_val, GradientAccessor J_val, bool compute_J) { 29 | 30 | const uint v = blockIdx.x * blockDim.x + threadIdx.x; 31 | const uint u = blockIdx.y * blockDim.y + threadIdx.y; 32 | const uint img_h = prev_img.size(0); 33 | const uint img_w = prev_img.size(1); 34 | 35 | // The boundary will not be valid anyway. 36 | if (v > img_h - 1 || u > img_w - 1) { return; } 37 | 38 | f_val[v][u] = NAN; 39 | 40 | // Also prune if gradient is too small (which is useless for pose estimation) 41 | float dI_dx = cur_dIdxy[v][u][0], dI_dy = cur_dIdxy[v][u][1]; 42 | float mTwo = (dI_dx * dI_dx) + (dI_dy * dI_dy); 43 | if (mTwo < min_grad_scale || isnan(mTwo)) { 44 | return; 45 | } 46 | 47 | float d1 = cur_depth[v][u]; 48 | if (isnan(d1)) { 49 | return; 50 | } 51 | 52 | float warpped_d1 = d1 * (krkinv.r3.x * u + krkinv.r3.y * v + krkinv.r3.z) + kt.z; 53 | int u0 = __float2int_rn((d1 * (krkinv.r1.x * u + krkinv.r1.y * v + krkinv.r1.z) + kt.x) / warpped_d1); 54 | int v0 = __float2int_rn((d1 * (krkinv.r2.x * u + krkinv.r2.y * v + krkinv.r2.z) + kt.y) / warpped_d1); 55 | 56 | if (u0 >= 0 && u0 < img_w && v0 >= 0 && v0 < img_h) { 57 | float d0 = prev_depth[v0][u0]; 58 | // Make sure this pair of obs is not outlier and is really observed. 59 | if (!isnan(d0) && abs(warpped_d1 - d0) <= max_depth_delta && d0 > 0.0) { 60 | // Compute function value. 61 | f_val[v][u] = cur_img[v][u] - prev_img[v0][u0]; 62 | // Compute gradient w.r.t. xi. 63 | if (compute_J) { 64 | float3 G = make_float3(d0 * (u0 - calib.z) / calib.x, d0 * (v0 - calib.w) / calib.y, d0); 65 | float p0 = dI_dx * calib.x / G.z; 66 | float p1 = dI_dy * calib.y / G.z; 67 | float p2 = -(p0 * G.x + p1 * G.y) / G.z; 68 | J_val[v][u][0] = p0; 69 | J_val[v][u][1] = p1; 70 | J_val[v][u][2] = p2; 71 | J_val[v][u][3] = -G.z * p1 + G.y * p2; 72 | J_val[v][u][4] = G.z * p0 - G.x * p2; 73 | J_val[v][u][5] = -G.y * p0 + G.x * p1; 74 | } 75 | } 76 | } 77 | } 78 | 79 | torch::Tensor gradient_xy(torch::Tensor cur_intensity) { 80 | CHECK_INPUT(cur_intensity); 81 | const uint img_h = cur_intensity.size(0); 82 | const uint img_w = cur_intensity.size(1); 83 | 84 | dim3 dimBlock = dim3(16, 16); 85 | dim3 dimGrid = dim3(div_up(img_h, dimBlock.x), div_up(img_w, dimBlock.y)); 86 | 87 | torch::Tensor cur_dIdxy = torch::empty({img_h, img_w, 2}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); 88 | gradient_xy_kernel<<>>( 89 | cur_intensity.packed_accessor32(), 90 | cur_dIdxy.packed_accessor32() 91 | ); 92 | return cur_dIdxy; 93 | } 94 | 95 | std::vector rgb_odometry( 96 | torch::Tensor prev_intensity, torch::Tensor prev_depth, 97 | torch::Tensor cur_intensity, torch::Tensor cur_depth, torch::Tensor cur_dIdxy, 98 | const std::vector& intr, 99 | const std::vector& krkinv_data, 100 | const std::vector& kt_data, 101 | float min_grad_scale, float max_depth_delta, bool compute_J) { 102 | 103 | CHECK_INPUT(prev_intensity); CHECK_INPUT(prev_depth); 104 | CHECK_INPUT(cur_intensity); CHECK_INPUT(cur_depth); CHECK_INPUT(cur_dIdxy); 105 | 106 | const uint img_h = cur_intensity.size(0); 107 | const uint img_w = cur_intensity.size(1); 108 | 109 | dim3 dimBlock = dim3(16, 16); 110 | dim3 dimGrid = dim3(div_up(img_h, dimBlock.x), div_up(img_w, dimBlock.y)); 111 | 112 | matrix3 krkinv(krkinv_data); 113 | float3 kt{kt_data[0], kt_data[1], kt_data[2]}; 114 | float4 calib{intr[0], intr[1], intr[2], intr[3]}; 115 | 116 | torch::Tensor f_img = torch::empty({img_h, img_w}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); 117 | torch::Tensor J_img = torch::empty({0, 0, 6}, torch::dtype(torch::kFloat32).device(torch::kCUDA));; 118 | if (compute_J) { 119 | J_img = torch::empty({img_h, img_w, 6}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); 120 | } 121 | 122 | evaluate_fJ<<>>( 123 | prev_intensity.packed_accessor32(), 124 | prev_depth.packed_accessor32(), 125 | cur_intensity.packed_accessor32(), 126 | cur_depth.packed_accessor32(), 127 | cur_dIdxy.packed_accessor32(), 128 | min_grad_scale, max_depth_delta, 129 | krkinv, kt, calib, 130 | f_img.packed_accessor32(), 131 | J_img.packed_accessor32(), compute_J); 132 | 133 | if (compute_J) { 134 | return {f_img, J_img}; 135 | } else { 136 | return {f_img}; 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /uni/ext/indexing/indexing.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor pack_batch(torch::Tensor indices, uint n_batch, uint n_point); 4 | std::vector groupby_sum(torch::Tensor values, torch::Tensor indices, uint C); 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("pack_batch", &pack_batch, "Pack Batch (CUDA)"); 8 | m.def("groupby_sum", &groupby_sum, "GroupBy Sum (CUDA)"); 9 | } 10 | -------------------------------------------------------------------------------- /uni/ext/indexing/indexing.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | using CountAccessor = torch::PackedTensorAccessor32; 9 | using IndexAccessor = torch::PackedTensorAccessor32; 10 | using PackedIndAccessor = torch::PackedTensorAccessor32; 11 | using ValueAccessor = torch::PackedTensorAccessor32; 12 | 13 | static uint div_up(const uint a, const uint b) { 14 | return (a + b - 1) / b; 15 | } 16 | 17 | __global__ static void pack_batch_kernel(const IndexAccessor indices, PackedIndAccessor packed_inds, const uint n_all, 18 | const uint n_batch, const uint n_point, int* __restrict__ filled_count) { 19 | const uint i_data = blockIdx.x * blockDim.x + threadIdx.x; 20 | if (i_data >= n_all) { 21 | return; 22 | } 23 | 24 | long i_group = indices[i_data]; 25 | if (i_group >= n_batch) { 26 | return; 27 | } 28 | 29 | // Get one valid id. 30 | int cur_count = atomicAdd(filled_count + i_group, 1); 31 | if (cur_count >= n_point) { 32 | return; 33 | } 34 | packed_inds[i_group][cur_count] = i_data; 35 | } 36 | 37 | __device__ static float atomicMax(float* __restrict__ address, float val) { 38 | int* address_as_i = (int*) address; 39 | int old = *address_as_i, assumed; 40 | do { 41 | assumed = old; 42 | old = ::atomicCAS(address_as_i, assumed, 43 | __float_as_int(::fmaxf(val, __int_as_float(assumed)))); 44 | } while (assumed != old); 45 | return __int_as_float(old); 46 | } 47 | 48 | __global__ static void groupby_max_kernel(const ValueAccessor values, const IndexAccessor indices, ValueAccessor reduced_values) { 49 | const uint i = blockIdx.x; 50 | const uint l = threadIdx.x; 51 | 52 | float value = values[i][l]; 53 | long index = indices[i]; 54 | 55 | float* rptr = reduced_values[index].data() + l; 56 | atomicMax(rptr, value); 57 | } 58 | 59 | __global__ static void groupby_sum_kernel(const ValueAccessor values, const IndexAccessor indices, ValueAccessor sum_values, CountAccessor sum_counts) { 60 | const uint i = blockIdx.x; 61 | const uint l = threadIdx.x; 62 | 63 | float value = values[i][l]; 64 | long index = indices[i]; 65 | 66 | float* rptr = sum_values[index].data() + l; 67 | int* iptr = &(sum_counts[index]); 68 | 69 | atomicAdd(rptr, value); 70 | atomicAdd(iptr, 1); 71 | } 72 | 73 | torch::Tensor pack_batch(torch::Tensor indices, uint n_batch, uint n_point) { 74 | CHECK_INPUT(indices); 75 | torch::Tensor packed_inds = torch::empty({n_batch, n_point}, torch::dtype(torch::kLong).device(torch::kCUDA)); 76 | thrust::device_vector filled_count(n_batch, 0); 77 | const uint n_all = indices.size(0); 78 | 79 | dim3 dimBlock = dim3(128); 80 | dim3 dimGrid = dim3(div_up(n_all, dimBlock.x)); 81 | pack_batch_kernel<<>>( 82 | indices.packed_accessor32(), 83 | packed_inds.packed_accessor32(), 84 | n_all, n_batch, n_point, filled_count.data().get()); 85 | return packed_inds; 86 | } 87 | 88 | 89 | std::vector groupby_sum(torch::Tensor values, torch::Tensor indices, uint C) { 90 | CHECK_INPUT(values); 91 | CHECK_INPUT(indices); 92 | 93 | const uint N = values.size(0); 94 | const uint L = values.size(1); 95 | 96 | torch::Tensor sum_values = torch::zeros({C, L}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); 97 | torch::Tensor sum_count = torch::zeros({C}, torch::dtype(torch::kInt32).device(torch::kCUDA)); 98 | 99 | dim3 dimBlock = dim3(L); 100 | dim3 dimGrid = dim3(N); 101 | groupby_sum_kernel<<>>( 102 | values.packed_accessor32(), 103 | indices.packed_accessor32(), 104 | sum_values.packed_accessor32(), 105 | sum_count.packed_accessor32() 106 | ); 107 | 108 | return {sum_values, sum_count}; 109 | } 110 | -------------------------------------------------------------------------------- /uni/ext/marching_cubes/mc.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | std::vector marching_cubes_sparse_interp_cuda( 4 | torch::Tensor indexer, // (nx, ny, nz) -> data_id 5 | torch::Tensor valid_blocks, // (K, ) 6 | torch::Tensor vec_batch_mapping, 7 | torch::Tensor cube_sdf, // (M, rx, ry, rz) 8 | torch::Tensor cube_std, // (M, rx, ry, rz) 9 | int max_n_triangles, // Maximum number of triangle buffer. 10 | const std::vector& n_xyz, // [nx, ny, nz] 11 | float max_std // Prune all vertices 12 | ); 13 | 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("marching_cubes_sparse_interp", &marching_cubes_sparse_interp_cuda, "Marching Cubes with Interpolation (CUDA)"); 16 | } 17 | -------------------------------------------------------------------------------- /uni/ext/pcproc/cuda_kdtree.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FLANN_CUDA_KD_TREE_BUILDER_H_ 2 | #define FLANN_CUDA_KD_TREE_BUILDER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "cutil_math.h" 11 | #include 12 | #include 13 | 14 | 15 | namespace tinyflann { 16 | 17 | // Distance types. 18 | struct CudaL1; 19 | struct CudaL2; 20 | 21 | // Parameters 22 | struct KDTreeCuda3dIndexParams { 23 | int leaf_max_size = 64; 24 | }; 25 | 26 | struct SearchParams 27 | { 28 | SearchParams(int checks_ = 32, float eps_ = 0.0, bool sorted_ = true ) : 29 | checks(checks_), eps(eps_), sorted(sorted_) 30 | { 31 | max_neighbors = -1; 32 | use_heap = true; 33 | } 34 | 35 | // how many leafs to visit when searching for neighbours (-1 for unlimited) 36 | int checks; 37 | // search for eps-approximate neighbours (default: 0) 38 | float eps; 39 | // only for radius search, require neighbours sorted by distance (default: true) 40 | bool sorted; 41 | // maximum number of neighbors radius search should return (-1 for unlimited) 42 | int max_neighbors; 43 | // use a heap to manage the result set (default: FLANN_Undefined) 44 | bool use_heap; 45 | }; 46 | 47 | template 48 | class KDTreeCuda3dIndex 49 | { 50 | public: 51 | int visited_leafs; 52 | KDTreeCuda3dIndex(const float* input_data, size_t input_count, const KDTreeCuda3dIndexParams& params = KDTreeCuda3dIndexParams()) 53 | : dataset_(input_data), leaf_count_(0), visited_leafs(0), node_count_(0), current_node_count_(0) { 54 | size_ = input_count; 55 | leaf_max_size_ = params.leaf_max_size; 56 | gpu_helper_=0; 57 | } 58 | 59 | /** 60 | * Standard destructor 61 | */ 62 | ~KDTreeCuda3dIndex() { 63 | clearGpuBuffers(); 64 | } 65 | 66 | /** 67 | * Builds the index 68 | */ 69 | void buildIndex() { 70 | leaf_count_ = 0; 71 | node_count_ = 0; 72 | uploadTreeToGpu(); 73 | } 74 | 75 | /** 76 | * queries: (N, p) flattened float cuda array where only first 3 elements are used. 77 | * n_query: N 78 | * n_query_stride: p 79 | * indices: (N, knn) int cuda array 80 | * dists: (N, knn) float cuda array 81 | */ 82 | void knnSearch(const float* queries, size_t n_query, int n_query_stride, int* indices, float* dists, size_t knn, const SearchParams& params = SearchParams()) const; 83 | int radiusSearch(const float* queries, size_t n_query, int n_query_stride, int* indices, float* dists, float radius, const SearchParams& params = SearchParams()) const; 84 | 85 | private: 86 | 87 | void uploadTreeToGpu(); 88 | void clearGpuBuffers(); 89 | 90 | private: 91 | 92 | struct GpuHelper; 93 | GpuHelper* gpu_helper_; 94 | 95 | const float* dataset_; 96 | int leaf_max_size_; 97 | int leaf_count_; 98 | int node_count_; 99 | //! used by convertTreeToGpuFormat 100 | int current_node_count_; 101 | size_t size_; 102 | 103 | }; // class KDTreeCuda3dIndex 104 | 105 | 106 | } // namespace all 107 | 108 | 109 | #endif -------------------------------------------------------------------------------- /uni/ext/pcproc/pcproc.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor remove_radius_outlier( 4 | torch::Tensor input_pc, 5 | int nb_points, 6 | float radius 7 | ); 8 | 9 | torch::Tensor estimate_normals( 10 | torch::Tensor input_pc, 11 | int max_nn, 12 | float radius, 13 | const std::vector& cam_xyz 14 | ); 15 | 16 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 17 | m.def("remove_radius_outlier", &remove_radius_outlier, "Remove point outliers by radius (CUDA)"); 18 | m.def("estimate_normals", &estimate_normals, "Estimate point cloud normals (CUDA)"); 19 | } 20 | -------------------------------------------------------------------------------- /uni/ext/pcproc/pcproc.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_kdtree.cuh" 8 | #include "cutil_math.h" 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | using PCAccessor = torch::PackedTensorAccessor32; 15 | using MaskAccessor = torch::PackedTensorAccessor32; 16 | 17 | static uint div_up(const uint a, const uint b) { 18 | return (a + b - 1) / b; 19 | } 20 | 21 | __device__ float4 sym3eig(float3 x1, float3 x2, float3 x3) { 22 | float4 ret_val; 23 | 24 | const float p1 = x1.y * x1.y + x1.z * x1.z + x2.z * x2.z; 25 | const float q = (x1.x + x2.y + x3.z) / 3.0f; 26 | const float p2 = (x1.x - q) * (x1.x - q) + (x2.y - q) * (x2.y - q) + (x3.z - q) * (x3.z - q) + 2 * p1; 27 | const float p = sqrt(p2 / 6.0f); 28 | const float b11 = (1.0f / p) * (x1.x - q); const float b12 = (1.0f / p) * x1.y; 29 | const float b13 = (1.0f / p) * x1.z; const float b21 = (1.0f / p) * x2.x; 30 | const float b22 = (1.0f / p) * (x2.y - q); const float b23 = (1.0f / p) * x2.z; 31 | const float b31 = (1.0f / p) * x3.x; const float b32 = (1.0f / p) * x3.y; 32 | const float b33 = (1.0f / p) * (x3.z - q); 33 | 34 | float r = b11 * b22 * b33 + b12 * b23 * b31 + b13 * b21 * b32 - 35 | b13 * b22 * b31 - b12 * b21 * b33 - b11 * b23 * b32; 36 | r = r / 2.0f; 37 | 38 | float phi; 39 | if (r <= -1) { 40 | phi = M_PI / 3.0f; 41 | } else if (r >= 1) { 42 | phi = 0; 43 | } else { 44 | phi = acos(r) / 3.0f; 45 | } 46 | 47 | // float e0 = q + 2 * p * cos(phi); 48 | ret_val.w = q + 2 * p * cos(phi + (2 * M_PI / 3)); 49 | // ret_val.w = 3 * q - e0 - e1; 50 | 51 | x1.x -= ret_val.w; 52 | x2.y -= ret_val.w; 53 | x3.z -= ret_val.w; 54 | 55 | const float r12_1 = x1.y * x2.z - x1.z * x2.y; 56 | const float r12_2 = x1.z * x2.x - x1.x * x2.z; 57 | const float r12_3 = x1.x * x2.y - x1.y * x2.x; 58 | const float r13_1 = x1.y * x3.z - x1.z * x3.y; 59 | const float r13_2 = x1.z * x3.x - x1.x * x3.z; 60 | const float r13_3 = x1.x * x3.y - x1.y * x3.x; 61 | const float r23_1 = x2.y * x3.z - x2.z * x3.y; 62 | const float r23_2 = x2.z * x3.x - x2.x * x3.z; 63 | const float r23_3 = x2.x * x3.y - x2.y * x3.x; 64 | 65 | const float d1 = r12_1 * r12_1 + r12_2 * r12_2 + r12_3 * r12_3; 66 | const float d2 = r13_1 * r13_1 + r13_2 * r13_2 + r13_3 * r13_3; 67 | const float d3 = r23_1 * r23_1 + r23_2 * r23_2 + r23_3 * r23_3; 68 | 69 | float d_max = d1; 70 | int i_max = 0; 71 | 72 | if (d2 > d_max) { 73 | d_max = d2; 74 | i_max = 1; 75 | } 76 | 77 | if (d3 > d_max) { 78 | i_max = 2; 79 | } 80 | 81 | if (i_max == 0) { 82 | ret_val.x = r12_1 / sqrt(d1); 83 | ret_val.y = r12_2 / sqrt(d1); 84 | ret_val.z = r12_3 / sqrt(d1); 85 | } else if (i_max == 1) { 86 | ret_val.x = r13_1 / sqrt(d2); 87 | ret_val.y = r13_2 / sqrt(d2); 88 | ret_val.z = r13_3 / sqrt(d2); 89 | } else { 90 | ret_val.x = r23_1 / sqrt(d3); 91 | ret_val.y = r23_2 / sqrt(d3); 92 | ret_val.z = r23_3 / sqrt(d3); 93 | } 94 | 95 | return ret_val; 96 | } 97 | 98 | __global__ void remove_radius_outlier_kernel(const PCAccessor input_pc, const float* __restrict__ input_nn_dist, int nb_points, float radius, 99 | MaskAccessor output_mask) { 100 | const uint pc_id = blockIdx.x * blockDim.x + threadIdx.x; 101 | if (pc_id >= input_pc.size(0)) { 102 | return; 103 | } 104 | output_mask[pc_id] = input_nn_dist[pc_id * nb_points + (nb_points - 1)] < radius * radius; 105 | } 106 | 107 | __global__ void estimate_normal_kernel(const PCAccessor input_pc, const float* __restrict__ input_nn_dist, const int* __restrict__ input_nn_ind, 108 | int max_nn, float radius, float3 cam_pos, PCAccessor output_normal) { 109 | const uint pc_id = blockIdx.x * blockDim.x + threadIdx.x; 110 | if (pc_id >= input_pc.size(0)) { 111 | return; 112 | } 113 | 114 | const float* cur_dist = input_nn_dist + max_nn * pc_id; 115 | const int* cur_ind = input_nn_ind + max_nn * pc_id; 116 | 117 | float3 pc_mean{0.0f, 0.0f, 0.0f}; 118 | float valid_count = 0.0f; 119 | for (int nn_i = 1; nn_i < max_nn; ++nn_i) { 120 | if (cur_dist[nn_i] < radius * radius) { 121 | int nn_pc_id = cur_ind[nn_i]; 122 | pc_mean += make_float3(input_pc[nn_pc_id][0], input_pc[nn_pc_id][1], input_pc[nn_pc_id][2]); 123 | valid_count += 1.0f; 124 | } else break; 125 | } 126 | 127 | if (valid_count < 5.0f) { 128 | output_normal[pc_id][0] = output_normal[pc_id][1] = output_normal[pc_id][2] = NAN; 129 | return; 130 | } 131 | pc_mean /= valid_count; 132 | 133 | // Compute the covariance matrix. 134 | float3 cov_x1{0.0f, 0.0f, 0.0f}; 135 | float3 cov_x2{0.0f, 0.0f, 0.0f}; 136 | float3 cov_x3{0.0f, 0.0f, 0.0f}; 137 | for (int nn_i = 1; nn_i < max_nn; ++nn_i) { 138 | if (cur_dist[nn_i] < radius * radius) { 139 | int nn_pc_id = cur_ind[nn_i]; 140 | float3 pos = make_float3(input_pc[nn_pc_id][0], input_pc[nn_pc_id][1], input_pc[nn_pc_id][2]); 141 | pos -= pc_mean; 142 | cov_x1.x += pos.x * pos.x; cov_x1.y += pos.x * pos.y; cov_x1.z += pos.x * pos.z; 143 | cov_x2.x += pos.y * pos.x; cov_x2.y += pos.y * pos.y; cov_x2.z += pos.y * pos.z; 144 | cov_x3.x += pos.z * pos.x; cov_x3.y += pos.z * pos.y; cov_x3.z += pos.z * pos.z; 145 | } else break; 146 | } 147 | 148 | float4 eigv = sym3eig(cov_x1, cov_x2, cov_x3); 149 | float3 normal = make_float3(eigv.x, eigv.y, eigv.z); 150 | 151 | float3 cur_pos = make_float3(input_pc[pc_id][0], input_pc[pc_id][1], input_pc[pc_id][2]); 152 | if (dot(normal, cur_pos - cam_pos) > 0.0f) { 153 | normal = -normal; 154 | } 155 | output_normal[pc_id][0] = normal.x; 156 | output_normal[pc_id][1] = normal.y; 157 | output_normal[pc_id][2] = normal.z; 158 | } 159 | 160 | torch::Tensor remove_radius_outlier(torch::Tensor input_pc, int nb_points, float radius) { 161 | CHECK_INPUT(input_pc); 162 | 163 | size_t n_point = input_pc.size(0); 164 | 165 | // Build KDTree based on input point cloud 166 | tinyflann::KDTreeCuda3dIndex knn_index(input_pc.data_ptr(), n_point); 167 | knn_index.buildIndex(); 168 | 169 | // Compute for each point its nearest N neighbours . 170 | thrust::device_vector dist(n_point * nb_points); 171 | thrust::device_vector indices(n_point * nb_points); 172 | knn_index.knnSearch(input_pc.data_ptr(), n_point, 4, (int*) indices.data().get(), 173 | (float*) dist.data().get(), nb_points); 174 | 175 | // Test the validity of the points and remove bad ones. 176 | torch::Tensor output_mask = torch::empty({(long) n_point}, torch::dtype(torch::kBool).device(torch::kCUDA)); 177 | 178 | dim3 dimBlock = dim3(128); 179 | dim3 dimGrid = dim3(div_up(n_point, dimBlock.x)); 180 | 181 | remove_radius_outlier_kernel<<>>( 182 | input_pc.packed_accessor32(), 183 | dist.data().get(), nb_points, radius, 184 | output_mask.packed_accessor32()); 185 | 186 | return output_mask; 187 | } 188 | 189 | torch::Tensor estimate_normals(torch::Tensor input_pc, int max_nn, float radius, const std::vector& cam_xyz) { 190 | CHECK_INPUT(input_pc); 191 | size_t n_point = input_pc.size(0); 192 | tinyflann::KDTreeCuda3dIndex knn_index(input_pc.data_ptr(), n_point); 193 | knn_index.buildIndex(); 194 | thrust::device_vector dist(n_point * max_nn); 195 | thrust::device_vector indices(n_point * max_nn); 196 | knn_index.knnSearch(input_pc.data_ptr(), n_point, 4, (int*) indices.data().get(), 197 | (float*) dist.data().get(), max_nn); 198 | 199 | torch::Tensor output_normal = torch::empty({(long) n_point, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); 200 | 201 | dim3 dimBlock = dim3(128); 202 | dim3 dimGrid = dim3(div_up(n_point, dimBlock.x)); 203 | estimate_normal_kernel<<>>( 204 | input_pc.packed_accessor32(), 205 | dist.data().get(), indices.data().get(), 206 | max_nn, radius, make_float3(cam_xyz[0], cam_xyz[1], cam_xyz[2]), 207 | output_normal.packed_accessor32()); 208 | 209 | return output_normal; 210 | } 211 | -------------------------------------------------------------------------------- /uni/mapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/uni/mapper/__init__.py -------------------------------------------------------------------------------- /uni/mapper/base_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | import argparse 5 | import threading 6 | from pathlib import Path 7 | import functools 8 | 9 | 10 | import pdb 11 | 12 | class BaseMap: 13 | def __init__(self, uni_model, args: argparse.Namespace, 14 | latent_dim, device: torch.device, enable_async: bool = False): 15 | 16 | 17 | self.model = uni_model 18 | 19 | self.voxel_size = args.voxel_size 20 | self.n_xyz = np.ceil((np.asarray(args.bound_max) - np.asarray(args.bound_min)) / args.voxel_size).astype(int).tolist() 21 | logging.info(f"Map size Nx = {self.n_xyz[0]}, Ny = {self.n_xyz[1]}, Nz = {self.n_xyz[2]}") 22 | 23 | self.args = args 24 | self.bound_min = torch.tensor(args.bound_min, device=device).float() 25 | self.bound_max = self.bound_min + self.voxel_size * torch.tensor(self.n_xyz, device=device) 26 | self.latent_dim = latent_dim # could be int for surface and tuple for context 27 | self.device = device 28 | self.integration_offsets = [torch.tensor(t, device=self.device, dtype=torch.float32) for t in [ 29 | [-0.5, -0.5, -0.5], [-0.5, -0.5, 0.5], [-0.5, 0.5, -0.5], [-0.5, 0.5, 0.5], 30 | [0.5, -0.5, -0.5], [0.5, -0.5, 0.5], [0.5, 0.5, -0.5], [0.5, 0.5, 0.5] 31 | ]] 32 | # Directly modifiable from outside. 33 | self.extract_mesh_std_range = None 34 | 35 | self.mesh_update_affected = [torch.tensor([t], device=self.device) 36 | for t in [[-1, 0, 0], [1, 0, 0], 37 | [0, -1, 0], [0, 1, 0], 38 | [0, 0, -1], [0, 0, 1]]] 39 | self.relative_network_offset = torch.tensor([[0.5, 0.5, 0.5]], device=self.device, dtype=torch.float32) 40 | 41 | self.cold_vars = { 42 | "n_occupied": 0, 43 | "indexer": torch.ones(np.product(self.n_xyz), device=device, dtype=torch.long) * -1, 44 | # -- Voxel Attributes -- 45 | # 1. Latent Vector (Geometry) 46 | "latent_vecs": torch.empty((1, self.latent_dim), dtype=torch.float32, device=device) if type(self.latent_dim) == int 47 | else torch.empty((1, *(self.latent_dim)), dtype=torch.float32, device=device), 48 | # 2. Position 49 | "latent_vecs_pos": torch.ones((1, ), dtype=torch.long, device=device) * -1, 50 | # 3. Confidence on its geometry 51 | "voxel_obs_count": torch.zeros((1, ), dtype=torch.float32, device=device), 52 | } 53 | self.backup_var_names = ["indexer", "latent_vecs", "latent_vecs_pos", "voxel_obs_count"] 54 | 55 | self.backup_vars = {} 56 | # Allow direct visit by variable 57 | for p in self.cold_vars.keys(): 58 | setattr(BaseMap, p, property( 59 | fget=functools.partial(BaseMap._get_var, name=p), 60 | fset=functools.partial(BaseMap._set_var, name=p) 61 | )) 62 | self.meshing_thread = None 63 | self.meshing_thread_id = -1 64 | self.meshing_stream = torch.cuda.Stream() 65 | self.mesh_cache = MeshExtractCache(self.device) 66 | self.latent_vecs.zero_() 67 | 68 | 69 | 70 | def save(self, path): 71 | if not isinstance(path, Path): 72 | path = Path(path) 73 | indexer_key = torch.where(self.indexer>-1)[0] 74 | indexer_value = self.indexer[indexer_key].clone() 75 | self.cold_vars['indexer_key'] = indexer_key 76 | self.cold_vars['indexer_value'] = indexer_value 77 | 78 | indexer = self.cold_vars['indexer'] 79 | del self.cold_vars['indexer'] 80 | 81 | with path.open('wb') as f: 82 | torch.save(self.cold_vars, f) 83 | 84 | self.cold_vars['indexer'] = indexer 85 | 86 | def load(self, path): 87 | if not isinstance(path, Path): 88 | path = Path(path) 89 | with path.open('rb') as f: 90 | self.cold_vars = torch.load(f) 91 | self.cold_vars['indexer'] = torch.ones(np.product(self.n_xyz), device=self.device, dtype=torch.long) * -1 92 | self.cold_vars['indexer'][self.cold_vars['indexer_key']] = self.cold_vars['indexer_value'] 93 | 94 | 95 | def _get_var(self, name): 96 | if threading.get_ident() == self.meshing_thread_id and name in self.backup_var_names: 97 | return self.backup_vars[name] 98 | else: 99 | return self.cold_vars[name] 100 | 101 | def _set_var(self, value, name): 102 | if threading.get_ident() == self.meshing_thread_id and name in self.backup_var_names: 103 | self.backup_vars[name] = value 104 | else: 105 | self.cold_vars[name] = value 106 | 107 | def _inflate_latent_buffer(self, count: int): 108 | target_n_occupied = self.n_occupied + count 109 | if self.latent_vecs.size(0) < target_n_occupied: 110 | new_size = self.latent_vecs.size(0) 111 | while new_size < target_n_occupied: 112 | new_size *= 2 113 | new_vec = torch.empty((new_size, self.latent_dim), dtype=torch.float32, device=self.device) if type(self.latent_dim) == int \ 114 | else torch.empty((new_size, *(self.latent_dim)), dtype=torch.float32, device=self.device) 115 | new_vec[:self.latent_vecs.size(0)] = self.latent_vecs 116 | 117 | new_vec_pos = torch.ones((new_size, ), dtype=torch.long, device=self.device) * -1 118 | new_vec_pos[:self.latent_vecs.size(0)] = self.latent_vecs_pos 119 | new_voxel_conf = torch.zeros((new_size, ), dtype=torch.float32, device=self.device) 120 | new_voxel_conf[:self.latent_vecs.size(0)] = self.voxel_obs_count 121 | 122 | new_vec[self.latent_vecs.size(0):].zero_() 123 | 124 | self.latent_vecs = new_vec 125 | self.latent_vecs_pos = new_vec_pos 126 | self.voxel_obs_count = new_voxel_conf 127 | 128 | 129 | new_inds = torch.arange(self.n_occupied, target_n_occupied, device=self.device, dtype=torch.long) 130 | self.n_occupied = target_n_occupied 131 | return new_inds 132 | 133 | def _linearize_id(self, xyz: torch.Tensor): 134 | """ 135 | :param xyz (N, 3) long id 136 | :return: (N, ) lineraized id to be accessed in self.indexer 137 | """ 138 | return xyz[:, 2] + self.n_xyz[-1] * xyz[:, 1] + (self.n_xyz[-1] * self.n_xyz[-2]) * xyz[:, 0] 139 | 140 | def _unlinearize_id(self, idx: torch.Tensor): 141 | """ 142 | :param idx: (N, ) linearized id for access in self.indexer 143 | :return: xyz (N, 3) id to be indexed in 3D 144 | """ 145 | return torch.stack([idx // (self.n_xyz[1] * self.n_xyz[2]), 146 | (idx // self.n_xyz[2]) % self.n_xyz[1], 147 | idx % self.n_xyz[2]], dim=-1) 148 | 149 | def _mark_updated_vec_id(self, new_vec_id: torch.Tensor): 150 | """ 151 | :param new_vec_id: (B,) updated id (indexed in latent vectors) 152 | """ 153 | self.mesh_cache.updated_vec_id = torch.cat([self.mesh_cache.updated_vec_id, new_vec_id]) 154 | self.mesh_cache.updated_vec_id = torch.unique(self.mesh_cache.updated_vec_id) 155 | 156 | def allocate_block(self, idx: torch.Tensor): 157 | """ 158 | :param idx: (N, 3) or (N, ), if the first one, will call linearize id. 159 | NOTE: this will not check index overflow! 160 | """ 161 | if idx.ndimension() == 2 and idx.size(1) == 3: 162 | idx = self._linearize_id(idx) 163 | new_id = self._inflate_latent_buffer(idx.size(0)) 164 | self.latent_vecs_pos[new_id] = idx 165 | self.indexer[idx] = new_id 166 | 167 | def integrate_keyframe(self, surface_xyz: torch.Tensor, surface_normal_or_context: torch.Tensor): 168 | pass 169 | 170 | 171 | def _expand_flatten_id(self, base_flatten_id: torch.Tensor, ensure_valid: bool = True): 172 | expanded_flatten_id = [base_flatten_id] 173 | updated_pos = self._unlinearize_id(base_flatten_id) 174 | for affected_offset in self.mesh_update_affected: 175 | rs_id = updated_pos + affected_offset 176 | for dim in range(3): 177 | rs_id[:, dim].clamp_(0, self.n_xyz[dim] - 1) 178 | rs_id = self._linearize_id(rs_id) 179 | if ensure_valid: 180 | rs_id = rs_id[self.indexer[rs_id] != -1] 181 | expanded_flatten_id.append(rs_id) 182 | expanded_flatten_id = torch.unique(torch.cat(expanded_flatten_id)) 183 | return expanded_flatten_id 184 | 185 | 186 | def _expand_flatten_id_orthogonal(self, base_flatten_id: torch.Tensor, main_direction: torch.tensor, ensure_valid: bool = True): 187 | expanded_flatten_id = [base_flatten_id] 188 | updated_pos = self._unlinearize_id(base_flatten_id) 189 | for affected_offset in self.mesh_update_affected: 190 | valid_mask = (affected_offset * main_direction).sum(1).abs() < .5 # 60 degree 191 | rs_id = updated_pos[valid_mask,:] + affected_offset 192 | for dim in range(3): 193 | rs_id[:, dim].clamp_(0, self.n_xyz[dim] - 1) 194 | rs_id = self._linearize_id(rs_id) 195 | if ensure_valid: 196 | rs_id = rs_id[self.indexer[rs_id] != -1] 197 | expanded_flatten_id.append(rs_id) 198 | expanded_flatten_id = torch.unique(torch.cat(expanded_flatten_id)) 199 | return expanded_flatten_id 200 | 201 | STATUS_CONF_BIT = 1 << 0 # 1 202 | STATUS_SURF_BIT = 1 << 1 # 2 203 | 204 | 205 | 206 | 207 | 208 | class MeshExtractCache: 209 | def __init__(self, device): 210 | self.vertices = None 211 | self.vertices_flatten_id = None 212 | self.vertices_std = None 213 | self.updated_vec_id = None 214 | self.device = device 215 | self.clear_updated_vec() 216 | 217 | def clear_updated_vec(self): 218 | self.updated_vec_id = torch.empty((0, ), device=self.device, dtype=torch.long) 219 | 220 | def clear_all(self): 221 | self.vertices = None 222 | self.vertices_flatten_id = None 223 | self.vertices_std = None 224 | self.updated_vec_id = None 225 | self.clear_updated_vec() 226 | 227 | 228 | -------------------------------------------------------------------------------- /uni/mapper/latent_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | import argparse 5 | 6 | # 8 points 7 | from .context_map_v2 import ContextMap 8 | 9 | 10 | 11 | import pdb 12 | class LatentMap(ContextMap): 13 | def __init__(self, uni_model, args: argparse.Namespace, 14 | latent_dim: int, device: torch.device, enable_async: bool = False): 15 | super().__init__(uni_model, args, latent_dim, device, enable_async) 16 | 17 | def infer(self, X_test, F_tx, classify): 18 | ''' 19 | x: N,3 20 | 21 | F_tx: the text feature 22 | ''' 23 | # get vid 24 | surface_xyz_zeroed = X_test - self.bound_min.unsqueeze(0) 25 | surface_xyz_normalized = surface_xyz_zeroed / self.voxel_size 26 | 27 | vertex = torch.ceil(surface_xyz_normalized) - 1 28 | surface_grid_id = self._linearize_id(vertex.long()) 29 | d_xyz = surface_xyz_normalized - vertex - 0.5 30 | with torch.no_grad(): 31 | pinds = self.indexer[surface_grid_id] 32 | Fs = self.latent_vecs[pinds,:,:] 33 | latents = self.model.color_decoding(Fs.unsqueeze(0), d_xyz.unsqueeze(0)/2) 34 | seg_scores = classify(latents, F_tx) 35 | return seg_scores 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /uni/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/uni/tracker/__init__.py -------------------------------------------------------------------------------- /uni/tracker/cicp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import pdb 4 | 5 | 6 | def preprocess_point_cloud(pcd, voxel_size): 7 | print(":: Downsample with a voxel size %.3f." % voxel_size) 8 | pcd_down = pcd.voxel_down_sample(voxel_size) 9 | 10 | radius_normal = voxel_size * 2 11 | print(":: Estimate normal with search radius %.3f." % radius_normal) 12 | pcd_down.estimate_normals( 13 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30)) 14 | 15 | radius_feature = voxel_size * 5 16 | print(":: Compute FPFH feature with search radius %.3f." % radius_feature) 17 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature( 18 | pcd_down, 19 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100)) 20 | return pcd_down, pcd_fpfh 21 | def execute_global_registration(source_down, target_down, source_fpfh, 22 | target_fpfh, voxel_size): 23 | distance_threshold = voxel_size * 1.5 24 | print(":: RANSAC registration on downsampled point clouds.") 25 | print(" Since the downsampling voxel size is %.3f," % voxel_size) 26 | print(" we use a liberal distance threshold %.3f." % distance_threshold) 27 | result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( 28 | source_down, target_down, source_fpfh, target_fpfh, True, 29 | distance_threshold, 30 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 31 | 3, [ 32 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength( 33 | 0.9), 34 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance( 35 | distance_threshold) 36 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(100000, 0.999)) 37 | return result 38 | def execute_fast_global_registration(source_down, target_down, source_fpfh, 39 | target_fpfh, voxel_size): 40 | distance_threshold = voxel_size * 0.5 41 | print(":: Apply fast global registration with distance threshold %.3f" \ 42 | % distance_threshold) 43 | result = o3d.pipelines.registration.registration_fgr_based_on_feature_matching( 44 | source_down, target_down, source_fpfh, target_fpfh, 45 | o3d.pipelines.registration.FastGlobalRegistrationOption( 46 | maximum_correspondence_distance=distance_threshold)) 47 | return result 48 | 49 | def cicp(source, target, current_transformation=np.identity(4), scale_factor=2): 50 | voxel_radius = [0.04*scale_factor, 0.02*scale_factor, 0.01*scale_factor] 51 | 52 | max_iter = [50, 30, 14] 53 | #current_transformation = np.identity(4) 54 | #print("3. Colored point cloud registration") 55 | for scale in range(3): 56 | iter = max_iter[scale] 57 | radius = voxel_radius[scale] 58 | #print([iter, radius, scale]) 59 | 60 | #print("3-1. Downsample with a voxel size %.2f" % radius) 61 | source_down = source.voxel_down_sample(radius) 62 | target_down = target.voxel_down_sample(radius) 63 | 64 | #print("3-2. Estimate normal.") 65 | source_down.estimate_normals( 66 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30)) 67 | target_down.estimate_normals( 68 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30)) 69 | 70 | #print("3-3. Applying colored point cloud registration") 71 | result_icp = o3d.pipelines.registration.registration_colored_icp( 72 | source_down, target_down, radius, current_transformation, 73 | o3d.pipelines.registration.TransformationEstimationForColoredICP(), 74 | o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6, 75 | relative_rmse=1e-6, 76 | max_iteration=iter)) 77 | 78 | if result_icp.fitness > .9: 79 | return current_transformation.copy() 80 | 81 | ''' 82 | deprecated 83 | ''' 84 | if False: #result_icp.fitness < 0.5: 85 | pdb.set_trace() 86 | 87 | # 1. centering the point cloud 88 | source_ct = source.get_center() #3,1 89 | target_ct = target.get_center() 90 | source_ = source.translate(-source_ct) 91 | target_ = target.translate(-target_ct) 92 | 93 | iter = max_iter[scale] * 10 94 | radius = voxel_radius[scale] 95 | #print([iter, radius, scale]) 96 | 97 | #print("3-1. Downsample with a voxel size %.2f" % radius) 98 | source_down = source_.voxel_down_sample(radius) 99 | target_down = target_.voxel_down_sample(radius) 100 | 101 | #print("3-2. Estimate normal.") 102 | source_down.estimate_normals( 103 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30)) 104 | target_down.estimate_normals( 105 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30)) 106 | 107 | #print("3-3. Applying colored point cloud registration") 108 | result_icp = o3d.pipelines.registration.registration_colored_icp( 109 | source_down, target_down, radius, current_transformation, 110 | o3d.pipelines.registration.TransformationEstimationForColoredICP(), 111 | o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6, 112 | relative_rmse=1e-6, 113 | max_iteration=iter)) 114 | ''' 115 | o3d.io.write_point_cloud('tmp1.pcd', source) 116 | o3d.io.write_point_cloud('tmp2.pcd', target) 117 | 118 | source_down, source_fpfh = preprocess_point_cloud(source, voxel_size=radius/2) 119 | target_down, target_fpfh = preprocess_point_cloud(target, voxel_size=radius/2) 120 | result_ransac = execute_fast_global_registration(source_down, target_down, 121 | source_fpfh, target_fpfh, 122 | voxel_size=radius/2) 123 | 124 | 125 | print('refined fitness', result_ransac.fitness) 126 | current_transformation = result_ransac.transformation 127 | ''' 128 | current_transformation = result_icp.transformation 129 | current_transformation = current_transformation.copy() 130 | current_transformation[:3,3] += (target_ct - current_transformation[:3,:3].dot(source_ct)) 131 | print("reg result",result_icp.fitness, result_icp.inlier_rmse) 132 | else: 133 | current_transformation = result_icp.transformation 134 | return current_transformation.copy() 135 | 136 | 137 | def poseEstimate2(pre_depth_im, pre_rgb_im, cur_depth_im, cur_rgb_im, calib, current_transformation=np.identity(4)): 138 | ''' from Kinect fusion: https://github.com/chengkunli96/KinectFusion/blob/main/src/kinect_fusion/fusion.py#L190 139 | ''' 140 | """Colored Point Cloud Registration. Park's method""" 141 | fx, fy, cx, cy = calib 142 | 143 | depth_o3d_img = o3d.geometry.Image((pre_depth_im).astype(np.float32)) 144 | color_o3d_img = o3d.geometry.Image((pre_rgb_im).astype(np.float32)) 145 | pre_rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color_o3d_img, depth_o3d_img, 1, depth_trunc=10) 146 | 147 | depth_o3d_img = o3d.geometry.Image((cur_depth_im).astype(np.float32)) 148 | color_o3d_img = o3d.geometry.Image((cur_rgb_im).astype(np.float32)) 149 | curr_rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color_o3d_img, depth_o3d_img, 1, depth_trunc=10.) 150 | 151 | pinhole_camera_intrinsic = o3d.camera.PinholeCameraIntrinsic( 152 | width=cur_depth_im.shape[1], 153 | height=cur_depth_im.shape[0], 154 | fx=fx, 155 | fy=fy, 156 | cx=cx, 157 | cy=cy, 158 | ) 159 | 160 | odo_init = current_transformation 161 | option = o3d.pipelines.odometry.OdometryOption() 162 | 163 | [success_hybrid_term, trans_hybrid_term,info] = o3d.pipelines.odometry.compute_rgbd_odometry( 164 | curr_rgbd_image, pre_rgbd_image, pinhole_camera_intrinsic, odo_init, 165 | o3d.pipelines.odometry.RGBDOdometryJacobianFromHybridTerm(), option) 166 | 167 | transform = np.array(trans_hybrid_term) 168 | return transform 169 | -------------------------------------------------------------------------------- /uni/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jarrome/Uni-Fusion/c49461d5aa53375380ca2eb7fbdf4ce8e14aeda0/uni/utils/__init__.py -------------------------------------------------------------------------------- /uni/utils/exp_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import numpy as np 4 | import sys 5 | import json 6 | import yaml 7 | import random 8 | import pickle 9 | from collections import defaultdict, OrderedDict 10 | 11 | 12 | def parse_config_json(json_path: Path, args: argparse.Namespace = None): 13 | """ 14 | Parse a json file and add key:value to args namespace. 15 | Json file format [ {attr}, {attr}, ... ] 16 | {attr} = { "_": COMMENT, VAR_NAME: VAR_VALUE } 17 | """ 18 | if args is None: 19 | args = argparse.Namespace() 20 | 21 | with json_path.open() as f: 22 | json_text = f.read() 23 | 24 | try: 25 | raw_configs = json.loads(json_text) 26 | except: 27 | # Do some fixing of the json text 28 | json_text = json_text.replace("\'", "\"") 29 | json_text = json_text.replace("None", "null") 30 | json_text = json_text.replace("False", "false") 31 | json_text = json_text.replace("True", "true") 32 | raw_configs = json.loads(json_text) 33 | 34 | if isinstance(raw_configs, dict): 35 | raw_configs = [raw_configs] 36 | configs = {} 37 | for raw_config in raw_configs: 38 | for rkey, rvalue in raw_config.items(): 39 | if rkey != "_": 40 | configs[rkey] = rvalue 41 | 42 | if configs is not None: 43 | for ckey, cvalue in configs.items(): 44 | args.__dict__[ckey] = cvalue 45 | return args 46 | 47 | 48 | def parse_config_yaml(yaml_path: Path, args: argparse.Namespace = None, override: bool = True): 49 | """ 50 | Parse a yaml file and add key:value to args namespace. 51 | """ 52 | if args is None: 53 | args = argparse.Namespace() 54 | with yaml_path.open() as f: 55 | configs = yaml.load(f, Loader=yaml.FullLoader) 56 | if configs is not None: 57 | if "include_configs" in configs.keys(): 58 | base_config = configs["include_configs"] 59 | del configs["include_configs"] 60 | base_config_path = yaml_path.parent / Path(base_config) 61 | with base_config_path.open() as f: 62 | base_config = yaml.load(f, Loader=yaml.FullLoader) 63 | base_config.update(configs) 64 | configs = base_config 65 | for ckey, cvalue in configs.items(): 66 | if override or ckey not in args.__dict__.keys(): 67 | args.__dict__[ckey] = cvalue 68 | return args 69 | 70 | 71 | def dict_to_args(data: dict): 72 | args = argparse.Namespace() 73 | for ckey, cvalue in data.items(): 74 | args.__dict__[ckey] = cvalue 75 | return args 76 | 77 | 78 | class ArgumentParserX(argparse.ArgumentParser): 79 | def __init__(self, base_config_path=None, add_hyper_arg=True, **kwargs): 80 | super().__init__(**kwargs) 81 | self.add_hyper_arg = add_hyper_arg 82 | self.base_config_path = base_config_path 83 | if self.add_hyper_arg: 84 | self.add_argument('hyper', type=str, help='Path to the yaml parameter') 85 | self.add_argument('--exec', type=str, help='Extract code to modify the args') 86 | 87 | def parse_args(self, args=None, namespace=None): 88 | # Parse arg for the first time to extract args defined in program. 89 | _args = self.parse_known_args(args, namespace)[0] 90 | # Add the types needed. 91 | file_args = argparse.Namespace() 92 | if self.base_config_path is not None: 93 | file_args = parse_config_yaml(Path(self.base_config_path), file_args) 94 | if self.add_hyper_arg: 95 | if _args.hyper.endswith("json"): 96 | file_args = parse_config_json(Path(_args.hyper), file_args) 97 | else: 98 | file_args = parse_config_yaml(Path(_args.hyper), file_args) 99 | for ckey, cvalue in file_args.__dict__.items(): 100 | try: 101 | self.add_argument('--' + ckey, type=type(cvalue), default=cvalue, required=False) 102 | except argparse.ArgumentError: 103 | continue 104 | # Parse args fully to extract all useful information 105 | _args = super().parse_args(args, namespace) 106 | # After that, execute exec part. 107 | exec_code = _args.exec 108 | if exec_code is not None: 109 | for exec_cmd in exec_code.split(";"): 110 | exec_cmd = "_args." + exec_cmd.strip() 111 | exec(exec_cmd) 112 | return _args 113 | 114 | 115 | class AverageMeter: 116 | def __init__(self): 117 | self.loss_dict = OrderedDict() 118 | 119 | def export(self, f): 120 | if isinstance(f, str): 121 | f = open(f, 'wb') 122 | pickle.dump(self.loss_dict, f) 123 | 124 | def load(self, f): 125 | if isinstance(f, str): 126 | f = open(f, 'rb') 127 | self.loss_dict = pickle.load(f) 128 | return self 129 | 130 | def append_loss(self, losses): 131 | for loss_name, loss_val in losses.items(): 132 | if loss_val is None: 133 | continue 134 | loss_val = float(loss_val) 135 | if np.isnan(loss_val): 136 | continue 137 | if loss_name not in self.loss_dict.keys(): 138 | self.loss_dict.update({loss_name: [loss_val]}) 139 | else: 140 | self.loss_dict[loss_name].append(loss_val) 141 | 142 | def get_mean_loss_dict(self): 143 | loss_dict = {} 144 | for loss_name, loss_arr in self.loss_dict.items(): 145 | loss_dict[loss_name] = np.mean(loss_arr) 146 | return loss_dict 147 | 148 | def get_mean_loss(self): 149 | mean_loss_dict = self.get_mean_loss_dict() 150 | if len(mean_loss_dict) == 0: 151 | return 0.0 152 | else: 153 | return sum(mean_loss_dict.values()) / len(mean_loss_dict) 154 | 155 | def get_printable_mean(self): 156 | text = "" 157 | all_loss_sum = 0.0 158 | for loss_name, loss_mean in self.get_mean_loss_dict().items(): 159 | all_loss_sum += loss_mean 160 | text += "(%s:%.4f) " % (loss_name, loss_mean) 161 | text += " sum = %.4f" % all_loss_sum 162 | return text 163 | 164 | def get_newest_loss_dict(self, return_count=False): 165 | loss_dict = {} 166 | loss_count_dict = {} 167 | for loss_name, loss_arr in self.loss_dict.items(): 168 | if len(loss_arr) > 0: 169 | loss_dict[loss_name] = loss_arr[-1] 170 | loss_count_dict[loss_name] = len(loss_arr) 171 | if return_count: 172 | return loss_dict, loss_count_dict 173 | else: 174 | return loss_dict 175 | 176 | def get_printable_newest(self): 177 | nloss_val, nloss_count = self.get_newest_loss_dict(return_count=True) 178 | return ", ".join([f"{loss_name}[{nloss_count[loss_name] - 1}]: {nloss_val[loss_name]}" 179 | for loss_name in nloss_val.keys()]) 180 | 181 | def print_format_loss(self, color=None): 182 | if hasattr(sys.stdout, "terminal"): 183 | color_device = sys.stdout.terminal 184 | else: 185 | color_device = sys.stdout 186 | if color == "y": 187 | color_device.write('\033[93m') 188 | elif color == "g": 189 | color_device.write('\033[92m') 190 | elif color == "b": 191 | color_device.write('\033[94m') 192 | print(self.get_printable_mean(), flush=True) 193 | if color is not None: 194 | color_device.write('\033[0m') 195 | 196 | 197 | class RunningAverageMeter: 198 | def __init__(self, alpha=1.0): 199 | self.alpha = alpha 200 | self.loss_dict = OrderedDict() 201 | 202 | def append_loss(self, losses): 203 | for loss_name, loss_val in losses.items(): 204 | if loss_val is None: 205 | continue 206 | loss_val = float(loss_val) 207 | if np.isnan(loss_val): 208 | continue 209 | if loss_name not in self.loss_dict.keys(): 210 | self.loss_dict.update({loss_name: loss_val}) 211 | else: 212 | old_mean = self.loss_dict[loss_name] 213 | self.loss_dict[loss_name] = self.alpha * old_mean + (1 - self.alpha) * loss_val 214 | 215 | def get_loss_dict(self): 216 | return {k: v for k, v in self.loss_dict.items()} 217 | 218 | 219 | def init_seed(seed=0): 220 | random.seed(seed) 221 | np.random.seed(seed) 222 | # According to https://pytorch.org/docs/stable/notes/randomness.html, 223 | # As pytorch run-to-run reproducibility is not guaranteed, and perhaps will lead to performance degradation, 224 | # We do not use manual seed for training. 225 | # This would influence stochastic network layers but will not influence data generation and processing w/o pytorch. 226 | # torch.manual_seed(seed) 227 | # torch.backends.cudnn.deterministic = True 228 | # torch.backends.cudnn.benchmark = False 229 | 230 | 231 | class CombinedChunkLoss: 232 | def __init__(self): 233 | self.loss_dict = None 234 | self.loss_sum_dict = None 235 | self.clear() 236 | 237 | def add_loss(self, name, val): 238 | self.loss_dict[name] = val 239 | self.loss_sum_dict[name] += val.item() 240 | 241 | def update_loss_dict(self, loss_dict: dict): 242 | for l_name, l_val in loss_dict.items(): 243 | self.add_loss(l_name, l_val) 244 | 245 | def get_total_loss(self): 246 | # Note: to reduce memory, we need to clear the referenced graph. 247 | total_loss = sum(self.loss_dict.values()) 248 | self.loss_dict = {} 249 | return total_loss 250 | 251 | def get_accumulated_loss_dict(self): 252 | return self.loss_sum_dict 253 | 254 | def clear(self): 255 | self.loss_dict = {} 256 | self.loss_sum_dict = defaultdict(float) 257 | -------------------------------------------------------------------------------- /uni/utils/linalg_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def IRLS(y,X,maxiter=100, w=None, IRLS_p = 1, d=0.0001, tolerance=1e-3): 3 | n,p = X.shape 4 | delta = torch.ones((1,n),dtype=torch.float64).to(X) * d 5 | if w is None: 6 | w = torch.ones((1,n),dtype=torch.float64).to(X) 7 | #W = torch.diag(w[0,:]) # n,n 8 | #XTW = X.transpose(0,1).matmul(W) 9 | XTW = X.transpose(0,1)*w #p,n 10 | B = XTW.matmul(X).inverse().matmul(XTW.matmul(y)) 11 | for _ in range(maxiter): 12 | _B = B 13 | _w = torch.abs(y-X.matmul(B)).transpose(0,1) 14 | #w = 1./torch.max(delta,_w) 15 | w = torch.max(delta,_w) ** (IRLS_p-2) 16 | #W = torch.diag(w[0,:]) 17 | #XTW = X.transpose(0,1).matmul(W) 18 | XTW = X.transpose(0,1)*w 19 | B = XTW.matmul(X).inverse().matmul(XTW.matmul(y)) 20 | tol = torch.abs(B-_B).sum() 21 | if tol < tolerance: 22 | return B, w 23 | return B, w 24 | -------------------------------------------------------------------------------- /uni/utils/motion_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyquaternion import Quaternion 3 | 4 | 5 | def so3_vee(Phi): 6 | if Phi.ndim < 3: 7 | Phi = np.expand_dims(Phi, axis=0) 8 | 9 | if Phi.shape[1:3] != (3, 3): 10 | raise ValueError("Phi must have shape ({},{}) or (N,{},{})".format(3, 3, 3, 3)) 11 | 12 | phi = np.empty([Phi.shape[0], 3]) 13 | phi[:, 0] = Phi[:, 2, 1] 14 | phi[:, 1] = Phi[:, 0, 2] 15 | phi[:, 2] = Phi[:, 1, 0] 16 | return np.squeeze(phi) 17 | 18 | 19 | def so3_wedge(phi): 20 | phi = np.atleast_2d(phi) 21 | if phi.shape[1] != 3: 22 | raise ValueError( 23 | "phi must have shape ({},) or (N,{})".format(3, 3)) 24 | 25 | Phi = np.zeros([phi.shape[0], 3, 3]) 26 | Phi[:, 0, 1] = -phi[:, 2] 27 | Phi[:, 1, 0] = phi[:, 2] 28 | Phi[:, 0, 2] = phi[:, 1] 29 | Phi[:, 2, 0] = -phi[:, 1] 30 | Phi[:, 1, 2] = -phi[:, 0] 31 | Phi[:, 2, 1] = phi[:, 0] 32 | return np.squeeze(Phi) 33 | 34 | 35 | def so3_log(matrix): 36 | cos_angle = 0.5 * np.trace(matrix) - 0.5 37 | cos_angle = np.clip(cos_angle, -1., 1.) 38 | angle = np.arccos(cos_angle) 39 | if np.isclose(angle, 0.): 40 | return so3_vee(matrix - np.identity(3)) 41 | else: 42 | return so3_vee((0.5 * angle / np.sin(angle)) * (matrix - matrix.T)) 43 | 44 | 45 | def so3_left_jacobian(phi): 46 | angle = np.linalg.norm(phi) 47 | 48 | if np.isclose(angle, 0.): 49 | return np.identity(3) + 0.5 * so3_wedge(phi) 50 | 51 | axis = phi / angle 52 | s = np.sin(angle) 53 | c = np.cos(angle) 54 | 55 | return (s / angle) * np.identity(3) + \ 56 | (1 - s / angle) * np.outer(axis, axis) + \ 57 | ((1 - c) / angle) * so3_wedge(axis) 58 | 59 | 60 | def se3_curlywedge(xi): 61 | xi = np.atleast_2d(xi) 62 | 63 | Psi = np.zeros([xi.shape[0], 6, 6]) 64 | Psi[:, 0:3, 0:3] = so3_wedge(xi[:, 3:6]) 65 | Psi[:, 0:3, 3:6] = so3_wedge(xi[:, 0:3]) 66 | Psi[:, 3:6, 3:6] = Psi[:, 0:3, 0:3] 67 | 68 | return np.squeeze(Psi) 69 | 70 | 71 | def se3_left_jacobian_Q_matrix(xi): 72 | rho = xi[0:3] # translation part 73 | phi = xi[3:6] # rotation part 74 | 75 | rx = so3_wedge(rho) 76 | px = so3_wedge(phi) 77 | 78 | ph = np.linalg.norm(phi) 79 | ph2 = ph * ph 80 | ph3 = ph2 * ph 81 | ph4 = ph3 * ph 82 | ph5 = ph4 * ph 83 | 84 | cph = np.cos(ph) 85 | sph = np.sin(ph) 86 | 87 | m1 = 0.5 88 | m2 = (ph - sph) / ph3 89 | m3 = (0.5 * ph2 + cph - 1.) / ph4 90 | m4 = (ph - 1.5 * sph + 0.5 * ph * cph) / ph5 91 | 92 | t1 = rx 93 | t2 = px.dot(rx) + rx.dot(px) + px.dot(rx).dot(px) 94 | t3 = px.dot(px).dot(rx) + rx.dot(px).dot(px) - 3. * px.dot(rx).dot(px) 95 | t4 = px.dot(rx).dot(px).dot(px) + px.dot(px).dot(rx).dot(px) 96 | 97 | return m1 * t1 + m2 * t2 + m3 * t3 + m4 * t4 98 | 99 | 100 | def se3_left_jacobian(xi): 101 | rho = xi[0:3] # translation part 102 | phi = xi[3:6] # rotation part 103 | 104 | # Near |phi|==0, use first order Taylor expansion 105 | if np.isclose(np.linalg.norm(phi), 0.): 106 | return np.identity(6) + 0.5 * se3_curlywedge(xi) 107 | 108 | so3_jac = so3_left_jacobian(phi) 109 | Q_mat = se3_left_jacobian_Q_matrix(xi) 110 | 111 | jac = np.zeros([6, 6]) 112 | jac[0:3, 0:3] = so3_jac 113 | jac[0:3, 3:6] = Q_mat 114 | jac[3:6, 3:6] = so3_jac 115 | 116 | return jac 117 | 118 | 119 | def se3_inv_left_jacobian(xi): 120 | rho = xi[0:3] # translation part 121 | phi = xi[3:6] # rotation part 122 | 123 | # Near |phi|==0, use first order Taylor expansion 124 | if np.isclose(np.linalg.norm(phi), 0.): 125 | return np.identity(6) - 0.5 * se3_curlywedge(xi) 126 | 127 | so3_inv_jac = so3_inv_left_jacobian(phi) 128 | Q_mat = se3_left_jacobian_Q_matrix(xi) 129 | 130 | jac = np.zeros([6, 6]) 131 | jac[0:3, 0:3] = so3_inv_jac 132 | jac[0:3, 3:6] = -so3_inv_jac.dot(Q_mat).dot(so3_inv_jac) 133 | jac[3:6, 3:6] = so3_inv_jac 134 | 135 | return jac 136 | 137 | 138 | def so3_inv_left_jacobian(phi): 139 | angle = np.linalg.norm(phi) 140 | 141 | if np.isclose(angle, 0.): 142 | return np.identity(3) - 0.5 * so3_wedge(phi) 143 | 144 | axis = phi / angle 145 | half_angle = 0.5 * angle 146 | cot_half_angle = 1. / np.tan(half_angle) 147 | 148 | return half_angle * cot_half_angle * np.identity(3) + \ 149 | (1 - half_angle * cot_half_angle) * np.outer(axis, axis) - \ 150 | half_angle * so3_wedge(axis) 151 | 152 | 153 | def project_orthogonal(rot): 154 | u, s, vh = np.linalg.svd(rot, full_matrices=True, compute_uv=True) 155 | rot = u @ vh 156 | if np.linalg.det(rot) < 0: 157 | u[:, 2] = -u[:, 2] 158 | rot = u @ vh 159 | return rot 160 | 161 | 162 | class Isometry: 163 | GL_POST_MULT = Quaternion(degrees=180.0, axis=[1.0, 0.0, 0.0]) 164 | 165 | def __init__(self, q=None, t=None): 166 | if q is None: 167 | q = Quaternion() 168 | if t is None: 169 | t = np.zeros(3) 170 | if not isinstance(t, np.ndarray): 171 | t = np.asarray(t) 172 | assert t.shape[0] == 3 and t.ndim == 1 173 | self.q = q 174 | self.t = t 175 | 176 | def __repr__(self): 177 | return f"Isometry: t = {self.t}, q = {self.q}" 178 | 179 | @property 180 | def rotation(self): 181 | return Isometry(q=self.q) 182 | 183 | @property 184 | def matrix(self): 185 | mat = self.q.transformation_matrix 186 | mat[0:3, 3] = self.t 187 | return mat 188 | 189 | @staticmethod 190 | def from_matrix(mat, t_component=None, ortho=False): 191 | assert isinstance(mat, np.ndarray) 192 | if t_component is None: 193 | assert mat.shape == (4, 4) 194 | if ortho: 195 | mat[:3, :3] = project_orthogonal(mat[:3, :3]) 196 | return Isometry(q=Quaternion(matrix=mat), t=mat[:3, 3]) 197 | else: 198 | assert mat.shape == (3, 3) 199 | assert t_component.shape == (3,) 200 | if ortho: 201 | mat = project_orthogonal(mat) 202 | return Isometry(q=Quaternion(matrix=mat), t=t_component) 203 | 204 | @staticmethod 205 | def from_twist(xi: np.ndarray): 206 | rho = xi[:3] 207 | phi = xi[3:6] 208 | iso = Isometry.from_so3_exp(phi) 209 | iso.t = so3_left_jacobian(phi) @ rho 210 | return iso 211 | 212 | @staticmethod 213 | def from_so3_exp(phi: np.ndarray): 214 | angle = np.linalg.norm(phi) 215 | 216 | # Near phi==0, use first order Taylor expansion 217 | if np.isclose(angle, 0.): 218 | return Isometry(q=Quaternion(matrix=np.identity(3) + so3_wedge(phi))) 219 | 220 | axis = phi / angle 221 | s = np.sin(angle) 222 | c = np.cos(angle) 223 | 224 | rot_mat = (c * np.identity(3) + 225 | (1 - c) * np.outer(axis, axis) + 226 | s * so3_wedge(axis)) 227 | return Isometry(q=Quaternion(matrix=rot_mat, rtol=1e-05, atol=1e-05)) 228 | 229 | @property 230 | def continuous_repr(self): 231 | rot = self.q.rotation_matrix[:, 0:2].T.flatten() # (6,) 232 | return np.concatenate([rot, self.t]) # (9,) 233 | 234 | @staticmethod 235 | def from_continuous_repr(rep, gs=True): 236 | if isinstance(rep, list): 237 | rep = np.asarray(rep) 238 | assert isinstance(rep, np.ndarray) 239 | assert rep.shape == (9,) 240 | # For rotation, use Gram-Schmidt orthogonalization 241 | col1 = rep[0:3] 242 | col2 = rep[3:6] 243 | if gs: 244 | col1 /= np.linalg.norm(col1) 245 | col2 = col2 - np.dot(col1, col2) * col1 246 | col2 /= np.linalg.norm(col2) 247 | col3 = np.cross(col1, col2) 248 | return Isometry(q=Quaternion(matrix=np.column_stack([col1, col2, col3])), t=rep[6:9]) 249 | 250 | @property 251 | def full_repr(self): 252 | rot = self.q.rotation_matrix.T.flatten() 253 | return np.concatenate([rot, self.t]) 254 | 255 | @staticmethod 256 | def from_full_repr(rep, ortho=False): 257 | assert isinstance(rep, np.ndarray) 258 | assert rep.shape == (12,) 259 | rot = rep[0:9].reshape(3, 3).T 260 | if ortho: 261 | rot = project_orthogonal(rot) 262 | return Isometry(q=Quaternion(matrix=rot), t=rep[9:12]) 263 | 264 | def torch_matrices(self, device): 265 | import torch 266 | return torch.from_numpy(self.q.rotation_matrix).to(device).float(), \ 267 | torch.from_numpy(self.t).to(device).float() 268 | 269 | @staticmethod 270 | def random(): 271 | return Isometry(q=Quaternion.random(), t=np.random.random((3,))) 272 | 273 | def inv(self): 274 | qinv = self.q.inverse 275 | return Isometry(q=qinv, t=-(qinv.rotate(self.t))) 276 | 277 | def dot(self, right): 278 | return Isometry(q=(self.q * right.q), t=(self.q.rotate(right.t) + self.t)) 279 | 280 | def to_gl_camera(self): 281 | return Isometry(q=(self.q * self.GL_POST_MULT), t=self.t) 282 | 283 | @staticmethod 284 | def look_at(source: np.ndarray, target: np.ndarray, up: np.ndarray = None): 285 | z_dir = target - source 286 | z_dir /= np.linalg.norm(z_dir) 287 | if up is None: 288 | up = np.asarray([0.0, 1.0, 0.0]) 289 | if np.linalg.norm(np.cross(z_dir, up)) < 1e-6: 290 | up = np.asarray([1.0, 0.0, 0.0]) 291 | else: 292 | up /= np.linalg.norm(up) 293 | x_dir = np.cross(z_dir, up) 294 | x_dir /= np.linalg.norm(x_dir) 295 | y_dir = np.cross(z_dir, x_dir) 296 | R = np.column_stack([x_dir, y_dir, z_dir]) 297 | return Isometry(q=Quaternion(matrix=R), t=source) 298 | 299 | def adjoint_matrix(self): 300 | R = self.q.rotation_matrix 301 | twR = so3_wedge(self.t) @ R 302 | adjoint = np.zeros((6, 6)) 303 | adjoint[0:3, 0:3] = R 304 | adjoint[3:6, 3:6] = R 305 | adjoint[0:3, 3:6] = twR 306 | return adjoint 307 | 308 | def log(self): 309 | phi = so3_log(self.q.rotation_matrix) 310 | rho = so3_inv_left_jacobian(phi) @ self.t 311 | return np.hstack([rho, phi]) 312 | 313 | def tangent(self, prev_iso, next_iso): 314 | t = 0.5 * (next_iso.t - prev_iso.t) 315 | l1 = Quaternion.log((self.q.inverse * prev_iso.q).normalised) 316 | l2 = Quaternion.log((self.q.inverse * next_iso.q).normalised) 317 | e = Quaternion() 318 | e.q = -0.25 * (l1.q + l2.q) 319 | e = self.q * Quaternion.exp(e) 320 | return Isometry(t=t, q=e) 321 | 322 | def __matmul__(self, other): 323 | # "@" operator: other can be (N,3) or (3,). 324 | if hasattr(other, "device"): # Torch tensor 325 | assert other.ndim == 2 and other.size(1) == 3 # (N,3) 326 | th_R, th_t = self.torch_matrices(other.device) 327 | return other @ th_R.t() + th_t.unsqueeze(0) 328 | if isinstance(other, Isometry): 329 | return self.dot(other) 330 | if type(other) != np.ndarray or other.ndim == 1: 331 | return self.q.rotate(other) + self.t 332 | else: 333 | return other @ self.q.rotation_matrix.T + self.t[np.newaxis, :] 334 | 335 | @staticmethod 336 | def interpolate(source, target, alpha): 337 | iquat = Quaternion.slerp(source.q, target.q, alpha) 338 | it = source.t * (1 - alpha) + target.t * alpha 339 | return Isometry(q=iquat, t=it) 340 | -------------------------------------------------------------------------------- /uni/utils/ray_cast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import pdb 4 | 5 | class RayCaster: 6 | def __init__(self, mesh, H, W, calib): 7 | ''' 8 | calib: 3x3 9 | ''' 10 | # 0. init the scene 11 | self.scene = o3d.t.geometry.RaycastingScene() 12 | self.H = H 13 | self.W = W 14 | #self.calib = calib 15 | self.mesh = mesh 16 | 17 | self.mesh_t = o3d.t.geometry.TriangleMesh.from_legacy(mesh) 18 | obj_id = self.scene.add_triangles(self.mesh_t) 19 | 20 | 21 | # 1. generate image's 3d point 22 | x = np.arange(W) 23 | y = np.arange(H) 24 | #xv, yv = np.meshgrid(x, y, indexing='ij') # WxH 25 | yv, xv = np.meshgrid(y, x, indexing='ij') # HxW 26 | 27 | 28 | img_xyz = np.ones((H*W,3)) 29 | img_xyz[:,0] = xv.reshape(-1) 30 | img_xyz[:,1] = yv.reshape(-1) 31 | self.img_xyz = np.linalg.inv(calib).dot(img_xyz.T).T # HWx3 32 | 33 | 34 | def ray_cast(self, pose): 35 | ''' 36 | pose: 4x4 37 | ''' 38 | # 2. direction 39 | source_pts = pose[:3,(3,)].T # 1,3 40 | direction = (pose[:3,:3]@(self.img_xyz.T)).T # HW, 3 41 | direction = direction / (np.linalg.norm(direction,axis=1, keepdims=True))#+1e-8) 42 | 43 | rays = o3d.core.Tensor(np.concatenate([np.repeat(source_pts,direction.shape[0],0),direction],axis=1), dtype=o3d.core.Dtype.Float32) 44 | 45 | # debug use 46 | ''' 47 | ray_pcd = o3d.geometry.PointCloud() 48 | ray_pcd.points = o3d.utility.Vector3dVector(source_pts + direction) 49 | 50 | o3d.visualization.draw_geometries([self.mesh, ray_pcd]) 51 | ''' 52 | 53 | # 3. the hit distance (depth) is in ans['t_hit'] 54 | ans = self.scene.cast_rays(rays) 55 | 56 | return ans, direction 57 | # the hit point on mesh 58 | back_proj_pts = source_pts + direction * ans['t_hit'].numpy()[:,np.newaxis] # HW,3 59 | if not return_depth: # return point cloud 60 | return back_proj_pts#.reshape((self.H,self.W,3)) 61 | else: 62 | # also get the color 63 | return back_proj_pts, ans['t_hit'].numpy() 64 | ''' 65 | pdb.set_trace() 66 | colors = np.asarray(self.mesh.vertex_colors) 67 | tri_vs = np.asarray(self.mesh.triangles) 68 | 69 | tri_id = ans['primitive_ids'].numpy() 70 | invalid_mask = tri_id == self.scene.INVALID_ID 71 | tri_id[invalid_mask] = 0 72 | 73 | wanted_color = colors[tri_vs[tri_id,0],:] 74 | wanted_color[invalid_mask,:] = 0 75 | 76 | return back_proj_pts, wanted_color 77 | ''' 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /uni/utils/torch_scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | 4 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 5 | if dim < 0: 6 | dim = other.dim() + dim 7 | if src.dim() == 1: 8 | for _ in range(0, dim): 9 | src = src.unsqueeze(0) 10 | for _ in range(src.dim(), other.dim()): 11 | src = src.unsqueeze(-1) 12 | src = src.expand(other.size()) 13 | return src 14 | 15 | def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 16 | out: Optional[torch.Tensor] = None, 17 | dim_size: Optional[int] = None) -> torch.Tensor: 18 | index = broadcast(index, src, dim) 19 | if out is None: 20 | size = list(src.size()) 21 | if dim_size is not None: 22 | size[dim] = dim_size 23 | elif index.numel() == 0: 24 | size[dim] = 0 25 | else: 26 | size[dim] = int(index.max()) + 1 27 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 28 | return out.scatter_add_(dim, index, src) 29 | else: 30 | return out.scatter_add_(dim, index, src) 31 | 32 | 33 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 34 | out: Optional[torch.Tensor] = None, 35 | dim_size: Optional[int] = None) -> torch.Tensor: 36 | return scatter_sum(src, index, dim, out, dim_size) 37 | 38 | 39 | def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 40 | out: Optional[torch.Tensor] = None, 41 | dim_size: Optional[int] = None) -> torch.Tensor: 42 | return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) 43 | 44 | 45 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 46 | out: Optional[torch.Tensor] = None, 47 | dim_size: Optional[int] = None) -> torch.Tensor: 48 | out = scatter_sum(src, index, dim, out, dim_size) 49 | dim_size = out.size(dim) 50 | 51 | index_dim = dim 52 | if index_dim < 0: 53 | index_dim = index_dim + src.dim() 54 | if index.dim() <= index_dim: 55 | index_dim = index.dim() - 1 56 | 57 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 58 | count = scatter_sum(ones, index, index_dim, None, dim_size) 59 | count[count < 1] = 1 60 | count = broadcast(count, out, dim) 61 | if out.is_floating_point(): 62 | out.true_divide_(count) 63 | else: 64 | out.div_(count, rounding_mode='floor') 65 | return out 66 | 67 | 68 | def scatter_min( 69 | src: torch.Tensor, index: torch.Tensor, dim: int = -1, 70 | out: Optional[torch.Tensor] = None, 71 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 72 | return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) 73 | 74 | 75 | def scatter_max( 76 | src: torch.Tensor, index: torch.Tensor, dim: int = -1, 77 | out: Optional[torch.Tensor] = None, 78 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 79 | return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) 80 | 81 | 82 | def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 83 | out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, 84 | reduce: str = "sum") -> torch.Tensor: 85 | r""" 86 | | 87 | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ 88 | master/docs/source/_figures/add.svg?sanitize=true 89 | :align: center 90 | :width: 400px 91 | | 92 | Reduces all values from the :attr:`src` tensor into :attr:`out` at the 93 | indices specified in the :attr:`index` tensor along a given axis 94 | :attr:`dim`. 95 | For each value in :attr:`src`, its output index is specified by its index 96 | in :attr:`src` for dimensions outside of :attr:`dim` and by the 97 | corresponding value in :attr:`index` for dimension :attr:`dim`. 98 | The applied reduction is defined via the :attr:`reduce` argument. 99 | Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional 100 | tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` 101 | and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional 102 | tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. 103 | Moreover, the values of :attr:`index` must be between :math:`0` and 104 | :math:`y - 1`, although no specific ordering of indices is required. 105 | The :attr:`index` tensor supports broadcasting in case its dimensions do 106 | not match with :attr:`src`. 107 | For one-dimensional tensors with :obj:`reduce="sum"`, the operation 108 | computes 109 | .. math:: 110 | \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j 111 | where :math:`\sum_j` is over :math:`j` such that 112 | :math:`\mathrm{index}_j = i`. 113 | .. note:: 114 | This operation is implemented via atomic operations on the GPU and is 115 | therefore **non-deterministic** since the order of parallel operations 116 | to the same value is undetermined. 117 | For floating-point variables, this results in a source of variance in 118 | the result. 119 | :param src: The source tensor. 120 | :param index: The indices of elements to scatter. 121 | :param dim: The axis along which to index. (default: :obj:`-1`) 122 | :param out: The destination tensor. 123 | :param dim_size: If :attr:`out` is not given, automatically create output 124 | with size :attr:`dim_size` at dimension :attr:`dim`. 125 | If :attr:`dim_size` is not given, a minimal sized output tensor 126 | according to :obj:`index.max() + 1` is returned. 127 | :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`, 128 | :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) 129 | :rtype: :class:`Tensor` 130 | .. code-block:: python 131 | from torch_scatter import scatter 132 | src = torch.randn(10, 6, 64) 133 | index = torch.tensor([0, 1, 0, 1, 2, 1]) 134 | # Broadcasting in the first and last dim. 135 | out = scatter(src, index, dim=1, reduce="sum") 136 | print(out.size()) 137 | .. code-block:: 138 | torch.Size([10, 3, 64]) 139 | """ 140 | if reduce == 'sum' or reduce == 'add': 141 | return scatter_sum(src, index, dim, out, dim_size) 142 | if reduce == 'mul': 143 | return scatter_mul(src, index, dim, out, dim_size) 144 | elif reduce == 'mean': 145 | return scatter_mean(src, index, dim, out, dim_size) 146 | elif reduce == 'min': 147 | return scatter_min(src, index, dim, out, dim_size)[0] 148 | elif reduce == 'max': 149 | return scatter_max(src, index, dim, out, dim_size)[0] 150 | else: 151 | raise ValueError 152 | -------------------------------------------------------------------------------- /uni/utils/vis_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import matplotlib.cm 4 | from uni.utils.motion_util import Isometry 5 | 6 | 7 | def pointcloud(pc, color: np.ndarray = None, normal: np.ndarray = None): 8 | if isinstance(pc, o3d.geometry.PointCloud): 9 | if pc.has_normals() and normal is None: 10 | normal = np.asarray(pc.normals) 11 | if pc.has_colors() and color is None: 12 | color = np.asarray(pc.colors) 13 | pc = np.asarray(pc.points) 14 | 15 | assert pc.shape[1] == 3 and len(pc.shape) == 2, f"Point cloud is of size {pc.shape} and cannot be displayed!" 16 | point_cloud = o3d.geometry.PointCloud() 17 | point_cloud.points = o3d.utility.Vector3dVector(pc) 18 | if color is not None: 19 | assert color.shape[0] == pc.shape[0], f"Point and color must have same size {color.shape[0]}, {pc.shape[0]}" 20 | point_cloud.colors = o3d.utility.Vector3dVector(color) 21 | if normal is not None: 22 | point_cloud.normals = o3d.utility.Vector3dVector(normal) 23 | 24 | return point_cloud 25 | 26 | 27 | def frame(transform: Isometry = Isometry(), size=1.0): 28 | frame_obj = o3d.geometry.TriangleMesh.create_coordinate_frame(size=size) 29 | frame_obj.transform(transform.matrix) 30 | return frame_obj 31 | 32 | 33 | def merged_linesets(lineset_list: list): 34 | merged_points = [] 35 | merged_inds = [] 36 | merged_colors = [] 37 | point_acc_ind = 0 38 | for ls in lineset_list: 39 | merged_points.append(np.asarray(ls.points)) 40 | merged_inds.append(np.asarray(ls.lines) + point_acc_ind) 41 | if ls.has_colors(): 42 | merged_colors.append(np.asarray(ls.colors)) 43 | else: 44 | merged_colors.append(np.zeros((len(ls.lines), 3))) 45 | point_acc_ind += len(ls.points) 46 | 47 | geom = o3d.geometry.LineSet( 48 | points=o3d.utility.Vector3dVector(np.vstack(merged_points)), 49 | lines=o3d.utility.Vector2iVector(np.vstack(merged_inds)) 50 | ) 51 | geom.colors = o3d.utility.Vector3dVector(np.vstack(merged_colors)) 52 | return geom 53 | 54 | 55 | def trajectory(traj1: list, traj2: list = None, ucid: int = -1): 56 | if len(traj1) > 0 and isinstance(traj1[0], Isometry): 57 | traj1 = [t.t for t in traj1] 58 | if traj2 and isinstance(traj2[0], Isometry): 59 | traj2 = [t.t for t in traj2] 60 | 61 | traj1_lineset = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(np.asarray(traj1)), 62 | lines=o3d.utility.Vector2iVector(np.vstack((np.arange(0, len(traj1) - 1), 63 | np.arange(1, len(traj1)))).T)) 64 | if ucid != -1: 65 | color_map = np.asarray(matplotlib.cm.get_cmap('tab10').colors) 66 | traj1_lineset.paint_uniform_color(color_map[ucid % 10]) 67 | 68 | if traj2 is not None: 69 | assert len(traj1) == len(traj2) 70 | traj2_lineset = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(np.asarray(traj2)), 71 | lines=o3d.utility.Vector2iVector(np.vstack((np.arange(0, len(traj2) - 1), 72 | np.arange(1, len(traj2)))).T)) 73 | traj_diff = o3d.geometry.LineSet( 74 | points=o3d.utility.Vector3dVector(np.vstack((np.asarray(traj1), np.asarray(traj2)))), 75 | lines=o3d.utility.Vector2iVector(np.arange(2 * len(traj1)).reshape((2, len(traj1))).T)) 76 | traj_diff.colors = o3d.utility.Vector3dVector(np.array([[1.0, 0.0, 0.0]]).repeat(len(traj_diff.lines), axis=0)) 77 | 78 | traj1_lineset = merged_linesets([traj1_lineset, traj2_lineset, traj_diff]) 79 | return traj1_lineset 80 | 81 | 82 | def camera(transform: Isometry = Isometry(), wh_ratio: float = 4.0 / 3.0, scale: float = 1.0, fovx: float = 90.0, 83 | color_id: int = -1): 84 | pw = np.tan(np.deg2rad(fovx / 2.)) * scale 85 | ph = pw / wh_ratio 86 | all_points = np.asarray([ 87 | [0.0, 0.0, 0.0], 88 | [pw, ph, scale], 89 | [pw, -ph, scale], 90 | [-pw, ph, scale], 91 | [-pw, -ph, scale], 92 | ]) 93 | line_indices = np.asarray([ 94 | [0, 1], [0, 2], [0, 3], [0, 4], 95 | [1, 2], [1, 3], [3, 4], [2, 4] 96 | ]) 97 | geom = o3d.geometry.LineSet( 98 | points=o3d.utility.Vector3dVector(all_points), 99 | lines=o3d.utility.Vector2iVector(line_indices)) 100 | 101 | if color_id == -1: 102 | my_color = np.zeros((3,)) 103 | else: 104 | my_color = np.asarray(matplotlib.cm.get_cmap('tab10').colors)[color_id, :3] 105 | geom.colors = o3d.utility.Vector3dVector(np.repeat(np.expand_dims(my_color, 0), line_indices.shape[0], 0)) 106 | 107 | geom.transform(transform.matrix) 108 | return geom 109 | 110 | 111 | def wireframe_bbox(extent_min=None, extent_max=None, color_id=-1): 112 | if extent_min is None: 113 | extent_min = [0.0, 0.0, 0.0] 114 | if extent_max is None: 115 | extent_max = [1.0, 1.0, 1.0] 116 | 117 | if color_id == -1: 118 | my_color = np.zeros((3,)) 119 | else: 120 | my_color = np.asarray(matplotlib.cm.get_cmap('tab10').colors)[color_id, :3] 121 | 122 | all_points = np.asarray([ 123 | [extent_min[0], extent_min[1], extent_min[2]], 124 | [extent_min[0], extent_min[1], extent_max[2]], 125 | [extent_min[0], extent_max[1], extent_min[2]], 126 | [extent_min[0], extent_max[1], extent_max[2]], 127 | [extent_max[0], extent_min[1], extent_min[2]], 128 | [extent_max[0], extent_min[1], extent_max[2]], 129 | [extent_max[0], extent_max[1], extent_min[2]], 130 | [extent_max[0], extent_max[1], extent_max[2]], 131 | ]) 132 | line_indices = np.asarray([ 133 | [0, 1], [2, 3], [4, 5], [6, 7], 134 | [0, 4], [1, 5], [2, 6], [3, 7], 135 | [0, 2], [4, 6], [1, 3], [5, 7] 136 | ]) 137 | geom = o3d.geometry.LineSet( 138 | points=o3d.utility.Vector3dVector(all_points), 139 | lines=o3d.utility.Vector2iVector(line_indices)) 140 | geom.colors = o3d.utility.Vector3dVector(np.repeat(np.expand_dims(my_color, 0), line_indices.shape[0], 0)) 141 | 142 | return geom 143 | -------------------------------------------------------------------------------- /vis_LIMs.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pickle 3 | import importlib 4 | import open3d as o3d 5 | import cv2 6 | import argparse 7 | import logging 8 | import time 9 | import torch 10 | import torch.nn.functional as F 11 | import copy 12 | from plyfile import PlyData 13 | import matplotlib as mpl 14 | p = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 15 | sys.path.append(p) 16 | 17 | 18 | from uni.utils import exp_util, vis_util 19 | 20 | import asyncio 21 | 22 | from uni.encoder import utility 23 | from uni.encoder.uni_encoder_v2 import get_uni_model 24 | 25 | import numpy as np 26 | from uni.mapper.surface_map import SurfaceMap 27 | from uni.mapper.context_map_v2 import ContextMap # 8 points 28 | from uni.mapper.latent_map import LatentMap 29 | 30 | 31 | 32 | import pdb 33 | 34 | import pathlib 35 | 36 | vis_param = argparse.Namespace() 37 | vis_param.n_left_steps = 0 38 | vis_param.args = None 39 | vis_param.mesh_updated = True 40 | # color palette for nyu40 labels 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | 46 | parser = exp_util.ArgumentParserX() 47 | args = parser.parse_args() 48 | logging.basicConfig(level=logging.INFO) 49 | o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) 50 | 51 | # Load in network. (args.model is the network specification) 52 | #model, args_model = utility.load_model(args.training_hypers, args.using_epoch) 53 | args.has_ir = hasattr(args, 'ir_mapping') 54 | args.has_saliency = hasattr(args, 'saliency_mapping') 55 | args.has_style = hasattr(args, 'style_mapping') 56 | args.has_latent = hasattr(args, 'latent_mapping') 57 | 58 | 59 | 60 | 61 | args.surface_mapping = exp_util.dict_to_args(args.surface_mapping) 62 | args.context_mapping = exp_util.dict_to_args(args.context_mapping) 63 | if args.has_ir: 64 | args.ir_mapping = exp_util.dict_to_args(args.ir_mapping) 65 | if args.has_saliency: 66 | args.saliency_mapping = exp_util.dict_to_args(args.saliency_mapping) 67 | if args.has_style: 68 | args.style_mapping = exp_util.dict_to_args(args.style_mapping) 69 | 70 | 71 | 72 | 73 | 74 | 75 | #if hasattr(args, "style_mapping"): 76 | import uni.tracker.tracker_custom as tracker 77 | #args.custom_mapping = exp_util.dict_to_args(args.custom_mapping) 78 | #else: 79 | #import uni.tracker.tracker as tracker 80 | 81 | 82 | args.tracking = exp_util.dict_to_args(args.tracking) 83 | 84 | # Load in sequence. 85 | seq_package, seq_class = args.sequence_type.split(".") 86 | sequence_module = importlib.import_module("uni.dataset." + seq_package) 87 | sequence_module = getattr(sequence_module, seq_class) 88 | vis_param.sequence = sequence_module(**args.sequence_kwargs) 89 | 90 | 91 | 92 | 93 | 94 | if torch.cuda.device_count() > 1: 95 | main_device, aux_device = torch.device("cuda", index=0), torch.device("cuda", index=1) 96 | elif torch.cuda.device_count() == 1: 97 | main_device, aux_device = torch.device("cuda", index=0), None 98 | else: 99 | assert False, "You must have one GPU." 100 | 101 | 102 | # Mapping model 103 | uni_model = get_uni_model(main_device) 104 | vis_param.context_map = ContextMap(uni_model, 105 | args.context_mapping, uni_model.color_code_length, device=main_device, 106 | enable_async=args.run_async) 107 | vis_param.surface_map = SurfaceMap(uni_model, vis_param.context_map, 108 | args.surface_mapping, uni_model.surface_code_length, device=main_device, 109 | enable_async=args.run_async) 110 | if args.has_ir: 111 | vis_param.ir_map = ContextMap(uni_model, 112 | args.ir_mapping, uni_model.ir_code_length, device=main_device, 113 | enable_async=args.run_async) 114 | if args.has_saliency: 115 | vis_param.saliency_map = ContextMap(uni_model, 116 | args.saliency_mapping, uni_model.saliency_code_length, device=main_device, 117 | enable_async=args.run_async) 118 | if args.has_style: 119 | vis_param.style_map = ContextMap(uni_model, 120 | args.style_mapping, uni_model.style_code_length, device=main_device, 121 | enable_async=args.run_async) 122 | 123 | 124 | vis_param.tracker = tracker.SDFTracker(vis_param.surface_map, args.tracking) 125 | vis_param.args = args 126 | 127 | 128 | 129 | # load 130 | maps = dict() 131 | vis_param.surface_map.load(args.outdir+'/surface.lim') 132 | vis_param.context_map.load(args.outdir+'/color.lim') 133 | 134 | if args.has_ir: 135 | vis_param.ir_map.load(args.outdir+'/ir.lim') 136 | maps['ir'] = vis_param.ir_map 137 | if args.has_saliency: 138 | vis_param.saliency_map.load(args.outdir+'/saliency.lim') 139 | maps['saliency'] = vis_param.saliency_map 140 | if args.has_style: 141 | vis_param.style_map.load(args.outdir+'/style.lim') 142 | maps['style'] = vis_param.style_map 143 | 144 | 145 | #vis_param.latent_map.load(args.outdir+'/surface.lim') 146 | 147 | color_mesh = vis_param.surface_map.extract_mesh(vis_param.args.resolution, int(4e7), max_std=0.15, 148 | extract_async=False, interpolate=True, no_cache=True) 149 | color_mesh_transformed = copy.deepcopy(color_mesh).transform(np.linalg.inv(vis_param.sequence.T_gt2uni)) 150 | o3d.io.write_triangle_mesh(args.outdir+'/color_recons.ply', color_mesh_transformed) 151 | 152 | 153 | viridis_palette = mpl.colormaps['plasma'].resampled(8) 154 | cividis_palette = mpl.colormaps['cividis'].resampled(8) 155 | 156 | X_test = torch.from_numpy(np.asarray(color_mesh.vertices)).float().to(main_device) 157 | 158 | if True: #hasattr(args, "style_mapping"): 159 | meshes, LIMs = [], [] 160 | if args.has_ir: 161 | ir_mesh = o3d.geometry.TriangleMesh(color_mesh) 162 | saliency_mesh = o3d.geometry.TriangleMesh(color_mesh) 163 | style_mesh = o3d.geometry.TriangleMesh(color_mesh) 164 | 165 | if args.has_ir: 166 | meshes.append(ir_mesh)#[ir_mesh, saliency_mesh, style_mesh] 167 | LIMs.append(vis_param.ir_map)# = [vis_param.ir_map, vis_param.saliency_map, vis_param.style_map] 168 | if args.has_saliency: 169 | meshes.append(saliency_mesh)#, style_mesh] 170 | LIMs.append(vis_param.saliency_map)#, vis_param.style_map] 171 | if args.has_style: 172 | meshes.append(style_mesh) 173 | LIMs.append(vis_param.style_map) 174 | for name, mesh, LIM in zip(maps.keys(), meshes, LIMs): 175 | v, pinds = LIM.infer(X_test) 176 | if v.dim() == 1 or name == 'saliency': # ir 177 | v_np = v.cpu().numpy() 178 | v_np[v_np<0] = 0 179 | if name == 'ir': 180 | v_np /= v_np.max() 181 | v_np = v_np[:,0] if name == 'saliency' else v_np 182 | if name == 'saliency': 183 | # using platte 184 | v = viridis_palette(v_np)[:,:3] 185 | elif name == 'ir': 186 | v_eq = cv2.equalizeHist((v_np*255).astype(np.uint8)) / 255 187 | v = np.repeat(v_eq, 3, 1) 188 | 189 | else: 190 | v = v.detach().cpu().numpy() 191 | 192 | if name == 'style': 193 | v = v[:,::-1] 194 | 195 | mesh.vertex_colors = o3d.utility.Vector3dVector(v) 196 | mesh.remove_vertices_by_index(np.where(pinds.cpu().numpy()==-1)[0]) 197 | 198 | # transform from LIM coordinate to original coordinate 199 | mesh_transformed = mesh.transform(np.linalg.inv(vis_param.sequence.T_gt2uni)) 200 | o3d.io.write_triangle_mesh(args.outdir+'/%s_recons.ply'%name, mesh_transformed) 201 | 202 | #o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True) 203 | if args.has_latent: 204 | args.latent_mapping = exp_util.dict_to_args(args.latent_mapping) 205 | from external.openseg import openseg_api 206 | print('Loading openseg model...') 207 | f_im, f_tx, f_classify, lang_latent_length = openseg_api.get_api() 208 | print('Loaded!') 209 | del f_im 210 | torch.cuda.empty_cache() 211 | f_im = None 212 | 213 | lang_latent_length = (uni_model.color_code_length[0], lang_latent_length) # 20,512 214 | vis_param.latent_map = LatentMap(uni_model, 215 | args.latent_mapping, lang_latent_length, device=main_device, 216 | enable_async=args.run_async) 217 | vis_param.latent_map.load(args.outdir+'/latent.lim') 218 | 219 | 220 | 221 | 222 | 223 | text_options = ['sofa','desk','sit','work','wood','eat'] 224 | for text in text_options: 225 | test_t = 'other,%s'%text 226 | 227 | 228 | F_tx = f_tx(text) 229 | F_tx = torch.from_numpy(F_tx).cuda(0) 230 | preds = [] 231 | step = int(1e4) 232 | for i in range(0,X_test.shape[0],step): 233 | pred = vis_param.latent_map.infer(X_test[i:min(i+step,X_test.shape[0]),:], F_tx, f_classify).detach().cpu().numpy() 234 | preds.append(pred) 235 | pred = np.concatenate(preds, axis=0).reshape(-1) 236 | # prob to color 237 | v = viridis_palette(pred)[:,:3] 238 | 239 | latent_mesh = o3d.geometry.TriangleMesh(color_mesh) 240 | latent_mesh.vertex_colors = o3d.utility.Vector3dVector(v.astype(np.float64)) 241 | #latent_mesh.remove_vertices_by_index(np.where(pinds.cpu().numpy()==-1)[0]) 242 | 243 | # transform from LIM coordinate to original coordinate 244 | latent_mesh.transform(np.linalg.inv(vis_param.sequence.T_gt2uni)) 245 | o3d.io.write_triangle_mesh(args.outdir+'/%s_recons.ply'%('lt_'+text), latent_mesh) 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | #color_mesh = color_mesh.transform(np.linalg.inv(vis_param.sequence.T_gt2uni)) 254 | 255 | 256 | 257 | 258 | 259 | --------------------------------------------------------------------------------