├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── neus.yaml ├── neus_nomask.yaml ├── neus_nomask_blended.yaml ├── unisurf.yaml ├── volsdf.yaml ├── volsdf_nerfpp.yaml ├── volsdf_nerfpp_blended.yaml └── volsdf_siren.yaml ├── dataio ├── BlendedMVS.py ├── DTU.py ├── __init__.py ├── blendedmvs_normalized.txt └── custom.py ├── debug_tools ├── __init__.py ├── plot_neus_bias.py └── test_volsdf_algo.py ├── docs ├── neus.md ├── trained_models_results.md ├── usage.md └── volsdf.md ├── media ├── 00000000.png ├── 00003000.png ├── 00010000.png ├── 00200000.png ├── DTU │ ├── neus │ │ ├── neus_105_nomask_new_rgb&normal&mesh_390x400_60_small_circle_256.gif │ │ ├── neus_106_nomask_new_rgb&normal&mesh_390x400_60_small_circle_256.gif │ │ ├── neus_110_nomask_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── neus_114_nomask_new_rgb&normal&mesh_390x400_60_small_circle_256.gif │ │ ├── neus_118_nomask_new_rgb&normal&mesh_450x400_60_small_circle_256.gif │ │ ├── neus_122_nomask_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── neus_24_nomask_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── neus_37_nomask_new_rgb&normal&mesh_300x400_60_small_circle_256.gif │ │ ├── neus_55_nomask_new_rgb&normal&mesh_300x400_60_small_circle_256.gif │ │ ├── neus_63_nomask_new_rgb&normal&mesh_300x400_60_small_circle_256.gif │ │ ├── neus_65_nomask_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── neus_69_nomask_new_rgb&normal&mesh_450x400_60_small_circle_256.gif │ │ ├── neus_83_nomask_new_rgb&normal&mesh_330x400_60_small_circle_256.gif │ │ └── neus_97_nomask_new_rgb&normal&mesh_300x400_60_small_circle_256.gif │ ├── unisurf │ │ └── unisurf_65_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ └── volsdf │ │ ├── volsdf_105_new_rgb&normal&mesh_390x400_60_small_circle_256.gif │ │ ├── volsdf_106_new_rgb&normal&mesh_390x400_60_small_circle_256.gif │ │ ├── volsdf_110_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── volsdf_114_new_rgb&normal&mesh_390x400_60_small_circle_256.gif │ │ ├── volsdf_122_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── volsdf_24_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── volsdf_37_new_rgb&normal&mesh_300x400_60_small_circle_256.gif │ │ ├── volsdf_40_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── volsdf_55_new_dbg_sp10_1c_up8_rgb&normal&mesh_300x400_60_small_circle_256.gif │ │ ├── volsdf_63_new_rgb&normal&mesh_300x400_60_small_circle_256.gif │ │ ├── volsdf_65_new_rgb&normal&mesh_360x400_60_small_circle_256.gif │ │ ├── volsdf_69_new_rgb&normal&mesh_450x400_60_small_circle_256.gif │ │ ├── volsdf_83_new_rgb&normal&mesh_330x400_60_small_circle_256.gif │ │ └── volsdf_97_new_rgb&normal&mesh_300x400_60_small_circle_256.gif ├── DTU_105.png ├── DTU_106.png ├── DTU_110.png ├── DTU_114.png ├── DTU_118.png ├── DTU_122.png ├── DTU_24.png ├── DTU_37.png ├── DTU_40.png ├── DTU_55.png ├── DTU_63.png ├── DTU_65.png ├── DTU_69.png ├── DTU_83.png ├── DTU_97.png ├── cam_great_circle.gif ├── cam_small_circle.gif ├── cam_spherical_spiral.gif ├── framework.png ├── image-20210809041923683.png ├── image-20210809041939732.png ├── image-20210809042000880.png ├── image-20210809042044843.png ├── image-20210809042204675.png ├── image-20210809042455605.png ├── limit.png ├── mesh_0k.png ├── mesh_10k.png ├── mesh_200k.png ├── mesh_3k.png ├── neus_65_nomask_new_rgb&normal_360x400_60_small_circle_None.gif ├── neus_65_nomask_new_rgb&normal_360x400_60_small_circle_sphere_tracing_None.gif ├── neus_unbiased.mp4 ├── sdf2sigma.gif ├── unisurf_65_new_rgb&normal_360x400_60_small_circle_None.gif ├── volsdf_beta_00000000.png ├── volsdf_beta_00004000.png ├── volsdf_beta_00010000.png ├── volsdf_beta_00200000.png ├── volsdf_nerf++_blended_norm_5c0d13_rgb&mesh_576x768_450_archimedean_spiral_256.gif ├── volsdf_up_iter_00000000.png ├── volsdf_up_iter_00004000.png ├── volsdf_up_iter_00010000.png └── volsdf_up_iter_00200000.png ├── models ├── __init__.py ├── base.py ├── frameworks │ ├── __init__.py │ ├── neus.py │ ├── unisurf.py │ └── volsdf.py └── ray_casting.py ├── set_env.sh ├── tools ├── __init__.py ├── extract_surface.py ├── render_view.py ├── vis_camera.py ├── vis_ray.py └── vis_surface_and_cam.py ├── train.py └── utils ├── __init__.py ├── checkpoints.py ├── dist_util.py ├── io_util.py ├── logger.py ├── mesh_util.py ├── print_fn.py ├── rend_util.py └── train_util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gif filter=lfs diff=lfs merge=lfs -text 2 | *.mp4 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .spyproject/ 2 | .vscode/ 3 | logs/ 4 | data/ 5 | data 6 | dev_test/ 7 | out/ 8 | **/__pycache__ 9 | *.pyc 10 | trained_models/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jianfei Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ventusff/neurecon/972e810ec252cfd16f630b1de6d2802d1b8de59a/__init__.py -------------------------------------------------------------------------------- /configs/neus.yaml: -------------------------------------------------------------------------------- 1 | expname: neus_37 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | batch_size: 1 # one batch, one image 9 | data_dir: ./data/neus/dtu_scan37 10 | cam_file: 'cameras_sphere.npz' 11 | downscale: 1 # downscale image for training 12 | pin_memory: True 13 | 14 | N_rays: 512 # N_rays for training 15 | val_rayschunk: 256 # N_rays for validation 16 | val_downscale: 8 # downscale image for validation 17 | 18 | model: 19 | framework: NeuS 20 | obj_bounding_radius: 1.0 21 | 22 | variance_init: 0.05 23 | 24 | # upsampling related 25 | upsample_algo: official_solution # [direct_use, direct_more, official_solution] 26 | N_nograd_samples: 2048 # config for upsampling using 'direct_more' 27 | N_upsample_iters: 4 # config for upsampling using 'official_solution' 28 | 29 | surface: 30 | D: 8 31 | W: 256 32 | skips: [4] 33 | radius_init: 0.5 34 | embed_multires: 6 35 | 36 | radiance: 37 | D: 4 38 | W: 256 39 | skips: [] 40 | embed_multires: -1 41 | embed_multires_view: 4 # as in the NeuS official implementaion 42 | 43 | training: 44 | lr: 5.0e-4 45 | speed_factor: 10.0 # NOTE: unexpectedly, this is very important. setting to 1.0 will cause some of the DTU instances can not converge correctly. 46 | 47 | # neus 48 | with_mask: True 49 | w_eikonal: 0.1 50 | w_mask: 1.0 51 | 52 | log_root_dir: "logs" 53 | 54 | # lr decay 55 | scheduler: 56 | type: warmupcosine 57 | warmup_steps: 5000 # unit: itertation steps 58 | 59 | # num_epochs: 50000 60 | num_iters: 300000 # 300k 61 | 62 | ckpt_file: null # will be read by python as None 63 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 64 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 65 | 66 | monitoring: tensorboard 67 | 68 | i_save: 900 # unit: seconds 69 | i_backup: 50000 # unit: itertation steps 70 | 71 | i_val: 500 72 | i_val_mesh: 10000 -------------------------------------------------------------------------------- /configs/neus_nomask.yaml: -------------------------------------------------------------------------------- 1 | expname: neus_37_nomask 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | batch_size: 1 # one batch, one image 9 | data_dir: ./data/neus/dtu_scan37 10 | cam_file: 'cameras_sphere.npz' 11 | downscale: 1 # downscale image for training 12 | pin_memory: True 13 | 14 | N_rays: 512 # N_rays for training 15 | val_rayschunk: 256 # N_rays for validation 16 | val_downscale: 8 # downscale image for validation 17 | 18 | model: 19 | framework: NeuS 20 | obj_bounding_radius: 1.0 21 | 22 | variance_init: 0.05 23 | N_outside: 32 # number of outside NeRF++ points 24 | 25 | # upsampling related 26 | upsample_algo: official_solution # [direct_use, direct_more, official_solution] 27 | N_nograd_samples: 2048 # config for upsampling using 'direct_more' 28 | N_upsample_iters: 4 # config for upsampling using 'official_solution' 29 | 30 | surface: 31 | D: 8 32 | W: 256 33 | skips: [4] 34 | radius_init: 0.5 35 | embed_multires: 6 36 | 37 | radiance: 38 | D: 4 39 | W: 256 40 | skips: [] 41 | embed_multires: -1 42 | embed_multires_view: 4 # as in the NeuS official implementaion 43 | 44 | training: 45 | lr: 5.0e-4 46 | speed_factor: 10.0 # NOTE: unexpectedly, this is very important. setting to 1.0 will cause some of the DTU instances can not converge correctly. 47 | 48 | # neus 49 | with_mask: False 50 | w_eikonal: 0.1 51 | 52 | log_root_dir: "logs" 53 | 54 | # lr decay 55 | scheduler: 56 | type: warmupcosine 57 | warmup_steps: 5000 # unit: itertation steps 58 | 59 | # num_epochs: 50000 60 | num_iters: 300000 # 300k 61 | 62 | ckpt_file: null # will be read by python as None 63 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 64 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 65 | 66 | monitoring: tensorboard 67 | 68 | i_save: 900 # unit: seconds 69 | i_backup: 50000 # unit: itertation steps 70 | 71 | i_val: 500 72 | i_val_mesh: 10000 -------------------------------------------------------------------------------- /configs/neus_nomask_blended.yaml: -------------------------------------------------------------------------------- 1 | expname: neus_nomask_blended_5c0d13 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | type: BlendedMVS 9 | batch_size: 1 # one batch, one image 10 | data_dir: ./data/BlendedMVS/BlendedMVS/5c0d13b795da9479e12e2ee9 11 | downscale: 1 # downscale image for training 12 | scale_radius: 2.0 # scale the dataset's all camera to be within this radius 13 | pin_memory: True 14 | 15 | N_rays: 512 # N_rays for training 16 | val_rayschunk: 256 # N_rays for validation 17 | val_downscale: 4 # downscale image for validation 18 | 19 | model: 20 | framework: NeuS 21 | obj_bounding_radius: 1.0 22 | 23 | variance_init: 0.05 24 | N_outside: 32 # number of outside NeRF++ points 25 | 26 | # upsampling related 27 | upsample_algo: official_solution # [direct_use, direct_more, official_solution] 28 | N_nograd_samples: 2048 # config for upsampling using 'direct_more' 29 | N_upsample_iters: 4 # config for upsampling using 'official_solution' 30 | 31 | surface: 32 | D: 8 33 | W: 256 34 | skips: [4] 35 | radius_init: 0.5 36 | embed_multires: 6 37 | 38 | radiance: 39 | D: 4 40 | W: 256 41 | skips: [] 42 | embed_multires: -1 43 | embed_multires_view: 4 # as in the NeuS official implementaion 44 | 45 | training: 46 | lr: 5.0e-4 47 | speed_factor: 10.0 # NOTE: unexpectedly, this is very important. setting to 1.0 will cause some of the DTU instances can not converge correctly. 48 | 49 | # neus 50 | with_mask: False 51 | w_eikonal: 0.1 52 | 53 | log_root_dir: "logs" 54 | 55 | # lr decay 56 | scheduler: 57 | type: warmupcosine 58 | warmup_steps: 5000 # unit: itertation steps 59 | 60 | # num_epochs: 50000 61 | num_iters: 300000 # 300k 62 | 63 | ckpt_file: null # will be read by python as None 64 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 65 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 66 | 67 | monitoring: tensorboard 68 | 69 | i_save: 900 # unit: seconds 70 | i_backup: 50000 # unit: itertation steps 71 | 72 | i_val: 500 73 | i_val_mesh: 10000 -------------------------------------------------------------------------------- /configs/unisurf.yaml: -------------------------------------------------------------------------------- 1 | expname: unisurf_65 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | batch_size: 1 # one batch, one image 9 | data_dir: ./data/DTU/scan65 10 | downscale: 1 # downscale image for training 11 | pin_memory: True 12 | 13 | N_rays: 1024 # N_rays for training 14 | val_rayschunk: 256 # N_rays for validation 15 | val_downscale: 8 # downscale image for validation 16 | 17 | model: 18 | framework: UNISURF 19 | tau: 0.5 # level surface 20 | W_geometry_feature: 256 21 | obj_bounding_radius: 4.0 # as in UNISURF supp II.1, 'four-times' larger region of interest 22 | 23 | surface: 24 | radius_init: 1.0 25 | D: 8 26 | skips: [4] 27 | embed_multires: 6 28 | 29 | radiance: 30 | D: 4 31 | skips: [] 32 | embed_multires: -1 33 | embed_multires_view: -1 34 | 35 | training: 36 | lr: 1.0e-4 37 | w_reg: 0.01 38 | perturb_surface_pts: 0.01 # for smoothing normals 39 | 40 | delta_max: 1.0 41 | delta_min: 0.05 42 | delta_beta: 1.5e-5 43 | 44 | log_root_dir: "./logs" # the final expdir would be log_root_dir/expname 45 | 46 | # lr decay 47 | scheduler: 48 | type: multistep 49 | milestones: [200000, 400000] # [200k, 400k] # unit: itertation steps 50 | gamma: 0.5 51 | 52 | # num_epochs: 50000 53 | num_iters: 450000 # 450k 54 | 55 | ckpt_file: null # will be read by python as None 56 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 57 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 58 | 59 | monitoring: tensorboard 60 | 61 | i_save: 900 # unit: seconds 62 | i_backup: 50000 # unit: itertation steps 63 | 64 | i_val: 1000 65 | i_val_mesh: 20000 66 | -------------------------------------------------------------------------------- /configs/volsdf.yaml: -------------------------------------------------------------------------------- 1 | expname: volsdf_65 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | batch_size: 1 # one batch, one image 9 | data_dir: ./data/DTU/scan65 10 | downscale: 1 # downscale image for training 11 | scale_radius: 3.0 # scale the dataset's all camera to be within this radius 12 | pin_memory: True 13 | 14 | near: 0.0 15 | far: 6.0 # NOTE: in volSDF,far = 2r=2*3=6.0 16 | 17 | N_rays: 1024 # N_rays for training 18 | val_rayschunk: 256 # N_rays for validation 19 | val_downscale: 8 # downscale image for validation 20 | 21 | model: 22 | framework: VolSDF 23 | obj_bounding_radius: 3.0 # scene sphere, as in the VolSDF paper 24 | 25 | outside_scene: "builtin" # [builtin, nerf++] 26 | max_upsample_iter: 6 # up sample iteratioms, as in the VolSDF paper 27 | 28 | W_geometry_feature: 256 29 | 30 | surface: 31 | radius_init: 1.0 # as in VolSDF supp B.3, unit sphere 32 | D: 8 33 | skips: [4] 34 | embed_multires: 6 35 | 36 | radiance: 37 | D: 4 38 | skips: [] 39 | embed_multires: -1 40 | embed_multires_view: -1 41 | 42 | training: 43 | speed_factor: 10.0 44 | 45 | lr: 5.0e-4 46 | w_eikonal: 0.1 47 | 48 | log_root_dir: "logs" 49 | 50 | num_iters: 100000 51 | 52 | # lr decay 53 | scheduler: 54 | type: exponential_step 55 | min_factor: 0.1 56 | 57 | ckpt_file: null # will be read by python as None 58 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 59 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 60 | 61 | monitoring: tensorboard 62 | 63 | i_save: 900 # unit: seconds 64 | i_backup: 50000 # unit: itertation steps 65 | 66 | i_val: 500 67 | i_val_mesh: 10000 68 | -------------------------------------------------------------------------------- /configs/volsdf_nerfpp.yaml: -------------------------------------------------------------------------------- 1 | expname: volsdf_nerf++_40 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | batch_size: 1 # one batch, one image 9 | data_dir: ./data/DTU/scan40 10 | downscale: 1 # downscale image for training 11 | scale_radius: 3.0 # scale the dataset's all camera to be within this radius 12 | pin_memory: True 13 | 14 | near: 0.0 15 | far: 6.0 # NOTE: in volSDF,far = 2r=2*3=6.0 16 | 17 | N_rays: 1024 # N_rays for training 18 | val_rayschunk: 256 # N_rays for validation 19 | val_downscale: 8 # downscale image for validation 20 | 21 | model: 22 | framework: VolSDF 23 | obj_bounding_radius: 3.0 # scene sphere, as in the VolSDF paper 24 | 25 | outside_scene: "nerf++" # [builtin, nerf++] 26 | max_upsample_iter: 5 # up sample iteratioms, as in the VolSDF paper 27 | 28 | W_geometry_feature: 256 29 | 30 | surface: 31 | radius_init: 1.0 # in VolSDF supp B.3, unit sphere 32 | D: 8 33 | skips: [4] 34 | embed_multires: 6 35 | 36 | radiance: 37 | D: 4 38 | skips: [] 39 | embed_multires: -1 40 | embed_multires_view: -1 41 | 42 | training: 43 | speed_factor: 10.0 44 | lr: 5.0e-4 45 | w_eikonal: 0.1 46 | 47 | log_root_dir: "logs" 48 | 49 | num_iters: 100000 # 100k 50 | 51 | # lr decay 52 | scheduler: 53 | type: warmupcosine 54 | warmup_steps: 0 # unit: itertation steps 55 | 56 | ckpt_file: null # will be read by python as None 57 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 58 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 59 | 60 | monitoring: tensorboard 61 | 62 | i_save: 900 # unit: seconds 63 | i_backup: 50000 # unit: itertation steps 64 | 65 | i_val: 500 66 | i_val_mesh: 10000 67 | -------------------------------------------------------------------------------- /configs/volsdf_nerfpp_blended.yaml: -------------------------------------------------------------------------------- 1 | expname: volsdf_nerf++_blended_5c0d13 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | type: BlendedMVS 9 | batch_size: 1 # one batch, one image 10 | data_dir: ./data/BlendedMVS/BlendedMVS/5c0d13b795da9479e12e2ee9 11 | downscale: 1 # downscale image for training 12 | scale_radius: 3.0 # scale the dataset's all camera to be within this radius 13 | volume_size: 5.0 # volume size for extracing mesh using marching cube 14 | pin_memory: True 15 | 16 | near: 0.0 17 | far: 6.0 # NOTE: in volSDF,far = 2r=2*3=6.0 18 | 19 | N_rays: 1024 # N_rays for training 20 | val_rayschunk: 256 # N_rays for validation 21 | val_downscale: 4 # downscale image for validation 22 | 23 | model: 24 | framework: VolSDF 25 | obj_bounding_radius: 3.0 # scene sphere, as in the VolSDF paper 26 | 27 | outside_scene: "nerf++" # [builtin, nerf++] 28 | max_upsample_iter: 5 # up sample iteratioms, as in the VolSDF paper 29 | 30 | W_geometry_feature: 256 31 | 32 | surface: 33 | radius_init: 1.0 # as in VolSDF supp B.3, unit sphere 34 | D: 8 35 | skips: [4] 36 | embed_multires: 6 37 | 38 | radiance: 39 | D: 4 40 | skips: [] 41 | embed_multires: -1 42 | embed_multires_view: -1 43 | 44 | training: 45 | speed_factor: 10.0 46 | lr: 5.0e-4 47 | w_eikonal: 0.1 48 | 49 | log_root_dir: "logs" 50 | 51 | # lr decay 52 | scheduler: 53 | type: warmupcosine 54 | warmup_steps: 0 # unit: itertation steps 55 | 56 | num_iters: 200000 57 | 58 | ckpt_file: null # will be read by python as None 59 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 60 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 61 | 62 | monitoring: tensorboard 63 | 64 | i_save: 900 # unit: seconds 65 | i_backup: 50000 # unit: itertation steps 66 | 67 | i_val: 500 68 | i_val_mesh: 10000 69 | -------------------------------------------------------------------------------- /configs/volsdf_siren.yaml: -------------------------------------------------------------------------------- 1 | expname: volsdf_siren_65_fix_lr1e-4_gamma 2 | 3 | # device_ids: [0] # single gpu ; run on specified GPU 4 | # device_ids: [1, 0] # DP ; run on specified GPU 5 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 6 | 7 | data: 8 | batch_size: 1 # one batch, one image 9 | data_dir: ./data/DTU/scan65 10 | downscale: 1 # downscale image for training 11 | scale_radius: 3.0 # scale the dataset's all camera to be within this radius 12 | pin_memory: True 13 | 14 | near: 0.0 15 | far: 6.0 # NOTE: in volSDF,far = 2r=2*3=6.0 16 | 17 | N_rays: 1024 # N_rays for training 18 | val_rayschunk: 256 # N_rays for validation 19 | val_downscale: 8 # downscale image for validation 20 | 21 | model: 22 | framework: VolSDF 23 | obj_bounding_radius: 3.0 # scene sphere, as in the VolSDF paper 24 | 25 | outside_scene: "builtin" # [builtin, nerf++] 26 | max_upsample_iter: 5 # up sample iteratioms, as in the VolSDF paper 27 | 28 | W_geometry_feature: 256 29 | 30 | surface: 31 | radius_init: 1.0 # as in VolSDF supp B.3, unit sphere 32 | use_siren: true 33 | D: 5 34 | skips: [] 35 | embed_multires: -1 36 | 37 | radiance: 38 | use_siren: true 39 | D: 5 40 | skips: [] 41 | embed_multires: -1 42 | embed_multires_view: 4 43 | 44 | 45 | training: 46 | lr_pretrain: 1.5e-4 47 | 48 | lr: 1.0e-4 49 | w_eikonal: 0.1 50 | 51 | log_root_dir: "logs" 52 | 53 | num_iters: 150000 54 | 55 | # lr decay 56 | scheduler: 57 | type: multistep 58 | milestones: [40000, 80000, 120000] # unit: itertation steps 59 | gamma: 0.5 60 | 61 | # scheduler: 62 | # type: warmupcosine 63 | # warmup_steps: 0 # unit: itertation steps 64 | 65 | ckpt_file: null # will be read by python as None 66 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 67 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 68 | 69 | monitoring: tensorboard 70 | 71 | i_save: 900 # unit: seconds 72 | i_backup: 50000 # unit: itertation steps 73 | 74 | i_val: 500 75 | i_val_mesh: 10000 76 | -------------------------------------------------------------------------------- /dataio/BlendedMVS.py: -------------------------------------------------------------------------------- 1 | from utils.io_util import glob_imgs, load_rgb 2 | 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import torch 8 | 9 | 10 | class SceneDataset(torch.utils.data.Dataset): 11 | def __init__(self, 12 | train_cameras, 13 | data_dir, 14 | downscale=1., # [H, W] 15 | scale_radius=-1): 16 | super().__init__() 17 | 18 | self.instance_dir = data_dir 19 | assert os.path.exists(self.instance_dir), "Data directory is empty" 20 | 21 | self.train_cameras = train_cameras 22 | self.downscale = downscale 23 | 24 | image_dir = '{0}/blended_images'.format(self.instance_dir) 25 | # cam_dir = '{0}/cams'.format(self.instance_dir) 26 | cam_dir = '{0}/cams_normalized'.format(self.instance_dir) 27 | 28 | self.intrinsics_all = [] 29 | self.c2w_all = [] 30 | self.rgb_images = [] 31 | self.basenames = [] 32 | cam_center_norms = [] 33 | for imgpath in tqdm(sorted(glob_imgs(image_dir)), desc='loading data...'): 34 | if 'masked' in imgpath: 35 | pass 36 | else: 37 | basename = os.path.splitext(os.path.split(imgpath)[-1])[0] 38 | self.basenames.append(basename) 39 | 40 | camfilepath = os.path.join(cam_dir, "{}_cam.txt".format(basename)) 41 | assert os.path.exists(camfilepath) 42 | extrinsics, intrinsics = load_cam(camfilepath) 43 | c2w = np.linalg.inv(extrinsics) 44 | cam_center_norms.append(np.linalg.norm(c2w[:3,3])) 45 | 46 | # downscale intrinsics 47 | intrinsics[0, 2] /= downscale 48 | intrinsics[1, 2] /= downscale 49 | intrinsics[0, 0] /= downscale 50 | intrinsics[1, 1] /= downscale 51 | 52 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 53 | self.c2w_all.append(torch.from_numpy(c2w).float()) 54 | 55 | rgb = load_rgb(imgpath, downscale) 56 | _, self.H, self.W = rgb.shape 57 | rgb = rgb.reshape(3, -1).transpose(1, 0) 58 | self.rgb_images.append(torch.from_numpy(rgb).float()) 59 | 60 | max_cam_norm = max(cam_center_norms) 61 | if scale_radius > 0: 62 | for i in range(len(self.c2w_all)): 63 | self.c2w_all[i][:3, 3] *= (scale_radius / max_cam_norm / 1.1) 64 | 65 | self.n_images = len(self.rgb_images) 66 | 67 | 68 | def __len__(self): 69 | return self.n_images 70 | 71 | def __getitem__(self, idx): 72 | # uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 73 | # uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 74 | # uv = uv.reshape(2, -1).transpose(1, 0) 75 | 76 | sample = { 77 | "intrinsics": self.intrinsics_all[idx], 78 | } 79 | 80 | ground_truth = { 81 | "rgb": self.rgb_images[idx] 82 | } 83 | 84 | ground_truth["rgb"] = self.rgb_images[idx] 85 | 86 | if not self.train_cameras: 87 | sample["c2w"] = self.c2w_all[idx] 88 | 89 | return idx, sample, ground_truth 90 | 91 | def collate_fn(self, batch_list): 92 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 93 | batch_list = zip(*batch_list) 94 | 95 | all_parsed = [] 96 | for entry in batch_list: 97 | if type(entry[0]) is dict: 98 | # make them all into a new dict 99 | ret = {} 100 | for k in entry[0].keys(): 101 | ret[k] = torch.stack([obj[k] for obj in entry]) 102 | all_parsed.append(ret) 103 | else: 104 | all_parsed.append(torch.LongTensor(entry)) 105 | 106 | return tuple(all_parsed) 107 | 108 | def get_gt_pose(self): 109 | return torch.stack(self.c2w_all, dim=0) 110 | 111 | # modified from https://github.com/YoYo000/MVSNet/blob/master/mvsnet/preprocess.py 112 | def load_cam(filepath, interval_scale=1, original_blendedmvs=False): 113 | """ read camera txt file """ 114 | cam = np.repeat(np.eye(4)[None, ...], repeats=2, axis=0) 115 | words = open(filepath).read().split() 116 | # read extrinsic 117 | for i in range(0, 4): 118 | for j in range(0, 4): 119 | extrinsic_index = 4 * i + j + 1 120 | cam[0][i][j] = words[extrinsic_index] 121 | 122 | # read intrinsic 123 | for i in range(0, 3): 124 | for j in range(0, 3): 125 | intrinsic_index = 3 * i + j + 18 126 | cam[1][i][j] = words[intrinsic_index] 127 | 128 | if original_blendedmvs: 129 | if len(words) == 29: 130 | cam[1][3][0] = words[27] 131 | cam[1][3][1] = float(words[28]) * interval_scale 132 | # cam[1][3][2] = FLAGS.max_d 133 | cam[1][3][2] = 128 # NOTE: manually fixed here. 134 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2] 135 | elif len(words) == 30: 136 | cam[1][3][0] = words[27] 137 | cam[1][3][1] = float(words[28]) * interval_scale 138 | cam[1][3][2] = words[29] 139 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2] 140 | elif len(words) == 31: 141 | cam[1][3][0] = words[27] 142 | cam[1][3][1] = float(words[28]) * interval_scale 143 | cam[1][3][2] = words[29] 144 | cam[1][3][3] = words[30] 145 | else: 146 | cam[1][3][0] = 0 147 | cam[1][3][1] = 0 148 | cam[1][3][2] = 0 149 | cam[1][3][3] = 0 150 | 151 | return cam 152 | 153 | 154 | def write_cam(filepath, cam): 155 | f = open(filepath, "w") 156 | 157 | f.write('extrinsic\n') 158 | for i in range(0, 4): 159 | for j in range(0, 4): 160 | f.write(str(cam[0][i][j]) + ' ') 161 | f.write('\n') 162 | f.write('\n') 163 | 164 | f.write('intrinsic\n') 165 | for i in range(0, 3): 166 | for j in range(0, 3): 167 | f.write(str(cam[1][i][j]) + ' ') 168 | f.write('\n') 169 | 170 | f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n') 171 | 172 | f.close() 173 | 174 | 175 | if __name__ == '__main__': 176 | def test(): 177 | dataset = SceneDataset(False, './data/BlendedMVS/BlendedMVS/5c0d13b795da9479e12e2ee9', scale_radius=3.0) 178 | c2w = dataset.get_gt_pose().data.cpu().numpy() 179 | extrinsics = np.linalg.inv(c2w) 180 | camera_matrix = next(iter(dataset))[1]['intrinsics'].data.cpu().numpy() 181 | from tools.vis_camera import visualize 182 | visualize(camera_matrix, extrinsics) 183 | test() -------------------------------------------------------------------------------- /dataio/DTU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from utils.io_util import load_mask, load_rgb, glob_imgs 7 | from utils.rend_util import rot_to_quat, load_K_Rt_from_P 8 | 9 | class SceneDataset(torch.utils.data.Dataset): 10 | # NOTE: jianfei: modified from IDR. https://github.com/lioryariv/idr/blob/main/code/datasets/scene_dataset.py 11 | """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset.""" 12 | 13 | def __init__(self, 14 | train_cameras, 15 | data_dir, 16 | downscale=1., # [H, W] 17 | cam_file=None, 18 | scale_radius=-1): 19 | 20 | assert os.path.exists(data_dir), "Data directory is empty" 21 | 22 | self.instance_dir = data_dir 23 | self.train_cameras = train_cameras 24 | 25 | image_dir = '{0}/image'.format(self.instance_dir) 26 | image_paths = sorted(glob_imgs(image_dir)) 27 | mask_dir = '{0}/mask'.format(self.instance_dir) 28 | mask_paths = sorted(glob_imgs(mask_dir)) 29 | 30 | self.n_images = len(image_paths) 31 | 32 | # determine width, height 33 | self.downscale = downscale 34 | tmp_rgb = load_rgb(image_paths[0], downscale) 35 | _, self.H, self.W = tmp_rgb.shape 36 | 37 | 38 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) 39 | if cam_file is not None: 40 | self.cam_file = '{0}/{1}'.format(self.instance_dir, cam_file) 41 | 42 | camera_dict = np.load(self.cam_file) 43 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 44 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 45 | 46 | self.intrinsics_all = [] 47 | self.c2w_all = [] 48 | cam_center_norms = [] 49 | for scale_mat, world_mat in zip(scale_mats, world_mats): 50 | P = world_mat @ scale_mat 51 | P = P[:3, :4] 52 | intrinsics, pose = load_K_Rt_from_P(P) 53 | cam_center_norms.append(np.linalg.norm(pose[:3,3])) 54 | 55 | # downscale intrinsics 56 | intrinsics[0, 2] /= downscale 57 | intrinsics[1, 2] /= downscale 58 | intrinsics[0, 0] /= downscale 59 | intrinsics[1, 1] /= downscale 60 | # intrinsics[0, 1] /= downscale # skew is a ratio, do not scale 61 | 62 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 63 | self.c2w_all.append(torch.from_numpy(pose).float()) 64 | max_cam_norm = max(cam_center_norms) 65 | if scale_radius > 0: 66 | for i in range(len(self.c2w_all)): 67 | self.c2w_all[i][:3, 3] *= (scale_radius / max_cam_norm / 1.1) 68 | 69 | self.rgb_images = [] 70 | for path in tqdm(image_paths, desc='loading images...'): 71 | rgb = load_rgb(path, downscale) 72 | rgb = rgb.reshape(3, -1).transpose(1, 0) 73 | self.rgb_images.append(torch.from_numpy(rgb).float()) 74 | 75 | self.object_masks = [] 76 | for path in mask_paths: 77 | object_mask = load_mask(path, downscale) 78 | object_mask = object_mask.reshape(-1) 79 | self.object_masks.append(torch.from_numpy(object_mask).to(dtype=torch.bool)) 80 | 81 | def __len__(self): 82 | return self.n_images 83 | 84 | def __getitem__(self, idx): 85 | # uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 86 | # uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 87 | # uv = uv.reshape(2, -1).transpose(1, 0) 88 | 89 | sample = { 90 | "object_mask": self.object_masks[idx], 91 | "intrinsics": self.intrinsics_all[idx], 92 | } 93 | 94 | ground_truth = { 95 | "rgb": self.rgb_images[idx] 96 | } 97 | 98 | ground_truth["rgb"] = self.rgb_images[idx] 99 | sample["object_mask"] = self.object_masks[idx] 100 | 101 | if not self.train_cameras: 102 | sample["c2w"] = self.c2w_all[idx] 103 | 104 | return idx, sample, ground_truth 105 | 106 | def collate_fn(self, batch_list): 107 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 108 | batch_list = zip(*batch_list) 109 | 110 | all_parsed = [] 111 | for entry in batch_list: 112 | if type(entry[0]) is dict: 113 | # make them all into a new dict 114 | ret = {} 115 | for k in entry[0].keys(): 116 | ret[k] = torch.stack([obj[k] for obj in entry]) 117 | all_parsed.append(ret) 118 | else: 119 | all_parsed.append(torch.LongTensor(entry)) 120 | 121 | return tuple(all_parsed) 122 | 123 | def get_scale_mat(self): 124 | return np.load(self.cam_file)['scale_mat_0'] 125 | 126 | def get_gt_pose(self, scaled=True): 127 | # Load gt pose without normalization to unit sphere 128 | camera_dict = np.load(self.cam_file) 129 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 130 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 131 | 132 | c2w_all = [] 133 | for scale_mat, world_mat in zip(scale_mats, world_mats): 134 | P = world_mat 135 | if scaled: 136 | P = world_mat @ scale_mat 137 | P = P[:3, :4] 138 | _, pose = load_K_Rt_from_P(P) 139 | c2w_all.append(torch.from_numpy(pose).float()) 140 | 141 | return torch.cat([p.float().unsqueeze(0) for p in c2w_all], 0) 142 | 143 | def get_pose_init(self): 144 | # get noisy initializations obtained with the linear method 145 | cam_file = '{0}/cameras_linear_init.npz'.format(self.instance_dir) 146 | camera_dict = np.load(cam_file) 147 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 148 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 149 | 150 | init_pose = [] 151 | for scale_mat, world_mat in zip(scale_mats, world_mats): 152 | P = world_mat @ scale_mat 153 | P = P[:3, :4] 154 | _, pose = load_K_Rt_from_P(P) 155 | init_pose.append(pose) 156 | init_pose = torch.cat([torch.Tensor(pose).float().unsqueeze(0) for pose in init_pose], 0).cuda() 157 | init_quat = rot_to_quat(init_pose[:, :3, :3]) 158 | init_quat = torch.cat([init_quat, init_pose[:, :3, 3]], 1) 159 | 160 | return init_quat 161 | 162 | if __name__ == "__main__": 163 | dataset = SceneDataset(False, './data/DTU/scan40') 164 | c2w = dataset.get_gt_pose(scaled=True).data.cpu().numpy() 165 | extrinsics = np.linalg.inv(c2w) # camera extrinsics are w2c matrix 166 | camera_matrix = next(iter(dataset))[1]['intrinsics'].data.cpu().numpy() 167 | from tools.vis_camera import visualize 168 | visualize(camera_matrix, extrinsics) -------------------------------------------------------------------------------- /dataio/__init__.py: -------------------------------------------------------------------------------- 1 | def get_data(args, return_val=False, val_downscale=4.0, **overwrite_cfgs): 2 | dataset_type = args.data.get('type', 'DTU') 3 | cfgs = { 4 | 'scale_radius': args.data.get('scale_radius', -1), 5 | 'downscale': args.data.downscale, 6 | 'data_dir': args.data.data_dir, 7 | 'train_cameras': False 8 | } 9 | 10 | if dataset_type == 'DTU': 11 | from .DTU import SceneDataset 12 | cfgs['cam_file'] = args.data.get('cam_file', None) 13 | elif dataset_type == 'custom': 14 | from .custom import SceneDataset 15 | elif dataset_type == 'BlendedMVS': 16 | from .BlendedMVS import SceneDataset 17 | else: 18 | raise NotImplementedError 19 | 20 | cfgs.update(overwrite_cfgs) 21 | dataset = SceneDataset(**cfgs) 22 | if return_val: 23 | cfgs['downscale'] = val_downscale 24 | val_dataset = SceneDataset(**cfgs) 25 | return dataset, val_dataset 26 | else: 27 | return dataset -------------------------------------------------------------------------------- /dataio/blendedmvs_normalized.txt: -------------------------------------------------------------------------------- 1 | 58c4bb4f4a69c55606122be4 2 | 58cf4771d0f5fb221defe6da 3 | 58d36897f387231e6c929903 4 | 58eaf1513353456af3a1682a 5 | 58f7f7299f5b5647873cb110 6 | 59056e6760bb961de55f3501 7 | 59338e76772c3e6384afbb15 8 | 59350ca084b7f26bf5ce6eb8 9 | 5947b62af1b45630bd0c2a02 10 | 59817e4a1bd4b175e7038d19 11 | 599aa591d5b41f366fed0d58 12 | 59d2657f82ca7774b1ec081d 13 | 59e75a2ca9e91f2c5526005d 14 | 59e864b2a9e91f2c5529325f 15 | 59ecfd02e225f6492d20fcc9 16 | 59f363a8b45be22330016cad 17 | 59f87d0bfa6280566fb38c9a 18 | 5a0271884e62597cdee0d0eb 19 | 5a3ca9cb270f0e3f14d0eddb 20 | 5a3cb4e4270f0e3f14d12f43 21 | 5a3f4aba5889373fbbc5d3b5 22 | 5a489fb1c7dab83a7d7b1070 23 | 5a48ba95c7dab83a7d7b44ed 24 | 5a48c4e9c7dab83a7d7b5cc7 25 | 5a48d4b2c7dab83a7d7b9851 26 | 5a4a38dad38c8a075495b5d2 27 | 5a563183425d0f5186314855 28 | 5a572fd9fc597b0478a81d14 29 | 5a57542f333d180827dfc132 30 | 5a588a8193ac3d233f77fbca 31 | 5a618c72784780334bc1972d 32 | 5a6400933d809f1d8200af15 33 | 5a6464143d809f1d8208c43c 34 | 5a69c47d0d5d0a7f3b2e9752 35 | 5a7d3db14989e929563eb153 36 | 5a8aa0fab18050187cbe060e 37 | 5a969eea91dfc339a9a3ad2c 38 | 5aa0f9d7a9efce63548c69a1 39 | 5aa235f64a17b335eeaf9609 40 | 5ab85f1dac4291329b17cb50 41 | 5ab8713ba3799a1d138bd69a 42 | 5ab8b8e029f5351f7f2ccf59 43 | 5adc6bd52430a05ecb2ffb85 44 | 5ae2e9c5fe405c5076abc6b2 45 | 5af28cea59bc705737003253 46 | 5b192eb2170cf166458ff886 47 | 5b21e18c58e2823a67a10dd8 48 | 5b22269758e2823a67a3bd03 49 | 5b2c67b5e0878c381608b8d8 50 | 5b3b353d8d46a939f93524b9 51 | 5b4933abf2b5f44e95de482a 52 | 5b62647143840965efc0dbde 53 | 5b6e716d67b396324c2d77cb 54 | 5b6eff8b67b396324c5b2672 55 | 5b78e57afc8fcf6781d0c3ba 56 | 5b908d3dc6ab78485f3d24a9 57 | 5b950c71608de421b1e7318f 58 | 5ba19a8a360c7c30c1c169df 59 | 5bb7a08aea1cfa39f1a947ab 60 | 5bc5f0e896b66a2cd8f9bd36 61 | 5bccd6beca24970bce448134 62 | 5bce7ac9ca24970bce4934b6 63 | 5bcf979a6d5f586b95c258cd 64 | 5bd43b4ba6b28b1ee86b92dd 65 | 5be47bf9b18881428d8fbc1d 66 | 5be883a4f98cee15019d5b83 67 | 5beb6e66abd34c35e18e66b9 68 | 5bf18642c50e6f7f8bdbd492 69 | 5bf21799d43923194842c001 70 | 5bf3a82cd439231948877aed 71 | 5bf7d63575c26f32dbf7413b 72 | 5bfd0f32ec61ca1dd69dc77b 73 | 5c062d84a96e33018ff6f0a6 74 | 5c0d13b795da9479e12e2ee9 75 | 5c1892f726173c3a09ea9aeb 76 | 5c189f2326173c3a09ed7ef3 77 | 5c1af2e2bee9a723c963d019 78 | 5c1b1500bee9a723c96c3e78 79 | 5c1dbf200843bc542d8ef8c4 80 | 5c20ca3a0843bc542d94e3e2 81 | 5c2b3ed5e611832e8aed46bf 82 | 5c34300a73a8df509add216d 83 | 5c34529873a8df509ae57b58 84 | -------------------------------------------------------------------------------- /dataio/custom.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lioryariv/idr/blob/main/code/datasets/scene_dataset.py 2 | 3 | import os 4 | import json 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from utils.io_util import load_mask, load_rgb, glob_imgs 10 | from utils.rend_util import rot_to_quat, load_K_Rt_from_P 11 | 12 | class SceneDataset(torch.utils.data.Dataset): 13 | """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset.""" 14 | def __init__(self, 15 | train_cameras, 16 | data_dir, 17 | downscale=1., # [H, W] 18 | cam_file=None, 19 | scale_radius=-1, 20 | ): 21 | 22 | self.instance_dir = data_dir 23 | assert os.path.exists(self.instance_dir), "Data directory is empty" 24 | 25 | self.train_cameras = train_cameras 26 | self.downscale = downscale 27 | 28 | image_dir = '{0}/images'.format(self.instance_dir) 29 | # image_paths = sorted(glob_imgs(image_dir)) 30 | mask_dir = '{0}/mask'.format(self.instance_dir) 31 | # mask_paths = sorted(glob_imgs(mask_dir)) 32 | mask_ignore_dir = '{0}/mask_out'.format(self.instance_dir) 33 | 34 | self.has_mask = os.path.exists(mask_dir) and len(os.listdir(mask_dir)) > 0 35 | self.has_mask_out = os.path.exists(mask_ignore_dir) and len(os.listdir(mask_ignore_dir)) > 0 36 | 37 | self.cam_file = '{0}/cam.json'.format(self.instance_dir) 38 | if cam_file is not None: 39 | self.cam_file = '{0}/{1}'.format(self.instance_dir, cam_file) 40 | 41 | camera_dict = json.load(open(self.cam_file)) 42 | 43 | self.n_images = len(camera_dict) 44 | 45 | cam_center_norms = [] 46 | self.intrinsics_all = [] 47 | self.c2w_all = [] 48 | self.rgb_images = [] 49 | self.object_masks = [] 50 | self.masks_ignore = [] 51 | for imgname, v in tqdm(camera_dict.items(), desc='loading dataset...'): 52 | world_mat = np.array(v['P'], dtype=np.float32).reshape(4,4) 53 | if 'SCALE' in v: 54 | scale_mat = np.array(v['SCALE'], dtype=np.float32).reshape(4,4) 55 | P = world_mat @ scale_mat 56 | else: 57 | P = world_mat 58 | intrinsics, c2w = load_K_Rt_from_P(P[:3, :4]) 59 | cam_center_norms.append(np.linalg.norm(c2w[:3,3])) 60 | 61 | # downscale intrinsics 62 | intrinsics[0, 2] /= downscale 63 | intrinsics[1, 2] /= downscale 64 | intrinsics[0, 0] /= downscale 65 | intrinsics[1, 1] /= downscale 66 | 67 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 68 | self.c2w_all.append(torch.from_numpy(c2w).float()) 69 | 70 | rgb = load_rgb(os.path.join(image_dir, imgname), downscale) 71 | _, self.H, self.W = rgb.shape 72 | rgb = rgb.reshape(3, -1).transpose(1, 0) 73 | self.rgb_images.append(torch.from_numpy(rgb).float()) 74 | fname_base = os.path.splitext(imgname)[0] 75 | 76 | if self.has_mask: 77 | object_mask = load_mask(os.path.join(mask_dir, "{}.png".format(fname_base)), downscale).reshape(-1) 78 | self.object_masks.append(torch.from_numpy(object_mask).to(dtype=torch.bool)) 79 | 80 | if self.has_mask_out: 81 | mask_ignore = load_mask(os.path.join(mask_ignore_dir, "{}.png".format(fname_base)), downscale).reshape(-1) 82 | self.masks_ignore.append(torch.from_numpy(mask_ignore).to(dtype=torch.bool)) 83 | 84 | max_cam_norm = max(cam_center_norms) 85 | if scale_radius > 0: 86 | for i in range(len(self.c2w_all)): 87 | self.c2w_all[i][:3, 3] *= (scale_radius / max_cam_norm / 1.1) 88 | 89 | def __len__(self): 90 | return self.n_images 91 | 92 | def __getitem__(self, idx): 93 | # uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 94 | # uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 95 | # uv = uv.reshape(2, -1).transpose(1, 0) 96 | 97 | sample = { 98 | "intrinsics": self.intrinsics_all[idx], 99 | } 100 | if self.has_mask: 101 | sample["object_mask"] = self.object_masks[idx] 102 | if self.has_mask_out: 103 | sample["mask_ignore"] = self.masks_ignore[idx] 104 | 105 | ground_truth = { 106 | "rgb": self.rgb_images[idx] 107 | } 108 | 109 | ground_truth["rgb"] = self.rgb_images[idx] 110 | 111 | if not self.train_cameras: 112 | sample["c2w"] = self.c2w_all[idx] 113 | 114 | return idx, sample, ground_truth 115 | 116 | def collate_fn(self, batch_list): 117 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 118 | batch_list = zip(*batch_list) 119 | 120 | all_parsed = [] 121 | for entry in batch_list: 122 | if type(entry[0]) is dict: 123 | # make them all into a new dict 124 | ret = {} 125 | for k in entry[0].keys(): 126 | ret[k] = torch.stack([obj[k] for obj in entry]) 127 | all_parsed.append(ret) 128 | else: 129 | all_parsed.append(torch.LongTensor(entry)) 130 | 131 | return tuple(all_parsed) 132 | 133 | def get_gt_pose(self, scaled=True): 134 | # Load gt pose without normalization to unit sphere 135 | camera_dict = json.load(open(self.cam_file)) 136 | 137 | c2w_all = [] 138 | for imgname, v in camera_dict.items(): 139 | world_mat = np.array(v['P'], dtype=np.float32).reshape(4,4) 140 | if scaled and 'SCALE' in v: 141 | scale_mat = np.array(v['SCALE'], dtype=np.float32).reshape(4,4) 142 | P = world_mat @ scale_mat 143 | else: 144 | P = world_mat 145 | _, c2w = load_K_Rt_from_P(P[:3, :4]) 146 | c2w_all.append(torch.from_numpy(c2w).float()) 147 | 148 | return torch.cat([p.float().unsqueeze(0) for p in c2w_all], 0) 149 | 150 | if __name__ == "__main__": 151 | # dataset = SceneDataset(False, './data/taxi/black') 152 | dataset = SceneDataset(False, './data/taxi/blue') 153 | c2w = dataset.get_gt_pose(scaled=True).data.cpu().numpy() 154 | extrinsics = np.linalg.inv(c2w) # camera extrinsics are w2c matrix 155 | camera_matrix = next(iter(dataset))[1]['intrinsics'].data.cpu().numpy() 156 | 157 | from tools.vis_camera import visualize 158 | visualize(camera_matrix, extrinsics) -------------------------------------------------------------------------------- /debug_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ventusff/neurecon/972e810ec252cfd16f630b1de6d2802d1b8de59a/debug_tools/__init__.py -------------------------------------------------------------------------------- /debug_tools/plot_neus_bias.py: -------------------------------------------------------------------------------- 1 | from models.frameworks.neus import sdf_to_w, pdf_phi_s, cdf_Phi_s 2 | 3 | import math 4 | import torch 5 | import numpy as np 6 | import matplotlib 7 | matplotlib.rcParams.update({'font.size': 20}) 8 | import matplotlib.pyplot as plt 9 | import matplotlib.gridspec as gridspec 10 | from matplotlib.widgets import Slider, Button 11 | 12 | BORDER0 = 2.13333333 13 | BORDER1 = 3.13333333 14 | BORDER_CENTER = 0.5 * (BORDER0 + BORDER1) 15 | 16 | 17 | class Plotter(object): 18 | def __init__(self, init_num=20, near=0., far=6., min_num=2, max_num=1024, init_s=64.) -> None: 19 | super().__init__() 20 | 21 | assert far > BORDER0 + 0.1 and near < far and near < BORDER0 - 0.1 22 | 23 | self.near = near 24 | self.far = far 25 | scatter_size = 20. 26 | #------------------- 27 | # prepare init data 28 | #------------------- 29 | fake_dvals = np.linspace(near, far, init_num) 30 | fake_dvals_mid = (fake_dvals[..., 1:] + fake_dvals[..., :-1]) * 0.5 31 | fake_sdf = fake_1d_sdf(fake_dvals) 32 | 33 | #------------------- 34 | # prepare figure 35 | #------------------- 36 | fig = plt.figure(figsize=(30,12)) 37 | #--------------------------- total 38 | gs = gridspec.GridSpec(50, 1) 39 | 40 | #----------------------top 41 | gs_top = gridspec.GridSpecFromSubplotSpec( 42 | nrows=1, 43 | ncols=2, 44 | subplot_spec=gs[0:40, 0]) 45 | #---------------- left: naive w function 46 | pdf_coarse, cdf_coarse, alpha_coarse, w_coarse = naive_sdf2w(fake_dvals, fake_sdf, s=init_s) 47 | # d_pred = np.sum(w_coarse * fake_dvals) 48 | d_pred = np.sum((w_coarse[fake_dvals=1000.(for rendering)'.format(init_s)) 106 | 107 | self.fig = fig 108 | self.ax_top_left = ax_top_left 109 | self.ax_top_right = ax_top_right 110 | self.ax_down_num = ax_down_num 111 | self.ax_down_s = ax_down_s 112 | self.s = init_s 113 | self.num = init_num 114 | 115 | def on_update_slider_num(self, val): 116 | self.num = int(2 ** val) 117 | self.refresh() 118 | 119 | def on_update_slider_s(self, val): 120 | self.s = 2 ** val 121 | self.refresh() 122 | 123 | def refresh(self): 124 | self.ax_down_num.set_xlabel('num = {} samples, {:.1f} points/m'.format(self.num, self.num/(self.far-self.near))) 125 | self.ax_down_s.set_xlabel('s = {:.1f}. fixed_s=64.(for sampling), learned_s>=1000.(for rendering)'.format(self.s)) 126 | 127 | fake_dvals = np.linspace(self.near, self.far, self.num) 128 | fake_dvals_mid = (fake_dvals[..., 1:] + fake_dvals[..., :-1]) * 0.5 129 | fake_sdf = fake_1d_sdf(fake_dvals) 130 | 131 | pdf_coarse, cdf_coarse, alpha_coarse, w_coarse = naive_sdf2w(fake_dvals, fake_sdf, s=self.s) 132 | d_pred = np.sum((w_coarse[fake_dvals 0, torch.zeros_like(dense_sdf), torch.ones_like(dense_sdf)) 163 | 164 | # init 165 | beta = np.sqrt((M**2) / (4 * (init_num-1) * np.log(1+eps))) 166 | # beta = alpha_net * (M**2) / (4 * (init_num-1) * np.log(1+eps)) 167 | 168 | # algorithm 169 | alpha = 1./beta 170 | # alpha = alpha_net 171 | 172 | # ------------- calculating 173 | sdf = sdf1d(x) 174 | sigma = sdf_to_sigma(sdf, alpha, beta) 175 | bounds = error_bound(x, sdf, alpha, beta) 176 | bounds_net = error_bound(x, sdf, alpha_net, beta_net) 177 | print("init beta+ = {:.3f}".format(beta)) 178 | is_end_with_matching = False 179 | it_algo = 0 180 | while it_algo < max_iter and (net_bound_max := bounds_net.max()) > eps: 181 | print("it =", it_algo) 182 | print("net_bound_max = {:.6f}".format(net_bound_max.item())) 183 | 184 | it_algo += 1 185 | #------------- update: upsample 186 | upsampled_x = rend_util.sample_pdf(x, bounds, init_num, det=True) 187 | plot(x, sdf, sigma, bounds, alpha, beta, upsampled_x=upsampled_x) 188 | x = torch.cat([x, upsampled_x], dim=-1) 189 | # x, _ = torch.sort(x, dim=-1) 190 | # sdf = sdf1d(x) 191 | x, sort_indices = torch.sort(x, dim=-1) 192 | sdf = torch.cat([sdf, sdf1d(upsampled_x)], dim=-1) 193 | sdf = torch.gather(sdf, dim=-1, index=sort_indices) 194 | print("more samples:", x.shape[-1]) 195 | 196 | bounds_net = error_bound(x, sdf, alpha_net, beta_net) 197 | if bounds_net.max() > eps: 198 | #-------------- find beta using bisection methods 199 | # left: > eps 200 | # right: < eps 201 | beta_left = beta_net 202 | beta_right = beta 203 | for _ in range(10): 204 | beta_tmp = 0.5 * (beta_left + beta_right) 205 | alpha_tmp = 1./beta_tmp 206 | # alpha_tmp = alpha_net 207 | bounds_tmp = error_bound(x, sdf, alpha_tmp, beta_tmp) 208 | bounds_max_tmp = bounds_tmp.max() 209 | if bounds_max_tmp < eps: 210 | beta_right = beta_tmp 211 | elif bounds_max_tmp == eps: 212 | beta_right = beta_tmp 213 | break 214 | else: 215 | beta_left = beta_tmp 216 | beta = beta_right 217 | alpha = 1./beta 218 | # alpha = alpha_net 219 | sigma = sdf_to_sigma(sdf, alpha, beta) 220 | bounds = error_bound(x, sdf, alpha, beta) 221 | else: 222 | is_end_with_matching = True 223 | break 224 | print("new beta+ = {:.3f}".format(beta)) 225 | if (not is_end_with_matching) and (it_algo != 0): 226 | beta_net = beta_right 227 | alpha_net = 1./beta_net 228 | print("it=", it_algo) 229 | print("final beta:", beta_net) 230 | sigma = sdf_to_sigma(sdf, alpha_net, beta_net) 231 | bounds = error_bound(x, sdf, alpha_net, beta_net) 232 | print("final error bound max:", bounds.max()) 233 | plot(x, sdf, sigma, bounds, alpha_net, beta_net) 234 | 235 | ## ---------------------- backup 236 | # def sdf_to_sigma(sdf: torch.Tensor, alpha, beta): 237 | # # sdf *= -1 # NOTE: this will cause inplace opt! 238 | # sdf = -sdf 239 | # expsbeta = torch.exp(sdf / beta) 240 | # psi = torch.where(sdf <= 0, 0.5 * expsbeta, 1 - 0.5 / expsbeta) 241 | # return alpha * psi 242 | 243 | 244 | # def error_bound(d_vals, sdf, alpha, beta): 245 | # """ 246 | # d_vals: [(B), N_rays, N_pts] 247 | # sdf: [(B), N_rays, N_pts] 248 | # """ 249 | # device = sdf.device 250 | # sigma = sdf_to_sigma(sdf, alpha, beta) 251 | # # [(B), N_rays, N_pts] 252 | # sdf_abs_i = torch.abs(sdf) 253 | # # [(B), N_rays, N_pts-1] 254 | # delta_i = d_vals[..., 1:] - d_vals[..., :-1] 255 | # # [(B), N_rays, N_pts] 256 | # R_t = torch.cat( 257 | # [ 258 | # torch.zeros([*sdf.shape[:-1], 1], device=device), 259 | # torch.cumsum(sigma[..., :-1] * delta_i, dim=-1) 260 | # ], dim=-1) 261 | # # [(B), N_rays, N_pts-1] 262 | # d_i_star = torch.clamp_min(0.5 * (sdf_abs_i[..., :-1] + sdf_abs_i[..., 1:] - delta_i), 0.) 263 | # # [(B), N_rays, N_pts-1] 264 | # errors = alpha/(4*beta) * (delta_i**2) * torch.exp(-d_i_star / beta) 265 | # # [(B), N_rays, N_pts-1] 266 | # errors_t = torch.cumsum(errors, dim=-1) 267 | # # [(B), N_rays, N_pts-1] 268 | # bounds = torch.exp(-R_t[..., :-1]) * (torch.exp(errors_t) - 1.) 269 | # # TODO: better solution 270 | # # NOTE: nan comes from 0 * inf 271 | # # NOTE: every situation where nan appears will also appears c * inf = "true" inf, so below solution is acceptable 272 | # bounds[torch.isnan(bounds)] = np.inf 273 | # return bounds 274 | -------------------------------------------------------------------------------- /docs/neus.md: -------------------------------------------------------------------------------- 1 | # Notes on the unbiased property of NeuS 2 | 3 | In NeuS's solution, the maximum color contribution (visibility weights) has no first order bias with the exact surface. 4 | 5 | - See this [[video]](https://longtimenohack.com/hosted/nerf-surface/neus_unbiased.mp4) (1.3MiB). 6 | - To try it yourself, run: 7 | 8 | ```shell 9 | python -m debug_tools.plot_neus_bias 10 | ``` 11 | 12 | -------------------------------------------------------------------------------- /docs/trained_models_results.md: -------------------------------------------------------------------------------- 1 | ## Results 2 | 3 | ### DTU dataset 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 |
SCAN_IDreference imagerendered views & learned shape
(above: novel view synthesis & below: 3D reconstruction)
UNISURFNeuS (NeRF++ background)VolSDF
24
37
40
55
63
65
69
83
97
105
106
110
114
118
122
126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | ### BlendedMVS dataset 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 |
IDreference imagelearned shape & rendered views
(3D reconstruction & novel view synthesis)
UNISURFNeuSVolSDF
5c0d13b795da9479e12e2ee9figure
-------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | - [Environment](#environment) 2 | - [hardware](#hardware) 3 | - [software](#software) 4 | - [(optional)](#optional) 5 | - [Dataset preparation](#dataset-preparation) 6 | - [DTU](#dtu) 7 | - [BlendedMVS](#blendedmvs) 8 | - [Training](#training) 9 | - [new training](#new-training) 10 | - [resume training](#resume-training) 11 | - [monitoring & logging](#monitoring--logging) 12 | - [configs](#configs) 13 | - [training on multi-gpu or clusters](#training-on-multi-gpu-or-clusters) 14 | - [Evaluation: mesh extraction](#evaluation-mesh-extraction) 15 | - [Evaluation: free viewport rendering](#evaluation-free-viewport-rendering) 16 | - [Before rendering, debug camera trajectory by visualization](#before-rendering-debug-camera-trajectory-by-visualization) 17 | - [Only render RGB & depth & normal images](#only-render-rgb--depth--normal-images) 18 | - [Only render mesh](#only-render-mesh) 19 | - [Render RGB, depth image, normal image, and mesh](#render-rgb-depth-image-normal-image-and-mesh) 20 | - [:pushpin: Use surface rendering, instead of volume rendering](#pushpin-use-surface-rendering-instead-of-volume-rendering) 21 | - [[WIP] to run on your own datasets](#wip-to-run-on-your-own-datasets) 22 | - [prerequisites](#prerequisites) 23 | 24 | ## Environment 25 | 26 | ### hardware 27 | 28 | - Currently tested with RTX3090 with 24GiB GPU memotry, and tested with clusters. 29 | - :pushpin: setting larger `data:val_downscale` and smaller `data:val_rayschunk` in the configs will reduce GPU memory usage. 30 | - model size is quite small: `~10MiB`, since the model is just several MLPs. The rendering consumes a lot of GPU. 31 | 32 | | | GPU Memory required
@ val_downscale=8 & val_rayschunk=256 | 33 | | -------------- | ------------------------------------------------------------ | 34 | | UNISURF | >= 6 GiB | 35 | | NeuS @w/o mask | >= 9 GiB | 36 | | VolSDF | >=11 GiB | 37 | 38 | ### software 39 | 40 | - python>=3.6 (tested on python=3.8.5 using conda) 41 | - pytorch>=1.6.0 42 | - `pip install tqdm scikit-image opencv-python pandas tensorboard addict imageio imageio-ffmpeg pyquaternion scikit-learn pyyaml seaborn PyMCubes trimesh plyfile` 43 | 44 | #### (optional) 45 | - visualization of meshes: `pip install open3d` 46 | 47 | 48 | 49 | ## Dataset preparation 50 | 51 | ### DTU 52 | 53 | Follow the download_data.sh script from IDR repo: [[click here]](https://github.com/lioryariv/idr/blob/main/data/download_data.sh). 54 | 55 | > NOTE: For NeuS experiments, you can also use their versions of DTU data, see [[here]](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing)。 56 | > 57 | > The camera normalizations/scaling are similar, except they seems to additionally adjust the rotations so that objects are in the same canonical frame. 58 | > 59 | > And then, add `data: cam_file: cameras_sphere.npz` to the configs (this is already by default for neus_xxx.yaml in this repo. 60 | 61 | ### BlendedMVS 62 | 63 | From the BlendedMVS repo: [[click here]](https://github.com/YoYo000/BlendedMVS), Download the low-res set `BlendedMVS.zip`(27.5GB) of BlendedMVS dataset. 64 | 65 | Download the `BlendedMVS_norm_cam.tar.gz` from [[GoogleDrive]](https://drive.google.com/drive/folders/1B7y-nMFO9noVI0byU34yPTRtqqzMdMIQ?usp=sharing) or [[Baidu, code: `reco`]](https://pan.baidu.com/s/10g1IWwrGrpE--VJ5XLuRFw), and extract them into the same folder of the extracted `BlendedMVS.zip`. 66 | 67 | The final file structure would be: 68 | 69 | ```python 70 | BlendedMVS 71 | ├── 5c0d13b795da9479e12e2ee9 72 | │ ├── blended_images # original. not changed. 73 | │ ├── cams # original. not changed. 74 | │ ├── rendered_depth_maps # original. not changed. 75 | │ ├── cams_normalized # newly appended normalized cameras. 76 | ├── ... 77 | ``` 78 | 79 | > NOTE: normalization of some of the instances of blendedMVS failed. You can refer to [[dataio/blendedmvs_normalized.txt]](dataio/blendedmvs_normalized.txt) for the succeeded list. 80 | 81 | > :warning: WARNING: the normalization method is currently not fully tested on all BlendedMVS instances, and the normalized camera file may be updated. 82 | 83 | ## Training 84 | 85 | ### new training 86 | 87 | ```shell 88 | python -m train --config configs/volsdf.yaml 89 | ``` 90 | 91 | or you can use any of the configs in the [configs](../configs) folder; 92 | 93 | or you can create new configs on your own. 94 | 95 | For training on multi-GPU or clusters, see section [[training on multi-gpu or clusters]](#training-on-multi-gpu-or-clusters)。 96 | 97 | ### resume training 98 | 99 | ```shell 100 | # replace xxx with specific expname 101 | python -m train --resume_dir ./logs/xxx/ 102 | ``` 103 | 104 | **Or**, simply re-use the original config file: 105 | 106 | ```shell 107 | # replace xxx with specific filename 108 | python -m train --config configs/xxx.yaml 109 | ``` 110 | 111 | ### monitoring & logging 112 | 113 | ```shell 114 | # replace xxx with specific expname 115 | tensorboard --logdir logs/xxx/events/ 116 | ``` 117 | 118 | the whole logging directory is structured as follows: 119 | 120 | ```python 121 | logs 122 | ├── exp1 123 | │ ├── backup # backup codes 124 | │ ├── ckpts # saved checkpoints 125 | │ ├── config.yaml # the training config 126 | │ ├── events # tensorboard events 127 | │ ├── imgs # validation image output # NOTE: default validation image is 8 time down-sampled. 128 | │ └── stats.p # saved scalars stats (lr, losses, value max/mins, etc.) 129 | ├── exp2 130 | └── ... 131 | ``` 132 | 133 | ### configs 134 | 135 | You can run different experiments by running different configs files. All of the config files of implemented papers (NeuS, VolSDF and UNISURF) can be found in the [[configs]](../configs) folder. 136 | 137 | ```python 138 | configs 139 | ├── neus.yaml # NeuS, training with mask 140 | ├── neus_nomask.yaml # NeuS, training without mask, using NeRF++ as background 141 | ├── neus_nomask_blended.yaml # NeuS, training without mask, using NeRF++ as background, for training with BlendedMVS dataset. 142 | ├── unisurf.yaml # UNISURF 143 | ├── volsdf.yaml # VolSDF 144 | ├── volsdf_nerfpp.yaml # VolSDF, with NeRF++ as background 145 | ├── volsdf_nerfpp_blended.yaml # VolSDF, with NeRF++ as background, for training with BlendedMVS dataset. 146 | └── volsdf_siren.yaml # VolSDF, with SIREN replaces ReLU activation. 147 | ``` 148 | 149 | ### training on multi-gpu or clusters 150 | 151 | This repo has full tested support for the following training conditions: 152 | 153 | - single GPU 154 | - `nn.DataParallel` 155 | - `DistributedDataParallel` with `torch.distributed.launch` 156 | - `DistributedDataParallel` with `SLURM`. 157 | 158 | #### single process with single GPU 159 | 160 | ```shell 161 | python -m train --config xxx 162 | ``` 163 | 164 | Or force to run on one GPU if you have multiple GPUs 165 | 166 | ```shell 167 | python -m train --config xxx --device_ids 0 168 | ``` 169 | 170 | #### single process with multiple GPU (`nn.DataParallel`) 171 | 172 | Automatically supported since default `device_ids` is set to `-1`, which means using all available GPUs. 173 | 174 | Or, you can specify used GPUs manually: (for example) 175 | 176 | ```shell 177 | python -m train --config xxx --device_ids 1,0,5,3 178 | ``` 179 | 180 | #### multi-process with multiple local GPU (`DistributedDataParallel` with `torch.distributed.launch`) 181 | 182 | Add `--ddp` when calling `train.py`. 183 | 184 | ```shell 185 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 train.py --ddp --config xxx 186 | ``` 187 | 188 | #### multi-process with clusters (`DistributedDataParallel` with `SLURM`) 189 | 190 | Add `--ddp` when calling `train.py`. 191 | 192 | ```shell 193 | srun --partition your_partition --mpi pmi2 --gres=gpu:4 --cpus-per-task=7 --ntasks-per-node=4 -n4 \ 194 | --kill-on-bad-exit=1 python -m train --ddp --config xxx --expname cluster_test --training:monitoring none --training:i_val 2000 195 | ``` 196 | 197 | 198 | 199 | ## Evaluation: mesh extraction 200 | 201 | ```shell 202 | python -m tools.extract_surface --load_pt /path/to/xxx.pt --N 512 --volume_size 2.0 --out /path/to/surface.ply 203 | ``` 204 | 205 | 206 | 207 | ## Evaluation: free viewport rendering 208 | 209 | ### Before rendering, debug camera trajectory by visualization 210 | 211 | | camera trajectory type | example | explanation and command line options | 212 | | ---------------------- | ---------------------------------------------------------- | ------------------------------------------------------------ | 213 | | `small_circle` | ![cam_small_circle](../media/cam_small_circle.gif) | select 3 view ids, in CCW order (from above);
when rendering, will interpolate camera paths along the **small circle** that pass through the selected 3 camera center locations
`--camera_path small_circle --camera_inds 11,14,17` | 214 | | `great_circle` | ![cam_great_circle](../media/cam_great_circle.gif) | select 2 view ids, in CCW order (from above);
when rendering, will interpolate camera paths along the **great circle** that pass through the selected 2 camera center locations
`--camera_path great_circle --camera_inds 11,17` | 215 | | `spherical_spiral` | ![cam_spherical_spiral](../media/cam_spherical_spiral.gif) | select 3 view ids, in CCW order (from above);
when rendering, will interpolate camera paths along the **spherical spiral path** that starts from the **small circle** that pass through these 3 selected camera centers
`--camera_path spherical_spiral --camera_inds 11,14,17` | 216 | 217 | Add `--debug` option. 218 | 219 | ```shell 220 | python -m tools.render_view --config trained_models/volsdf_114/config.yaml --load_pt trained_models/volsdf_114/final_00100000.pt --camera_path spherical_spiral --camera_inds 48,43,38 --debug --num_views 20 221 | ``` 222 | 223 | You can replace the `camera_path` and `camera_inds` with any of the camera path configurations you like. 224 | 225 | > NOTE: remember to remove the --debug option after debugging. 226 | 227 | ### Only render RGB & depth & normal images 228 | 229 | For GPUs with smaller GPU memory, use smaller `rayschunk`, and larger `downscale`. 230 | 231 | ```shell 232 | python -m tools.render_view --num_views 60 --downscale 4 --config trained_models/volsdf_24/config.yaml \ 233 | --load_pt trained_models/volsdf_24/final_00100000.pt --camera_path small_circle --rayschunk 1024 \ 234 | --camera_inds 27,24,21 --H_scale 1.2 235 | ``` 236 | 237 | ### Only render mesh 238 | 239 | Add `--disable_rgb` option. 240 | 241 | ### Render RGB, depth image, normal image, and mesh 242 | 243 | Add `--render_mesh /path/to/xxx.ply`. Example: 244 | 245 | ```shell 246 | python -m tools.render_view --num_views 60 --downscale 4 --config trained_models/volsdf_24/config.yaml \ 247 | --load_pt trained_models/volsdf_24/final_00100000.pt --camera_path small_circle --rayschunk 1024 \ 248 | --render_mesh trained_models/volsdf_24/surface.ply --camera_inds 27,24,21 --H_scale 1.2 249 | ``` 250 | 251 | ### :pushpin: Use surface rendering, instead of volume rendering 252 | 253 | Since the underlying shape representation is a implicit surface, one can use surface rendering techniques to render the image. For each ray, only the point intersected with the surface will have contribution to its pixel color, instead of considering neighboring points along the ray as in volume rendering. 254 | 255 | This will boost rendering speed **100x** faster using `sphere_tracing`. 256 | 257 | - Specifically, for NeuS/VolSDF which utilize SDF representation, you can render with `sphere_tracing` or `root_finding` along the ray. 258 | 259 | Just add `--use_surface_render sphere_tracing`. 260 | 261 | ```shell 262 | python -m tools.render_view --device_ids 0 --num_views 60 --downscale 4 --num_views 60 --downscale 4 \ 263 | --config trained_models/neus_65_nomask/config.yaml --load_pt trained_models/neus_65_nomask/final_00300000.pt \ 264 | --camera_path small_circle --camera_inds 11,13,15 --H_scale 1.2 --outbase neus_st_65 \ 265 | --use_surface_render sphere_tracing --rayschunk 1000000 266 | ``` 267 | 268 | > NOTE: in this case, the `rayschunk` can be very very large, 1000000 for example, since only ONE point on the ray is queried. 269 | 270 | Example @NeuS @DTU-65 @[360 x 400] resolution @60 frames rendering. 271 | 272 | | | Original volume rendering
& Integrated normals of volume rendering | Surface rendering using `sphere_tracing`
& Normals from the ray-traced surface points | 273 | | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 274 | | rendering time | 28 minutes | 18 seconds | 275 | | rendered result | ![neus_65_nomask_new_rgb&normal_360x400_60_small_circle_None](../media/neus_65_nomask_new_rgb&normal_360x400_60_small_circle_None.gif) | ![neus_65_nomask_new_rgb&normal_360x400_60_small_circle_sphere_tracing_None](../media/neus_65_nomask_new_rgb&normal_360x400_60_small_circle_sphere_tracing_None.gif) | 276 | 277 | > NOTE: the NeRF++ background is removed when sphere tracing. 278 | 279 | 280 | 281 | - For UNISURF which utilize OccupancyNet representation, you can render with `--use_surface_render root_finding`. 282 | 283 | 284 | 285 | ## [WIP] to run on your own datasets 286 | 287 | ### prerequisites 288 | 289 | - [COLMAP](https://github.com/colmap/colmap) for extracting camera extrinsics 290 | - To run on your own masks: 291 | - annotation tool: [CVAT](https://github.com/openvinotoolkit/cvat) or their online annotation site: [cvat.org](https://cvat.org/) 292 | - load coco mask: `pip install pycocotools` 293 | -------------------------------------------------------------------------------- /docs/volsdf.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Notes on the up-sampling algorithm and error bound of VolSDF 4 | 5 | In VolSDF, they prove a error bound of the discontinuous Riemann Sum's approximations in the opacity's calculation, and derive a ray point up-sampling algorithm to control the error bound to keep smaller than manually set `episilon`, which is set to `0.1`. 6 | 7 | ## 1. up sampling algorithm's visualization in tensorboard when training 8 | 9 | | | @0k | @4k | @10k | @200k | 10 | | ------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 11 | | up sample iterations until converged | ![volsdf_up_iter_00000000](../media/volsdf_up_iter_00000000.png) | ![volsdf_up_iter_00004000](../media/volsdf_up_iter_00004000.png) | ![volsdf_up_iter_00010000](../media/volsdf_up_iter_00010000.png) | ![volsdf_up_iter_00200000](../media/volsdf_up_iter_00200000.png) | 12 | | beta heat map | ![volsdf_beta_00000000](../media/volsdf_beta_00000000.png) | ![volsdf_beta_00004000](../media/volsdf_beta_00004000.png) | ![volsdf_beta_00010000](../media/volsdf_beta_00010000.png) | ![volsdf_beta_00200000](../media/volsdf_beta_00200000.png) | 13 | 14 | 15 | 16 | ## 2. up sampling algorithm single ray testing 17 | 18 | - network's beta = 0.001 19 | - eps = 0.1 20 | - To try it yourself, run: 21 | 22 | ```shell 23 | python -m debug_tools.test_volsdf_algo 24 | ``` 25 | 26 | 27 | 28 | ### 0-th iteration 29 | 30 | - 128 uniform sample points. 31 | - If use the network's `beta`, then `error_bound.max`=inf, which does not satisfy ` 0: 23 | d_low[ind_low] = d_pred[ind_low] 24 | f_low[ind_low] = f_mid[ind_low] 25 | if (ind_low == 0).sum() > 0: 26 | d_high[ind_low == 0] = d_pred[ind_low == 0] 27 | f_high[ind_low == 0] = f_mid[ind_low == 0] 28 | 29 | d_pred = - f_low * (d_high - d_low) / (f_high - f_low) + d_low 30 | return d_pred 31 | 32 | def run_bisection_method(): 33 | pass 34 | 35 | def root_finding_surface_points( 36 | surface_query_fn, 37 | rays_o: torch.Tensor, rays_d: torch.Tensor, 38 | near: Union[float, torch.Tensor]=0.0, 39 | far: Union[float, torch.Tensor]=6.0, 40 | # function config 41 | batched = True, 42 | batched_info = {}, 43 | # algorithm config 44 | N_steps = 256, 45 | logit_tau=0.0, 46 | method='secant', 47 | N_secant_steps = 8, 48 | fill_inf=True, 49 | ): 50 | """ 51 | rays_o: [(B), N_rays, 3] 52 | rays_d: [(B), N_rays, 3] 53 | near: float or [(B), N_rays] 54 | far: float or [(B), N_rays] 55 | """ 56 | # NOTE: jianfei: modified from DVR. https://github.com/autonomousvision/differentiable_volumetric_rendering 57 | # NOTE: DVR'logits (+)inside (-)outside; logits here, (+)outside (-)inside. 58 | # NOTE: rays_d needs to be already normalized 59 | with torch.no_grad(): 60 | device = rays_o.device 61 | if not batched: 62 | rays_o.unsqueeze_(0) 63 | rays_d.unsqueeze_(0) 64 | 65 | B = rays_o.shape[0] 66 | N_rays = rays_o.shape[-2] 67 | 68 | # [B, N_rays, N_steps, 1] 69 | t = torch.linspace(0., 1., N_steps, device=device)[None, None, :] 70 | if not isinstance(near, torch.Tensor): 71 | near = near * torch.ones(rays_o.shape[:-1], device=device) 72 | if not isinstance(far, torch.Tensor): 73 | far = far * torch.ones(rays_o.shape[:-1], device=device) 74 | d_proposal = near[..., None] * (1-t) + far[..., None] * t 75 | 76 | # [B, N_rays, N_steps, 3] 77 | p_proposal = rays_o.unsqueeze(-2) + d_proposal.unsqueeze(-1) * rays_d.unsqueeze(-2) 78 | 79 | # only query sigma 80 | pts = p_proposal 81 | 82 | # query network 83 | # [B, N_rays, N_steps] 84 | val = surface_query_fn(pts) 85 | # [B, N_rays, N_steps] 86 | val = val - logit_tau # centered at zero 87 | 88 | # mask: the first point is not occupied 89 | # [B, N_rays] 90 | mask_0_not_occupied = val[..., 0] > 0 91 | 92 | # [B, N_rays, N_steps-1] 93 | sign_matrix = torch.cat( 94 | [ 95 | torch.sign(val[..., :-1] * val[..., 1:]), # [B, N, N_steps-1] 96 | torch.ones([B, N_rays, 1], device=device) # [B, N, 1] 97 | ], dim=-1) 98 | 99 | # [B, N_rays, N_steps-1] 100 | cost_matrix = sign_matrix * torch.arange(N_steps, 0, -1).float().to(device) 101 | 102 | values, indices = torch.min(cost_matrix, -1) 103 | 104 | # mask: at least one sign change occured 105 | mask_sign_change = values < 0 106 | 107 | # mask: whether the first sign change is from pos to neg (outside surface into the surface) 108 | mask_pos_to_neg = val[torch.arange(B).unsqueeze(-1), torch.arange(N_rays).unsqueeze(0), indices] > 0 109 | 110 | mask = mask_sign_change & mask_pos_to_neg & mask_0_not_occupied 111 | 112 | #--------- secant method 113 | # [B*N_rays, N_steps, 1] 114 | d_proposal_flat = d_proposal.view([B*N_rays, N_steps, 1]) 115 | val_flat = val.view([B*N_rays, N_steps, 1]) 116 | N_secant = d_proposal_flat.shape[0] 117 | 118 | # [N_masked] 119 | d_high = d_proposal_flat[torch.arange(N_secant), indices.view(N_secant)].view([B, N_rays])[mask] 120 | f_high = val_flat[torch.arange(N_secant), indices.view(N_secant)].view([B, N_rays])[mask] 121 | 122 | indices = torch.clamp(indices + 1, max=N_steps - 1) 123 | d_low = d_proposal_flat[torch.arange(N_secant), indices.view(N_secant)].view([B, N_rays])[mask] 124 | f_low = val_flat[torch.arange(N_secant), indices.view(N_secant)].view([B, N_rays])[mask] 125 | 126 | # [N_masked, 3] 127 | rays_o_masked = rays_o[mask] 128 | rays_d_masked = rays_d[mask] 129 | 130 | # TODO: for categorical representation, mask latents here 131 | 132 | if method == 'secant' and mask.sum() > 0: 133 | d_pred = run_secant_method( 134 | f_low, f_high, d_low, d_high, 135 | rays_o_masked, rays_d_masked, 136 | surface_query_fn, 137 | N_secant_steps, logit_tau) 138 | else: 139 | d_pred = torch.ones(rays_o_masked.shape[0]).to(device) 140 | 141 | # for sanity 142 | pt_pred = torch.ones([B, N_rays, 3]).to(device) 143 | pt_pred[mask] = rays_o_masked + d_pred.unsqueeze(-1) * rays_d_masked 144 | 145 | d_pred_out = torch.ones([B, N_rays]).to(device) 146 | d_pred_out[mask] = d_pred 147 | 148 | # Insert appropriate values for points where no depth is predicted 149 | if isinstance(far, torch.Tensor): 150 | far = far[mask == 0] 151 | d_pred_out[mask == 0] = np.inf if fill_inf else far # no intersections; or the first intersection is from outside to inside; or the 0-th point is occupied. 152 | d_pred_out[mask_0_not_occupied == 0] = 0 # if the 0-th point is occupied, the depth should be 0. 153 | 154 | if not batched: 155 | d_pred_out.squeeze_(0) 156 | pt_pred.squeeze_(0) 157 | mask.squeeze_(0) 158 | mask_sign_change.squeeze_(0) 159 | 160 | return d_pred_out, pt_pred, mask, mask_sign_change 161 | 162 | 163 | def sphere_tracing_surface_points( 164 | implicit_surface: ImplicitSurface, 165 | rays_o, rays_d, 166 | # function config 167 | near=0.0, 168 | far=6.0, 169 | batched = True, 170 | batched_info = {}, 171 | # algorithm config 172 | N_iters = 20, 173 | ): 174 | device = rays_o.device 175 | d_preds = torch.ones([*rays_o.shape[:-1]], device=device) * near 176 | mask = torch.ones_like(d_preds, dtype=torch.bool, device=device) 177 | for _ in range(N_iters): 178 | pts = rays_o + rays_d * d_preds[..., :, None] 179 | surface_val = implicit_surface.forward(pts) 180 | d_preds[mask] += surface_val[mask] 181 | mask[d_preds > far] = False 182 | mask[d_preds < 0] = False 183 | pts = rays_o + rays_d * d_preds[..., :, None] 184 | return d_preds, pts, mask 185 | 186 | 187 | def surface_render(rays_o: torch.Tensor, rays_d: torch.Tensor, 188 | model, 189 | calc_normal=True, 190 | rayschunk=8192, netchunk=1048576, batched=True, use_view_dirs=True, show_progress=False, 191 | ray_casting_algo='', 192 | ray_casting_cfgs={}, 193 | **not_used_kwargs): 194 | """ 195 | input: 196 | rays_o: [(B,) N_rays, 3] 197 | rays_d: [(B,) N_rays, 3] NOTE: not normalized. contains info about ratio of len(this ray)/len(principle ray) 198 | """ 199 | with torch.no_grad(): 200 | device = rays_o.device 201 | if batched: 202 | DIM_BATCHIFY = 1 203 | B = rays_d.shape[0] # batch_size 204 | flat_vec_shape = [B, -1, 3] 205 | else: 206 | DIM_BATCHIFY = 0 207 | flat_vec_shape = [-1, 3] 208 | rays_o = torch.reshape(rays_o, flat_vec_shape).float() 209 | rays_d = torch.reshape(rays_d, flat_vec_shape).float() 210 | # NOTE: already normalized 211 | rays_d = F.normalize(rays_d, dim=-1) 212 | 213 | # --------------- 214 | # Render a ray chunk 215 | # --------------- 216 | def render_rayschunk(rays_o: torch.Tensor, rays_d: torch.Tensor): 217 | if use_view_dirs: 218 | view_dirs = rays_d 219 | else: 220 | view_dirs = None 221 | if ray_casting_algo == 'root_finding': 222 | d_pred_out, pt_pred, mask, *_ = root_finding_surface_points( 223 | model.implicit_surface, rays_o, rays_d, batched=batched, **ray_casting_cfgs) 224 | elif ray_casting_algo == 'sphere_tracing': 225 | d_pred_out, pt_pred, mask = sphere_tracing_surface_points( 226 | model.implicit_surface, rays_o, rays_d, batched=batched, **ray_casting_cfgs) 227 | else: 228 | raise NotImplementedError 229 | 230 | color, _, nablas = model.forward(pt_pred, view_dirs) 231 | color[~mask] = 0 # black 232 | # NOTE: all without grad. especially for nablas. 233 | return color.data, d_pred_out.data, nablas.data, mask.data 234 | 235 | colors = [] 236 | depths = [] 237 | nablas = [] 238 | masks = [] 239 | for i in tqdm(range(0, rays_o.shape[DIM_BATCHIFY], rayschunk), disable=not show_progress): 240 | color_i, d_i, nablas_i, mask_i = render_rayschunk( 241 | rays_o[:, i:i+rayschunk] if batched else rays_o[i:i+rayschunk], 242 | rays_d[:, i:i+rayschunk] if batched else rays_d[i:i+rayschunk] 243 | ) 244 | colors.append(color_i) 245 | depths.append(d_i) 246 | nablas.append(nablas_i) 247 | masks.append(mask_i) 248 | colors = torch.cat(colors, DIM_BATCHIFY) 249 | depths = torch.cat(depths, DIM_BATCHIFY) 250 | nablas = torch.cat(nablas, DIM_BATCHIFY) 251 | masks = torch.cat(masks, DIM_BATCHIFY) 252 | 253 | extras = OrderedDict([ 254 | ('implicit_nablas', nablas), 255 | ('mask_surface', masks) 256 | ]) 257 | 258 | if calc_normal: 259 | normals = F.normalize(nablas, dim=-1) 260 | normals[~masks] = 0 # grey (/2.+0.5) 261 | extras['normals_surface'] = normals 262 | 263 | return colors, depths, extras 264 | -------------------------------------------------------------------------------- /set_env.sh: -------------------------------------------------------------------------------- 1 | # put project directory into PYTHONPATH 2 | 3 | # source $HOME/.condabashrc 4 | # conda activate nerf 5 | 6 | DIR="$(pwd)" 7 | export PYTHONPATH="${DIR}":$PYTHONPATH 8 | echo "added $DIR to PYTHONPATH" 9 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ventusff/neurecon/972e810ec252cfd16f630b1de6d2802d1b8de59a/tools/__init__.py -------------------------------------------------------------------------------- /tools/extract_surface.py: -------------------------------------------------------------------------------- 1 | from models.base import ImplicitSurface 2 | from utils.mesh_util import extract_mesh 3 | 4 | import torch 5 | 6 | def main_function(args): 7 | N = args.N 8 | s = args.volume_size 9 | implicit_surface = ImplicitSurface(radius_init=args.init_r).cuda() 10 | if args.load_pt is not None: 11 | # --------- if load statedict 12 | # state_dict = torch.load("/home/PJLAB/guojianfei/latest.pt") 13 | # state_dict = torch.load("./dev_test/37/latest.pt") 14 | state_dict = torch.load(args.load_pt) 15 | imp_surface_state_dict = {k.replace('implicit_surface.',''):v for k, v in state_dict['model'].items() if 'implicit_surface.' in k} 16 | imp_surface_state_dict['obj_bounding_size'] = torch.tensor([1.0]).cuda() 17 | implicit_surface.load_state_dict(imp_surface_state_dict) 18 | if args.out is None: 19 | from datetime import datetime 20 | dt = datetime.now() 21 | args.out = 'surface_' + dt.strftime("%Y%m%d%H%M%S") + '.ply' 22 | extract_mesh(implicit_surface, s, N=N, filepath=args.out, show_progress=True, chunk=args.chunk) 23 | 24 | 25 | if __name__ == "__main__": 26 | import argparse 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--out", type=str, default=None, help='output ply file name') 29 | parser.add_argument('--N', type=int, default=512, help='resolution of the marching cube algo') 30 | parser.add_argument('--volume_size', type=float, default=2., help='voxel size to run marching cube') 31 | parser.add_argument("--load_pt", type=str, default=None, help='the trained model checkpoint .pt file') 32 | parser.add_argument("--chunk", type=int, default=16*1024, help='net chunk when querying the network. change for smaller GPU memory.') 33 | parser.add_argument("--init_r", type=float, default=1.0, help='Optional. The init radius of the implicit surface.') 34 | args = parser.parse_args() 35 | 36 | main_function(args) -------------------------------------------------------------------------------- /tools/vis_camera.py: -------------------------------------------------------------------------------- 1 | ''' 2 | camera extrinsics visualization tools 3 | modified from https://github.com/opencv/opencv/blob/master/samples/python/camera_calibration_show_extrinsics.py 4 | ''' 5 | 6 | from utils.print_fn import log 7 | # Python 2/3 compatibility 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | import cv2 as cv 12 | 13 | from numpy import linspace 14 | import matplotlib 15 | 16 | matplotlib.use('TkAgg') 17 | 18 | def inverse_homogeneoux_matrix(M): 19 | R = M[0:3, 0:3] 20 | T = M[0:3, 3] 21 | M_inv = np.identity(4) 22 | M_inv[0:3, 0:3] = R.T 23 | M_inv[0:3, 3] = -(R.T).dot(T) 24 | 25 | return M_inv 26 | 27 | 28 | def transform_to_matplotlib_frame(cMo, X, inverse=False): 29 | M = np.identity(4) 30 | M[1, 1] = 0 31 | M[1, 2] = 1 32 | M[2, 1] = -1 33 | M[2, 2] = 0 34 | 35 | if inverse: 36 | return M.dot(inverse_homogeneoux_matrix(cMo).dot(X)) 37 | else: 38 | return M.dot(cMo.dot(X)) 39 | 40 | 41 | def create_camera_model(camera_matrix, width, height, scale_focal, draw_frame_axis=False): 42 | fx = camera_matrix[0, 0] 43 | fy = camera_matrix[1, 1] 44 | focal = 2 / (fx + fy) 45 | f_scale = scale_focal * focal 46 | 47 | # draw image plane 48 | X_img_plane = np.ones((4, 5)) 49 | X_img_plane[0:3, 0] = [-width, height, f_scale] 50 | X_img_plane[0:3, 1] = [width, height, f_scale] 51 | X_img_plane[0:3, 2] = [width, -height, f_scale] 52 | X_img_plane[0:3, 3] = [-width, -height, f_scale] 53 | X_img_plane[0:3, 4] = [-width, height, f_scale] 54 | 55 | # draw triangle above the image plane 56 | X_triangle = np.ones((4, 3)) 57 | X_triangle[0:3, 0] = [-width, -height, f_scale] 58 | X_triangle[0:3, 1] = [0, -2*height, f_scale] 59 | X_triangle[0:3, 2] = [width, -height, f_scale] 60 | 61 | # draw camera 62 | X_center1 = np.ones((4, 2)) 63 | X_center1[0:3, 0] = [0, 0, 0] 64 | X_center1[0:3, 1] = [-width, height, f_scale] 65 | 66 | X_center2 = np.ones((4, 2)) 67 | X_center2[0:3, 0] = [0, 0, 0] 68 | X_center2[0:3, 1] = [width, height, f_scale] 69 | 70 | X_center3 = np.ones((4, 2)) 71 | X_center3[0:3, 0] = [0, 0, 0] 72 | X_center3[0:3, 1] = [width, -height, f_scale] 73 | 74 | X_center4 = np.ones((4, 2)) 75 | X_center4[0:3, 0] = [0, 0, 0] 76 | X_center4[0:3, 1] = [-width, -height, f_scale] 77 | 78 | # draw camera frame axis 79 | X_frame1 = np.ones((4, 2)) 80 | X_frame1[0:3, 0] = [0, 0, 0] 81 | X_frame1[0:3, 1] = [f_scale/2, 0, 0] 82 | 83 | X_frame2 = np.ones((4, 2)) 84 | X_frame2[0:3, 0] = [0, 0, 0] 85 | X_frame2[0:3, 1] = [0, f_scale/2, 0] 86 | 87 | X_frame3 = np.ones((4, 2)) 88 | X_frame3[0:3, 0] = [0, 0, 0] 89 | X_frame3[0:3, 1] = [0, 0, f_scale/2] 90 | 91 | if draw_frame_axis: 92 | return [X_img_plane, X_triangle, X_center1, X_center2, X_center3, X_center4, X_frame1, X_frame2, X_frame3] 93 | else: 94 | return [X_img_plane, X_triangle, X_center1, X_center2, X_center3, X_center4] 95 | 96 | 97 | def create_board_model(extrinsics, board_width, board_height, square_size, draw_frame_axis=False): 98 | width = board_width*square_size 99 | height = board_height*square_size 100 | 101 | # draw calibration board 102 | X_board = np.ones((4, 5)) 103 | #X_board_cam = np.ones((extrinsics.shape[0],4,5)) 104 | X_board[0:3, 0] = [0, 0, 0] 105 | X_board[0:3, 1] = [width, 0, 0] 106 | X_board[0:3, 2] = [width, height, 0] 107 | X_board[0:3, 3] = [0, height, 0] 108 | X_board[0:3, 4] = [0, 0, 0] 109 | 110 | # draw board frame axis 111 | X_frame1 = np.ones((4, 2)) 112 | X_frame1[0:3, 0] = [0, 0, 0] 113 | X_frame1[0:3, 1] = [height/2, 0, 0] 114 | 115 | X_frame2 = np.ones((4, 2)) 116 | X_frame2[0:3, 0] = [0, 0, 0] 117 | X_frame2[0:3, 1] = [0, height/2, 0] 118 | 119 | X_frame3 = np.ones((4, 2)) 120 | X_frame3[0:3, 0] = [0, 0, 0] 121 | X_frame3[0:3, 1] = [0, 0, height/2] 122 | 123 | if draw_frame_axis: 124 | return [X_board, X_frame1, X_frame2, X_frame3] 125 | else: 126 | return [X_board] 127 | 128 | 129 | def draw_camera(ax, camera_matrix, cam_width, cam_height, scale_focal, 130 | extrinsics, 131 | patternCentric=True, 132 | annotation=True): 133 | from matplotlib import cm 134 | 135 | min_values = np.zeros((3, 1)) 136 | min_values = np.inf 137 | max_values = np.zeros((3, 1)) 138 | max_values = -np.inf 139 | 140 | X_moving = create_camera_model( 141 | camera_matrix, cam_width, cam_height, scale_focal) 142 | 143 | cm_subsection = linspace(0.0, 1.0, extrinsics.shape[0]) 144 | colors = [cm.jet(x) for x in cm_subsection] 145 | 146 | for idx in range(extrinsics.shape[0]): 147 | # R, _ = cv.Rodrigues(extrinsics[idx,0:3]) 148 | # cMo = np.eye(4,4) 149 | # cMo[0:3,0:3] = R 150 | # cMo[0:3,3] = extrinsics[idx,3:6] 151 | cMo = extrinsics[idx] 152 | for i in range(len(X_moving)): 153 | X = np.zeros(X_moving[i].shape) 154 | for j in range(X_moving[i].shape[1]): 155 | X[0:4, j] = transform_to_matplotlib_frame( 156 | cMo, X_moving[i][0:4, j], patternCentric) 157 | ax.plot3D(X[0, :], X[1, :], X[2, :], color=colors[idx]) 158 | min_values = np.minimum(min_values, X[0:3, :].min(1)) 159 | max_values = np.maximum(max_values, X[0:3, :].max(1)) 160 | # modified: add an annotation of number 161 | if annotation: 162 | X = transform_to_matplotlib_frame( 163 | cMo, X_moving[0][0:4, 0], patternCentric) 164 | ax.text(X[0], X[1], X[2], "{}".format(idx), color=colors[idx]) 165 | 166 | return min_values, max_values 167 | 168 | 169 | def visualize(camera_matrix, extrinsics): 170 | 171 | ######################## plot params ######################## 172 | cam_width = 0.064/2 # Width/2 of the displayed camera. 173 | cam_height = 0.048/2 # Height/2 of the displayed camera. 174 | scale_focal = 40 # Value to scale the focal length. 175 | 176 | ######################## original code ######################## 177 | import matplotlib.pyplot as plt 178 | from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-variable 179 | 180 | fig = plt.figure() 181 | ax = fig.gca(projection='3d') 182 | # ax.set_aspect("equal") 183 | ax.set_aspect("auto") 184 | 185 | min_values, max_values = draw_camera(ax, camera_matrix, cam_width, cam_height, 186 | scale_focal, extrinsics, True) 187 | 188 | X_min = min_values[0] 189 | X_max = max_values[0] 190 | Y_min = min_values[1] 191 | Y_max = max_values[1] 192 | Z_min = min_values[2] 193 | Z_max = max_values[2] 194 | max_range = np.array([X_max-X_min, Y_max-Y_min, Z_max-Z_min]).max() / 2.0 195 | 196 | mid_x = (X_max+X_min) * 0.5 197 | mid_y = (Y_max+Y_min) * 0.5 198 | mid_z = (Z_max+Z_min) * 0.5 199 | ax.set_xlim(mid_x - max_range, mid_x + max_range) 200 | ax.set_ylim(mid_y - max_range, mid_y + max_range) 201 | ax.set_zlim(mid_z - max_range, mid_z + max_range) 202 | 203 | ax.set_xlabel('x') 204 | ax.set_ylabel('z') 205 | ax.set_zlabel('-y') 206 | ax.set_title('Extrinsic Parameters Visualization') 207 | 208 | plt.show() 209 | log.info('Done') 210 | 211 | 212 | if __name__ == '__main__': 213 | import argparse 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument("--scan_id", type=int, default=40) 216 | args = parser.parse_args() 217 | 218 | log.info(__doc__) 219 | # NOTE: jianfei: 20210722 newly checked. The coordinate is correct. 220 | # note that the ticks on (-y) means the opposite of y coordinates. 221 | 222 | ######################## modified: example code ######################## 223 | from dataio.DTU import SceneDataset 224 | import torch 225 | train_dataset = SceneDataset( 226 | train_cameras=False, 227 | data_dir='./data/DTU/scan{}'.format(scan_id=args.scan_id)) 228 | c2w = torch.stack(train_dataset.c2w_all).data.cpu().numpy() 229 | extrinsics = np.linalg.inv(c2w) # camera extrinsics are w2c matrix 230 | camera_matrix = next(iter(train_dataset))[1]['intrinsics'].data.cpu().numpy() 231 | 232 | 233 | # import pickle 234 | # data = pickle.load(open('./dev_test/london/london_siren_si20_cam.pt', 'rb')) 235 | # c2ws = data['c2w'] 236 | # extrinsics = np.linalg.inv(c2ws) 237 | # camera_matrix = data['intr'] 238 | visualize(camera_matrix, extrinsics) 239 | cv.destroyAllWindows() 240 | -------------------------------------------------------------------------------- /tools/vis_ray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from utils.rend_util import get_rays 5 | from dataio.DTU import SceneDataset 6 | 7 | def plot_rays(rays_o: np.ndarray, rays_d: np.ndarray, ax): 8 | # TODO: automatic reducing number of rays 9 | XYZUVW = np.concatenate([rays_o, rays_d], axis=-1) 10 | X, Y, Z, U, V, W = np.transpose(XYZUVW) 11 | # X2 = X+U 12 | # Y2 = Y+V 13 | # Z2 = Z+W 14 | # x_max = max(np.max(X), np.max(X2)) 15 | # x_min = min(np.min(X), np.min(X2)) 16 | # y_max = max(np.max(Y), np.max(Y2)) 17 | # y_min = min(np.min(Y), np.min(Y2)) 18 | # z_max = max(np.max(Z), np.max(Z2)) 19 | # z_min = min(np.min(Z), np.min(Z2)) 20 | # fig = plt.figure() 21 | # ax = fig.add_subplot(111, projection='3d') 22 | ax.quiver(X, Y, Z, U, V, W) 23 | # ax.set_xlim(x_min, x_max) 24 | # ax.set_ylim(y_min, y_max) 25 | # ax.set_zlim(z_min, z_max) 26 | 27 | return ax 28 | 29 | dataset = SceneDataset(False, './data/DTU/scan40', downscale=32) 30 | 31 | fig = plt.figure() 32 | ax = fig.add_subplot(111, projection='3d') 33 | ax.set_xlabel('x') 34 | ax.set_ylabel('y') 35 | ax.set_zlabel('z') 36 | ax.set_xlim(-2, 2) 37 | ax.set_ylim(-2, 2) 38 | ax.set_zlim(-2, 2) 39 | H, W = (dataset.H, dataset.W) 40 | 41 | for i in range(dataset.n_images): 42 | _, model_input, _ = dataset[i] 43 | intrinsics = model_input["intrinsics"][None, ...] 44 | c2w = model_input['c2w'][None, ...] 45 | rays_o, rays_d, select_inds = get_rays(c2w, intrinsics, H, W, N_rays=1) 46 | rays_o = rays_o.data.squeeze(0).cpu().numpy() 47 | rays_d = rays_d.data.squeeze(0).cpu().numpy() 48 | ax = plot_rays(rays_o, rays_d, ax) 49 | 50 | plt.show() -------------------------------------------------------------------------------- /tools/vis_surface_and_cam.py: -------------------------------------------------------------------------------- 1 | from utils import io_util 2 | from dataio import get_data 3 | 4 | import skimage 5 | import skimage.measure 6 | import numpy as np 7 | import open3d as o3d 8 | 9 | 10 | def get_camera_frustum(img_size, K, W2C, frustum_length=0.5, color=[0., 1., 0.]): 11 | W, H = img_size 12 | hfov = np.rad2deg(np.arctan(W / 2. / K[0, 0]) * 2.) 13 | vfov = np.rad2deg(np.arctan(H / 2. / K[1, 1]) * 2.) 14 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 15 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 16 | 17 | # build view frustum for camera (I, 0) 18 | frustum_points = np.array([[0., 0., 0.], # frustum origin 19 | [-half_w, -half_h, frustum_length], # top-left image corner 20 | [half_w, -half_h, frustum_length], # top-right image corner 21 | [half_w, half_h, frustum_length], # bottom-right image corner 22 | [-half_w, half_h, frustum_length]]) # bottom-left image corner 23 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) 24 | frustum_colors = np.tile(np.array(color).reshape((1, 3)), (frustum_lines.shape[0], 1)) 25 | 26 | # frustum_colors = np.vstack((np.tile(np.array([[1., 0., 0.]]), (4, 1)), 27 | # np.tile(np.array([[0., 1., 0.]]), (4, 1)))) 28 | 29 | # transform view frustum from (I, 0) to (R, t) 30 | C2W = np.linalg.inv(W2C) 31 | frustum_points = np.dot(np.hstack((frustum_points, np.ones_like(frustum_points[:, 0:1]))), C2W.T) 32 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] 33 | 34 | return frustum_points, frustum_lines, frustum_colors 35 | 36 | def frustums2lineset(frustums): 37 | N = len(frustums) 38 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum 39 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum 40 | merged_colors = np.zeros((N*8, 3)) # each line gets a color 41 | 42 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums): 43 | merged_points[i*5:(i+1)*5, :] = frustum_points 44 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5 45 | merged_colors[i*8:(i+1)*8, :] = frustum_colors 46 | 47 | lineset = o3d.geometry.LineSet() 48 | lineset.points = o3d.utility.Vector3dVector(merged_points) 49 | lineset.lines = o3d.utility.Vector2iVector(merged_lines) 50 | lineset.colors = o3d.utility.Vector3dVector(merged_colors) 51 | 52 | return lineset 53 | 54 | 55 | # ---------------------- 56 | # plot cameras alongside with mesh 57 | # modified from NeRF++. https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/extract_sfm.py 58 | def visualize_cameras(colored_camera_dicts, sphere_radius, camera_size=0.1, geometry_file=None, geometry_type='mesh', backface=False): 59 | things_to_draw = [] 60 | 61 | if sphere_radius > 0: 62 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius, resolution=10) 63 | sphere = o3d.geometry.LineSet.create_from_triangle_mesh(sphere) 64 | sphere.paint_uniform_color((1, 0, 0)) 65 | things_to_draw.append(sphere) 66 | 67 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0., 0., 0.]) 68 | things_to_draw.append(coord_frame) 69 | 70 | idx = 0 71 | for camera_dict in colored_camera_dicts: 72 | idx += 1 73 | 74 | K = np.array(camera_dict['K']).reshape((4, 4)) 75 | W2C = np.array(camera_dict['W2C']).reshape((4, 4)) 76 | C2W = np.linalg.inv(W2C) 77 | img_size = camera_dict['img_size'] 78 | color = camera_dict['color'] 79 | frustums = [get_camera_frustum(img_size, K, W2C, frustum_length=camera_size, color=color)] 80 | cameras = frustums2lineset(frustums) 81 | things_to_draw.append(cameras) 82 | 83 | if geometry_file is not None: 84 | if geometry_type == 'mesh': 85 | geometry = o3d.io.read_triangle_mesh(geometry_file) 86 | geometry.compute_vertex_normals() 87 | elif geometry_type == 'pointcloud': 88 | geometry = o3d.io.read_point_cloud(geometry_file) 89 | else: 90 | raise Exception('Unknown geometry_type: ', geometry_type) 91 | 92 | things_to_draw.append(geometry) 93 | if backface: 94 | o3d.visualization.RenderOption.mesh_show_back_face = True 95 | o3d.visualization.draw_geometries(things_to_draw) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = io_util.create_args_parser() 100 | parser.add_argument("--scan_id", type=int, default=40) 101 | parser.add_argument("--mesh_file", type=str, default=None) 102 | parser.add_argument("--sphere_radius", type=float, default=3.0) 103 | parser.add_argument("--backface",action='store_true', help='render show back face') 104 | args = parser.parse_args() 105 | 106 | # load camera 107 | args, unknown = parser.parse_known_args() 108 | config = io_util.load_config(args, unknown) 109 | dataset = get_data(config) 110 | 111 | #------------- 112 | colored_camera_dicts = [] 113 | for i in range(len(dataset)): 114 | (_, model_input, ground_truth) = dataset[i] 115 | c2w = model_input['c2w'].data.cpu().numpy() 116 | intrinsics = model_input["intrinsics"].data.cpu().numpy() 117 | 118 | cam_dict = {} 119 | cam_dict['img_size'] = (dataset.W, dataset.H) 120 | cam_dict['W2C'] = np.linalg.inv(c2w) 121 | cam_dict['K'] = intrinsics 122 | # cam_dict['color'] = [0, 1, 1] 123 | cam_dict['color'] = [1, 0, 0] 124 | 125 | # if i == 0: 126 | # cam_dict['color'] = [1, 0, 0] 127 | 128 | # if i == 1: 129 | # cam_dict['color'] = [0, 1, 0] 130 | 131 | # if i == 28: 132 | # cam_dict['color'] = [1, 0, 0] 133 | 134 | colored_camera_dicts.append(cam_dict) 135 | 136 | visualize_cameras(colored_camera_dicts, args.sphere_radius, geometry_file=args.mesh_file, backface=args.backface) 137 | 138 | 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from models.frameworks import get_model 2 | from models.base import get_optimizer, get_scheduler 3 | from utils import rend_util, train_util, mesh_util, io_util 4 | from utils.dist_util import get_local_rank, init_env, is_master, get_rank, get_world_size 5 | from utils.print_fn import log 6 | from utils.logger import Logger 7 | from utils.checkpoints import CheckpointIO 8 | from dataio import get_data 9 | 10 | import os 11 | import sys 12 | import time 13 | import functools 14 | from tqdm import tqdm 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | import torch.distributed as dist 19 | from torch.utils.data.dataloader import DataLoader 20 | from torch.utils.data.distributed import DistributedSampler 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | 23 | 24 | def main_function(args): 25 | 26 | init_env(args) 27 | 28 | #---------------------------- 29 | #-------- shortcuts --------- 30 | rank = get_rank() 31 | local_rank = get_local_rank() 32 | world_size = get_world_size() 33 | i_backup = int(args.training.i_backup // world_size) if args.training.i_backup > 0 else -1 34 | i_val = int(args.training.i_val // world_size) if args.training.i_val > 0 else -1 35 | i_val_mesh = int(args.training.i_val_mesh // world_size) if args.training.i_val_mesh > 0 else -1 36 | special_i_val_mesh = [int(i // world_size) for i in [3000, 5000, 7000]] 37 | exp_dir = args.training.exp_dir 38 | mesh_dir = os.path.join(exp_dir, 'meshes') 39 | 40 | device = torch.device('cuda', local_rank) 41 | 42 | 43 | # logger 44 | logger = Logger( 45 | log_dir=exp_dir, 46 | img_dir=os.path.join(exp_dir, 'imgs'), 47 | monitoring=args.training.get('monitoring', 'tensorboard'), 48 | monitoring_dir=os.path.join(exp_dir, 'events'), 49 | rank=rank, is_master=is_master(), multi_process_logging=(world_size > 1)) 50 | 51 | log.info("=> Experiments dir: {}".format(exp_dir)) 52 | 53 | if is_master(): 54 | # backup codes 55 | io_util.backup(os.path.join(exp_dir, 'backup')) 56 | 57 | # save configs 58 | io_util.save_config(args, os.path.join(exp_dir, 'config.yaml')) 59 | 60 | dataset, val_dataset = get_data(args, return_val=True, val_downscale=args.data.get('val_downscale', 4.0)) 61 | bs = args.data.get('batch_size', None) 62 | if args.ddp: 63 | train_sampler = DistributedSampler(dataset) 64 | dataloader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, batch_size=bs) 65 | val_sampler = DistributedSampler(val_dataset) 66 | valloader = torch.utils.data.DataLoader(val_dataset, sampler=val_sampler, batch_size=bs) 67 | else: 68 | dataloader = DataLoader(dataset, 69 | batch_size=bs, 70 | shuffle=True, 71 | pin_memory=args.data.get('pin_memory', False)) 72 | valloader = DataLoader(val_dataset, 73 | batch_size=1, 74 | shuffle=True) 75 | 76 | # Create model 77 | model, trainer, render_kwargs_train, render_kwargs_test, volume_render_fn = get_model(args) 78 | model.to(device) 79 | log.info(model) 80 | log.info("=> Nerf params: " + str(train_util.count_trainable_parameters(model))) 81 | 82 | render_kwargs_train['H'] = dataset.H 83 | render_kwargs_train['W'] = dataset.W 84 | render_kwargs_test['H'] = val_dataset.H 85 | render_kwargs_test['W'] = val_dataset.W 86 | 87 | # build optimizer 88 | optimizer = get_optimizer(args, model) 89 | 90 | # checkpoints 91 | checkpoint_io = CheckpointIO(checkpoint_dir=os.path.join(exp_dir, 'ckpts'), allow_mkdir=is_master()) 92 | if world_size > 1: 93 | dist.barrier() 94 | # Register modules to checkpoint 95 | checkpoint_io.register_modules( 96 | model=model, 97 | optimizer=optimizer, 98 | ) 99 | 100 | # Load checkpoints 101 | load_dict = checkpoint_io.load_file( 102 | args.training.ckpt_file, 103 | ignore_keys=args.training.ckpt_ignore_keys, 104 | only_use_keys=args.training.ckpt_only_use_keys, 105 | map_location=device) 106 | 107 | logger.load_stats('stats.p') # this will be used for plotting 108 | it = load_dict.get('global_step', 0) 109 | epoch_idx = load_dict.get('epoch_idx', 0) 110 | 111 | # pretrain if needed. must be after load state_dict, since needs 'is_pretrained' variable to be loaded. 112 | #--------------------------------------------- 113 | #-------- init perparation only done in master 114 | #--------------------------------------------- 115 | if is_master(): 116 | pretrain_config = {'logger': logger} 117 | if 'lr_pretrain' in args.training: 118 | pretrain_config['lr'] = args.training.lr_pretrain 119 | if(model.implicit_surface.pretrain_hook(pretrain_config)): 120 | checkpoint_io.save(filename='latest.pt'.format(it), global_step=it, epoch_idx=epoch_idx) 121 | 122 | # Parallel training 123 | if args.ddp: 124 | trainer = DDP(trainer, device_ids=args.device_ids, output_device=local_rank, find_unused_parameters=False) 125 | 126 | # build scheduler 127 | scheduler = get_scheduler(args, optimizer, last_epoch=it-1) 128 | t0 = time.time() 129 | log.info('=> Start training..., it={}, lr={}, in {}'.format(it, optimizer.param_groups[0]['lr'], exp_dir)) 130 | end = (it >= args.training.num_iters) 131 | with tqdm(range(args.training.num_iters), disable=not is_master()) as pbar: 132 | if is_master(): 133 | pbar.update(it) 134 | while it <= args.training.num_iters and not end: 135 | try: 136 | if args.ddp: 137 | train_sampler.set_epoch(epoch_idx) 138 | for (indices, model_input, ground_truth) in dataloader: 139 | int_it = int(it // world_size) 140 | #------------------- 141 | # validate 142 | #------------------- 143 | if i_val > 0 and int_it % i_val == 0: 144 | with torch.no_grad(): 145 | (val_ind, val_in, val_gt) = next(iter(valloader)) 146 | 147 | intrinsics = val_in["intrinsics"].to(device) 148 | c2w = val_in['c2w'].to(device) 149 | 150 | # N_rays=-1 for rendering full image 151 | rays_o, rays_d, select_inds = rend_util.get_rays( 152 | c2w, intrinsics, render_kwargs_test['H'], render_kwargs_test['W'], N_rays=-1) 153 | target_rgb = val_gt['rgb'].to(device) 154 | rgb, depth_v, ret = volume_render_fn(rays_o, rays_d, calc_normal=True, detailed_output=True, **render_kwargs_test) 155 | 156 | to_img = functools.partial( 157 | rend_util.lin2img, 158 | H=render_kwargs_test['H'], W=render_kwargs_test['W'], 159 | batched=render_kwargs_test['batched']) 160 | logger.add_imgs(to_img(target_rgb), 'val/gt_rgb', it) 161 | logger.add_imgs(to_img(rgb), 'val/predicted_rgb', it) 162 | logger.add_imgs(to_img((depth_v/(depth_v.max()+1e-10)).unsqueeze(-1)), 'val/pred_depth_volume', it) 163 | logger.add_imgs(to_img(ret['mask_volume'].unsqueeze(-1)), 'val/pred_mask_volume', it) 164 | if 'depth_surface' in ret: 165 | logger.add_imgs(to_img((ret['depth_surface']/ret['depth_surface'].max()).unsqueeze(-1)), 'val/pred_depth_surface', it) 166 | if 'mask_surface' in ret: 167 | logger.add_imgs(to_img(ret['mask_surface'].unsqueeze(-1).float()), 'val/predicted_mask', it) 168 | if hasattr(trainer, 'val'): 169 | trainer.val(logger, ret, to_img, it, render_kwargs_test) 170 | 171 | logger.add_imgs(to_img(ret['normals_volume']/2.+0.5), 'val/predicted_normals', it) 172 | 173 | #------------------- 174 | # validate mesh 175 | #------------------- 176 | if is_master(): 177 | # NOTE: not validating mesh before 3k, as some of the instances of DTU for NeuS training will have no large enough mesh at the beginning. 178 | if i_val_mesh > 0 and (int_it % i_val_mesh == 0 or int_it in special_i_val_mesh) and it != 0: 179 | with torch.no_grad(): 180 | io_util.cond_mkdir(mesh_dir) 181 | mesh_util.extract_mesh( 182 | model.implicit_surface, 183 | filepath=os.path.join(mesh_dir, '{:08d}.ply'.format(it)), 184 | volume_size=args.data.get('volume_size', 2.0), 185 | show_progress=is_master()) 186 | 187 | if it >= args.training.num_iters: 188 | end = True 189 | break 190 | 191 | #------------------- 192 | # train 193 | #------------------- 194 | start_time = time.time() 195 | ret = trainer.forward(args, indices, model_input, ground_truth, render_kwargs_train, it) 196 | 197 | losses = ret['losses'] 198 | extras = ret['extras'] 199 | 200 | for k, v in losses.items(): 201 | # log.info("{}:{} - > {}".format(k, v.shape, v.mean().shape)) 202 | losses[k] = torch.mean(v) 203 | 204 | optimizer.zero_grad() 205 | losses['total'].backward() 206 | # NOTE: check grad before optimizer.step() 207 | if True: 208 | grad_norms = train_util.calc_grad_norm(model=model) 209 | optimizer.step() 210 | scheduler.step(it) # NOTE: important! when world_size is not 1 211 | 212 | #------------------- 213 | # logging 214 | #------------------- 215 | # done every i_save seconds 216 | if (args.training.i_save > 0) and (time.time() - t0 > args.training.i_save): 217 | if is_master(): 218 | checkpoint_io.save(filename='latest.pt', global_step=it, epoch_idx=epoch_idx) 219 | # this will be used for plotting 220 | logger.save_stats('stats.p') 221 | t0 = time.time() 222 | 223 | if is_master(): 224 | #---------------------------------------------------------------------------- 225 | #------------------- things only done in master ----------------------------- 226 | #---------------------------------------------------------------------------- 227 | pbar.set_postfix(lr=optimizer.param_groups[0]['lr'], loss_total=losses['total'].item(), loss_img=losses['loss_img'].item()) 228 | 229 | if i_backup > 0 and int_it % i_backup == 0 and it > 0: 230 | checkpoint_io.save(filename='{:08d}.pt'.format(it), global_step=it, epoch_idx=epoch_idx) 231 | 232 | #---------------------------------------------------------------------------- 233 | #------------------- things done in every child process --------------------------- 234 | #---------------------------------------------------------------------------- 235 | 236 | #------------------- 237 | # log grads and learning rate 238 | for k, v in grad_norms.items(): 239 | logger.add('grad', k, v, it) 240 | logger.add('learning rates', 'whole', optimizer.param_groups[0]['lr'], it) 241 | 242 | #------------------- 243 | # log losses 244 | for k, v in losses.items(): 245 | logger.add('losses', k, v.data.cpu().numpy().item(), it) 246 | 247 | #------------------- 248 | # log extras 249 | names = ["radiance", "alpha", "implicit_surface", "implicit_nablas_norm", "sigma_out", "radiance_out"] 250 | for n in names: 251 | p = "whole" 252 | # key = "raw.{}".format(n) 253 | key = n 254 | if key in extras: 255 | logger.add("extras_{}".format(n), "{}.mean".format(p), extras[key].mean().data.cpu().numpy().item(), it) 256 | logger.add("extras_{}".format(n), "{}.min".format(p), extras[key].min().data.cpu().numpy().item(), it) 257 | logger.add("extras_{}".format(n), "{}.max".format(p), extras[key].max().data.cpu().numpy().item(), it) 258 | logger.add("extras_{}".format(n), "{}.norm".format(p), extras[key].norm().data.cpu().numpy().item(), it) 259 | if 'scalars' in extras: 260 | for k, v in extras['scalars'].items(): 261 | logger.add('scalars', k, v.mean(), it) 262 | 263 | #--------------------- 264 | # end of one iteration 265 | end_time = time.time() 266 | log.debug("=> One iteration time is {:.2f}".format(end_time - start_time)) 267 | 268 | it += world_size 269 | if is_master(): 270 | pbar.update(world_size) 271 | #--------------------- 272 | # end of one epoch 273 | epoch_idx += 1 274 | 275 | except KeyboardInterrupt: 276 | if is_master(): 277 | checkpoint_io.save(filename='latest.pt'.format(it), global_step=it, epoch_idx=epoch_idx) 278 | # this will be used for plotting 279 | logger.save_stats('stats.p') 280 | sys.exit() 281 | 282 | if is_master(): 283 | checkpoint_io.save(filename='final_{:08d}.pt'.format(it), global_step=it, epoch_idx=epoch_idx) 284 | logger.save_stats('stats.p') 285 | log.info("Everything done.") 286 | 287 | if __name__ == "__main__": 288 | # Arguments 289 | parser = io_util.create_args_parser() 290 | parser.add_argument("--ddp", action='store_true', help='whether to use DDP to train.') 291 | parser.add_argument("--port", type=int, default=None, help='master port for multi processing. (if used)') 292 | args, unknown = parser.parse_known_args() 293 | config = io_util.load_config(args, unknown) 294 | main_function(config) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ventusff/neurecon/972e810ec252cfd16f630b1de6d2802d1b8de59a/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | from utils.print_fn import log 2 | 3 | import os 4 | import urllib 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | # torch.autograd.set_detect_anomaly(True) 9 | 10 | class CheckpointIO(object): 11 | ''' CheckpointIO class. 12 | 13 | modified from https://github.com/LMescheder/GAN_stability/blob/master/gan_training/checkpoints.py 14 | 15 | It handles saving and loading checkpoints. 16 | 17 | Args: 18 | checkpoint_dir (str): path where checkpoints are saved 19 | ''' 20 | 21 | def __init__(self, checkpoint_dir='./chkpts', allow_mkdir=True, **kwargs): 22 | self.module_dict = kwargs 23 | self.checkpoint_dir = checkpoint_dir 24 | if allow_mkdir: 25 | if not os.path.exists(checkpoint_dir): 26 | os.makedirs(checkpoint_dir) 27 | 28 | def register_modules(self, **kwargs): 29 | ''' Registers modules in current module dictionary. 30 | ''' 31 | self.module_dict.update(kwargs) 32 | 33 | def save(self, filename, **kwargs): 34 | ''' Saves the current module dictionary. 35 | 36 | Args: 37 | filename (str): name of output file 38 | ''' 39 | if not os.path.isabs(filename): 40 | filename = os.path.join(self.checkpoint_dir, filename) 41 | log.info("=> Saving ckpt to {}".format(filename)) 42 | outdict = kwargs 43 | for k, v in self.module_dict.items(): 44 | outdict[k] = v.state_dict() 45 | torch.save(outdict, filename) 46 | log.info("Done.") 47 | 48 | def load(self, filename): 49 | '''Loads a module dictionary from local file or url. 50 | 51 | Args: 52 | filename (str): name of saved module dictionary 53 | ''' 54 | if is_url(filename): 55 | return self.load_url(filename) 56 | else: 57 | return self.load_file(filename) 58 | 59 | def load_file(self, filepath, no_reload=False, ignore_keys=[], only_use_keys=None, map_location='cuda'): 60 | '''Loads a module dictionary from file. 61 | 62 | Args: 63 | filepath (str): file path of saved module dictionary 64 | ''' 65 | 66 | assert not ((len(ignore_keys) > 0) and only_use_keys is not None), \ 67 | 'please specify at most one in [ckpt_ignore_keys, ckpt_only_use_keys]' 68 | 69 | if filepath is not None and filepath != "None": 70 | ckpts = [filepath] 71 | else: 72 | ckpts = sorted_ckpts(self.checkpoint_dir) 73 | 74 | log.info("=> Found ckpts: " + 75 | "{}".format(ckpts) if len(ckpts) < 5 else "...,{}".format(ckpts[-5:])) 76 | 77 | if len(ckpts) > 0 and not no_reload: 78 | ckpt_file = ckpts[-1] 79 | log.info('=> Loading checkpoint from local file: ' + str(ckpt_file)) 80 | state_dict = torch.load(ckpt_file, map_location=map_location) 81 | 82 | if len(ignore_keys) > 0: 83 | to_load_state_dict = {} 84 | for k in state_dict.keys(): 85 | if k in ignore_keys: 86 | log.info("=> [ckpt] Ignoring keys: {}".format(k)) 87 | else: 88 | to_load_state_dict[k] = state_dict[k] 89 | elif only_use_keys is not None: 90 | if not isinstance(only_use_keys, list): 91 | only_use_keys = [only_use_keys] 92 | to_load_state_dict = {} 93 | for k in only_use_keys: 94 | log.info("=> [ckpt] Only use keys: {}".format(k)) 95 | to_load_state_dict[k] = state_dict[k] 96 | else: 97 | to_load_state_dict = state_dict 98 | 99 | scalars = self.parse_state_dict(to_load_state_dict, ignore_keys) 100 | return scalars 101 | else: 102 | return {} 103 | 104 | def load_url(self, url): 105 | '''Load a module dictionary from url. 106 | 107 | Args: 108 | url (str): url to saved model 109 | ''' 110 | log.info(url) 111 | log.info('=> Loading checkpoint from url...') 112 | state_dict = model_zoo.load_url(url, progress=True) 113 | scalars = self.parse_state_dict(state_dict) 114 | return scalars 115 | 116 | def parse_state_dict(self, state_dict, ignore_keys): 117 | '''Parse state_dict of model and return scalars. 118 | 119 | Args: 120 | state_dict (dict): State dict of model 121 | ''' 122 | 123 | for k, v in self.module_dict.items(): 124 | if k in state_dict: 125 | v.load_state_dict(state_dict[k]) 126 | else: 127 | if k not in ignore_keys: 128 | log.info('Warning: Could not find %s in checkpoint!' % k) 129 | scalars = {k: v for k, v in state_dict.items() 130 | if k not in self.module_dict} 131 | return scalars 132 | 133 | 134 | def is_url(url): 135 | scheme = urllib.parse.urlparse(url).scheme 136 | return scheme in ('http', 'https') 137 | 138 | 139 | def sorted_ckpts(checkpoint_dir): 140 | ckpts = [] 141 | if os.path.exists(checkpoint_dir): 142 | latest = None 143 | final = None 144 | ckpts = [] 145 | for fname in sorted(os.listdir(checkpoint_dir)): 146 | fpath = os.path.join(checkpoint_dir, fname) 147 | if ".pt" in fname: 148 | ckpts.append(fpath) 149 | if 'latest' in fname: 150 | latest = fpath 151 | elif 'final' in fname: 152 | final = fpath 153 | if latest is not None: 154 | ckpts.remove(latest) 155 | ckpts.append(latest) 156 | if final is not None: 157 | ckpts.remove(final) 158 | ckpts.append(final) 159 | return ckpts -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from typing import Optional 6 | 7 | import torch.distributed as dist 8 | 9 | rank = 0 # process id, for IPC 10 | local_rank = 0 # local GPU device id 11 | world_size = 1 # number of processes 12 | 13 | def init_env(args): 14 | global rank, local_rank, world_size 15 | if args.ddp: 16 | #------------- multi process running, using DDP 17 | if 'SLURM_PROCID' in os.environ: 18 | #--------- for SLURM 19 | slurm_initialize('nccl', port=args.port) 20 | else: 21 | #--------- for torch.distributed.launch 22 | dist.init_process_group(backend='nccl') 23 | 24 | rank = int(os.environ['RANK']) 25 | local_rank = int(os.environ['LOCAL_RANK']) 26 | world_size = int(os.environ['WORLD_SIZE']) 27 | torch.cuda.set_device(local_rank) 28 | args.device_ids = [local_rank] 29 | print("=> Init Env @ DDP: rank={}, world_size={}, local_rank={}.\n\tdevice_ids set to {}".format(rank, world_size, local_rank, args.device_ids)) 30 | # NOTE: important! 31 | else: 32 | #------------- single process running, using single GPU or DataParallel 33 | # torch.cuda.set_device(args.device_ids[0]) 34 | print("=> Init Env @ single process: use device_ids = {}".format(args.device_ids)) 35 | rank = 0 36 | local_rank = args.device_ids[0] 37 | world_size = 1 38 | torch.cuda.set_device(args.device_ids[0]) 39 | set_seed(42) 40 | 41 | 42 | def slurm_initialize(backend='nccl', port: Optional[int] = None): 43 | proc_id = int(os.environ['SLURM_PROCID']) 44 | ntasks = int(os.environ['SLURM_NTASKS']) 45 | node_list = os.environ['SLURM_NODELIST'] 46 | if '[' in node_list: 47 | beg = node_list.find('[') 48 | pos1 = node_list.find('-', beg) 49 | if pos1 < 0: 50 | pos1 = 1000 51 | pos2 = node_list.find(',', beg) 52 | if pos2 < 0: 53 | pos2 = 1000 54 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 55 | addr = node_list[8:].replace('-', '.') 56 | if port is not None: 57 | os.environ['MASTER_PORT'] = str(port) 58 | elif 'MASTER_PORT' not in os.environ: 59 | os.environ["MASTER_PORT"] = "13333" 60 | os.environ['MASTER_ADDR'] = addr 61 | os.environ['WORLD_SIZE'] = str(ntasks) 62 | os.environ['RANK'] = str(proc_id) 63 | if backend == 'nccl': 64 | dist.init_process_group(backend='nccl') 65 | else: 66 | dist.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks) 67 | rank = dist.get_rank() 68 | device = rank % torch.cuda.device_count() 69 | torch.cuda.set_device(device) 70 | os.environ['LOCAL_RANK'] = str(device) 71 | 72 | 73 | def set_seed(seed): 74 | random.seed(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | 79 | 80 | def is_master(): 81 | return rank == 0 82 | 83 | def get_rank(): 84 | return int(os.environ.get('SLURM_PROCID', rank)) 85 | 86 | def get_local_rank(): 87 | return int(os.environ.get('LOCAL_RANK', local_rank)) 88 | 89 | def get_world_size(): 90 | return int(os.environ.get('SLURM_NTASKS', world_size)) -------------------------------------------------------------------------------- /utils/io_util.py: -------------------------------------------------------------------------------- 1 | from utils.print_fn import log 2 | 3 | import os 4 | import copy 5 | import yaml 6 | import glob 7 | import addict 8 | import shutil 9 | import imageio 10 | import argparse 11 | import functools 12 | import numpy as np 13 | 14 | 15 | import torch 16 | import skimage 17 | from skimage.transform import rescale 18 | 19 | def glob_imgs(path): 20 | imgs = [] 21 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']: 22 | imgs.extend(glob.glob(os.path.join(path, ext))) 23 | return imgs 24 | 25 | # def find_files(dir, exts=['*.png', '*.jpg']): 26 | # if os.path.isdir(dir): 27 | # # types should be ['*.png', '*.jpg'] 28 | # files_grabbed = [] 29 | # for ext in exts: 30 | # files_grabbed.extend(glob.glob(os.path.join(dir, ext))) 31 | # if len(files_grabbed) > 0: 32 | # files_grabbed = sorted(files_grabbed) 33 | # return files_grabbed 34 | # else: 35 | # return [] 36 | 37 | def load_rgb(path, downscale=1): 38 | img = imageio.imread(path) 39 | img = skimage.img_as_float32(img) 40 | if downscale != 1: 41 | img = rescale(img, 1./downscale, anti_aliasing=False, multichannel=True) 42 | 43 | # NOTE: pixel values between [-1,1] 44 | # img -= 0.5 45 | # img *= 2. 46 | img = img.transpose(2, 0, 1) 47 | return img 48 | 49 | def load_mask(path, downscale=1): 50 | alpha = imageio.imread(path, as_gray=True) 51 | alpha = skimage.img_as_float32(alpha) 52 | if downscale != 1: 53 | alpha = rescale(alpha, 1./downscale, anti_aliasing=False, multichannel=False) 54 | object_mask = alpha > 127.5 55 | 56 | return object_mask 57 | 58 | def partialclass(cls, *args, **kwds): 59 | class NewCls(cls): 60 | __init__ = functools.partialmethod(cls.__init__, *args, **kwds) 61 | 62 | NewCls.__name__ = cls.__name__ # to preserve old class name. 63 | 64 | return NewCls 65 | 66 | 67 | def cond_mkdir(path): 68 | if not os.path.exists(path): 69 | os.makedirs(path) 70 | 71 | 72 | def backup(backup_dir): 73 | """ automatic backup codes 74 | """ 75 | log.info("=> Backing up... ") 76 | special_files_to_copy = [] 77 | filetypes_to_copy = [".py"] 78 | subdirs_to_copy = ["", "dataio/", "models/", "tools/", "debug_tools/", "utils/"] 79 | 80 | this_dir = "./" # TODO 81 | cond_mkdir(backup_dir) 82 | # special files 83 | [ 84 | cond_mkdir(os.path.join(backup_dir, os.path.split(file)[0])) 85 | for file in special_files_to_copy 86 | ] 87 | [ 88 | shutil.copyfile( 89 | os.path.join(this_dir, file), os.path.join(backup_dir, file) 90 | ) 91 | for file in special_files_to_copy 92 | ] 93 | # dirs 94 | for subdir in subdirs_to_copy: 95 | cond_mkdir(os.path.join(backup_dir, subdir)) 96 | files = os.listdir(os.path.join(this_dir, subdir)) 97 | files = [ 98 | file 99 | for file in files 100 | if os.path.isfile(os.path.join(this_dir, subdir, file)) 101 | and file[file.rfind("."):] in filetypes_to_copy 102 | ] 103 | [ 104 | shutil.copyfile( 105 | os.path.join(this_dir, subdir, file), 106 | os.path.join(backup_dir, subdir, file), 107 | ) 108 | for file in files 109 | ] 110 | log.info("done.") 111 | 112 | 113 | def save_video(imgs, fname, as_gif=False, fps=24, quality=8, already_np=False, gif_scale:int =512): 114 | """[summary] 115 | 116 | Args: 117 | imgs ([type]): [0 to 1] 118 | fname ([type]): [description] 119 | as_gif (bool, optional): [description]. Defaults to False. 120 | fps (int, optional): [description]. Defaults to 24. 121 | quality (int, optional): [description]. Defaults to 8. 122 | """ 123 | gif_scale = int(gif_scale) 124 | # convert to np.uint8 125 | if not already_np: 126 | imgs = (255 * np.clip( 127 | imgs.permute(0, 2, 3, 1).detach().cpu().numpy(), 0, 1))\ 128 | .astype(np.uint8) 129 | imageio.mimwrite(fname, imgs, fps=fps, quality=quality) 130 | 131 | if as_gif: # save as gif, too 132 | os.system(f'ffmpeg -i {fname} -r 15 ' 133 | f'-vf "scale={gif_scale}:-1,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" {os.path.splitext(fname)[0] + ".gif"}') 134 | 135 | 136 | def gallery(array, ncols=3): 137 | nindex, height, width, intensity = array.shape 138 | nrows = nindex//ncols 139 | # assert nindex == nrows*ncols 140 | if nindex > nrows*ncols: 141 | nrows += 1 142 | array = np.concatenate([array, np.zeros([nrows*ncols-nindex, height, width, intensity])]) 143 | # want result.shape = (height*nrows, width*ncols, intensity) 144 | result = (array.reshape(nrows, ncols, height, width, intensity) 145 | .swapaxes(1,2) 146 | .reshape(height*nrows, width*ncols, intensity)) 147 | return result 148 | 149 | 150 | # modified from tensorboardX https://github.com/lanpa/tensorboardX 151 | def figure_to_image(figures, close=True): 152 | """Render matplotlib figure to numpy format. 153 | 154 | Note that this requires the ``matplotlib`` package. 155 | 156 | Args: 157 | figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures 158 | close (bool): Flag to automatically close the figure 159 | 160 | Returns: 161 | numpy.array: image in [CHW] order 162 | """ 163 | import numpy as np 164 | try: 165 | import matplotlib.pyplot as plt 166 | import matplotlib.backends.backend_agg as plt_backend_agg 167 | except ModuleNotFoundError: 168 | print('please install matplotlib') 169 | 170 | def render_to_rgb(figure): 171 | canvas = plt_backend_agg.FigureCanvasAgg(figure) 172 | canvas.draw() 173 | data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) 174 | w, h = figure.canvas.get_width_height() 175 | image_hwc = data.reshape([h, w, 4])[:, :, 0:3] 176 | # image_chw = np.moveaxis(image_hwc, source=2, destination=0) 177 | if close: 178 | plt.close(figure) 179 | return image_hwc 180 | 181 | if isinstance(figures, list): 182 | images = [render_to_rgb(figure) for figure in figures] 183 | return np.stack(images) 184 | else: 185 | image = render_to_rgb(figures) 186 | return image 187 | 188 | 189 | 190 | 191 | #----------------------------- 192 | # configs 193 | #----------------------------- 194 | class ForceKeyErrorDict(addict.Dict): 195 | def __missing__(self, name): 196 | raise KeyError(name) 197 | 198 | 199 | def load_yaml(path, default_path=None): 200 | 201 | with open(path, encoding='utf8') as yaml_file: 202 | config_dict = yaml.load(yaml_file, Loader=yaml.FullLoader) 203 | config = ForceKeyErrorDict(**config_dict) 204 | 205 | if default_path is not None and path != default_path: 206 | with open(default_path, encoding='utf8') as default_yaml_file: 207 | default_config_dict = yaml.load( 208 | default_yaml_file, Loader=yaml.FullLoader) 209 | main_config = ForceKeyErrorDict(**default_config_dict) 210 | 211 | # def overwrite(output_config, update_with): 212 | # for k, v in update_with.items(): 213 | # if not isinstance(v, dict): 214 | # output_config[k] = v 215 | # else: 216 | # overwrite(output_config[k], v) 217 | # overwrite(main_config, config) 218 | 219 | # simpler solution 220 | main_config.update(config) 221 | config = main_config 222 | 223 | return config 224 | 225 | 226 | def save_config(datadict: ForceKeyErrorDict, path: str): 227 | datadict = copy.deepcopy(datadict) 228 | datadict.training.ckpt_file = None 229 | datadict.training.pop('exp_dir') 230 | with open(path, 'w', encoding='utf8') as outfile: 231 | yaml.dump(datadict.to_dict(), outfile, default_flow_style=False) 232 | 233 | 234 | def update_config(config, unknown): 235 | # update config given args 236 | for idx, arg in enumerate(unknown): 237 | if arg.startswith("--"): 238 | if (':') in arg: 239 | k1, k2 = arg.replace("--", "").split(':') 240 | argtype = type(config[k1][k2]) 241 | if argtype == bool: 242 | v = unknown[idx+1].lower() == 'true' 243 | else: 244 | if config[k1][k2] is not None: 245 | v = type(config[k1][k2])(unknown[idx+1]) 246 | else: 247 | v = unknown[idx+1] 248 | print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}') 249 | config[k1][k2] = v 250 | else: 251 | k = arg.replace('--', '') 252 | v = unknown[idx+1] 253 | argtype = type(config[k]) 254 | print(f'Changing {k} ---- {config[k]} to {v}') 255 | config[k] = v 256 | 257 | return config 258 | 259 | 260 | def create_args_parser(): 261 | parser = argparse.ArgumentParser() 262 | # standard configs 263 | parser.add_argument('--config', type=str, default=None, help='Path to config file.') 264 | parser.add_argument('--resume_dir', type=str, default=None, help='Directory of experiment to load.') 265 | return parser 266 | 267 | 268 | def load_config(args, unknown, base_config_path=None): 269 | ''' overwrite seq 270 | command line param --over--> args.config --over--> default config yaml 271 | ''' 272 | assert (args.config is not None) != (args.resume_dir is not None), "you must specify ONLY one in 'config' or 'resume_dir' " 273 | 274 | # NOTE: '--local_rank=xx' is automatically given by torch.distributed.launch (if used) 275 | # BUT: pytorch suggest to use os.environ['LOCAL_RANK'] instead, and --local_rank=xxx will be deprecated in the future. 276 | # so we are not using --local_rank at all. 277 | found_k = None 278 | for item in unknown: 279 | if 'local_rank' in item: 280 | found_k = item 281 | break 282 | if found_k is not None: 283 | unknown.remove(found_k) 284 | 285 | print("=> Parse extra configs: ", unknown) 286 | if args.resume_dir is not None: 287 | assert args.config is None, "given --config will not be used when given --resume_dir" 288 | assert '--expname' not in unknown, "given --expname with --resume_dir will lead to unexpected behavior." 289 | #--------------- 290 | # if loading from a dir, do not use base.yaml as the default; 291 | #--------------- 292 | config_path = os.path.join(args.resume_dir, 'config.yaml') 293 | config = load_yaml(config_path, default_path=None) 294 | 295 | # use configs given by command line to further overwrite current config 296 | config = update_config(config, unknown) 297 | 298 | # use the loading directory as the experiment path 299 | config.training.exp_dir = args.resume_dir 300 | print("=> Loading previous experiments in: {}".format(config.training.exp_dir)) 301 | else: 302 | #--------------- 303 | # if loading from a config file 304 | # use base.yaml as default 305 | #--------------- 306 | config = load_yaml(args.config, default_path=base_config_path) 307 | 308 | # use configs given by command line to further overwrite current config 309 | config = update_config(config, unknown) 310 | 311 | # use the expname and log_root_dir to get the experiement directory 312 | if 'exp_dir' not in config.training: 313 | config.training.exp_dir = os.path.join(config.training.log_root_dir, config.expname) 314 | 315 | # add other configs in args to config 316 | other_dict = vars(args) 317 | other_dict.pop('config') 318 | other_dict.pop('resume_dir') 319 | config.update(other_dict) 320 | 321 | if hasattr(args, 'ddp') and args.ddp: 322 | if config.device_ids != -1: 323 | print("=> Ignoring device_ids configs when using DDP. Auto set to -1.") 324 | config.device_ids = -1 325 | else: 326 | args.ddp = False 327 | # # device_ids: -1 will be parsed as using all available cuda device 328 | # # device_ids: [] will be parsed as using all available cuda device 329 | if (type(config.device_ids) == int and config.device_ids == -1) \ 330 | or (type(config.device_ids) == list and len(config.device_ids) == 0): 331 | config.device_ids = list(range(torch.cuda.device_count())) 332 | # # e.g. device_ids: 0 will be parsed as device_ids [0] 333 | elif isinstance(config.device_ids, int): 334 | config.device_ids = [config.device_ids] 335 | # # e.g. device_ids: 0,1 will be parsed as device_ids [0,1] 336 | elif isinstance(config.device_ids, str): 337 | config.device_ids = [int(m) for m in config.device_ids.split(',')] 338 | print("=> Use cuda devices: {}".format(config.device_ids)) 339 | 340 | return config -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from utils import io_util 2 | from utils.print_fn import log 3 | 4 | import os 5 | import torch 6 | import pickle 7 | import imageio 8 | import torchvision 9 | import numpy as np 10 | 11 | import torch.distributed as dist 12 | 13 | #--------------------------------------------------------------------------- 14 | #---------------------- tensorboard / image recorder ----------------------- 15 | #--------------------------------------------------------------------------- 16 | 17 | class Logger(object): 18 | """ 19 | modified from https://github.com/LMescheder/GAN_stability/blob/master/gan_training/logger.py 20 | """ 21 | def __init__(self, 22 | log_dir, 23 | img_dir, 24 | monitoring=None, 25 | monitoring_dir=None, 26 | rank=0, 27 | is_master=True, 28 | multi_process_logging=True): 29 | self.stats = dict() 30 | self.log_dir = log_dir 31 | self.img_dir = img_dir 32 | self.rank = rank 33 | self.is_master = is_master 34 | self.multi_process_logging = multi_process_logging 35 | 36 | if self.is_master: 37 | io_util.cond_mkdir(self.log_dir) 38 | io_util.cond_mkdir(self.img_dir) 39 | if self.multi_process_logging: 40 | dist.barrier() 41 | 42 | self.monitoring = None 43 | self.monitoring_dir = None 44 | 45 | # if self.is_master: 46 | 47 | # NOTE: for now, we are allowing tensorboard writting on all child processes, 48 | # as it's already nicely supported, 49 | # and the data of different events file of different processes will be automatically aggregated when visualizing. 50 | # https://discuss.pytorch.org/t/using-tensorboard-with-distributeddataparallel/102555/7 51 | if not (monitoring is None or monitoring == 'none'): 52 | self.setup_monitoring(monitoring, monitoring_dir) 53 | 54 | 55 | def setup_monitoring(self, monitoring, monitoring_dir): 56 | self.monitoring = monitoring 57 | self.monitoring_dir = monitoring_dir 58 | if monitoring == 'tensorboard': 59 | # NOTE: since torch 1.2 60 | from torch.utils.tensorboard import SummaryWriter 61 | # from tensorboardX import SummaryWriter 62 | self.tb = SummaryWriter(self.monitoring_dir) 63 | else: 64 | raise NotImplementedError('Monitoring tool "%s" not supported!' 65 | % monitoring) 66 | 67 | def add(self, category, k, v, it): 68 | if category not in self.stats: 69 | self.stats[category] = {} 70 | 71 | if k not in self.stats[category]: 72 | self.stats[category][k] = [] 73 | 74 | self.stats[category][k].append((it, v)) 75 | 76 | k_name = '%s/%s' % (category, k) 77 | if self.monitoring == 'telemetry': 78 | self.tm.metric_push_async({ 79 | 'metric': k_name, 'value': v, 'it': it 80 | }) 81 | elif self.monitoring == 'tensorboard': 82 | self.tb.add_scalar(k_name, v, it) 83 | 84 | def add_vector(self, category, k, vec, it): 85 | if category not in self.stats: 86 | self.stats[category] = {} 87 | 88 | if k not in self.stats[category]: 89 | self.stats[category][k] = [] 90 | 91 | if isinstance(vec, torch.Tensor): 92 | vec = vec.data.clone().cpu().numpy() 93 | 94 | self.stats[category][k].append((it, vec)) 95 | 96 | def add_imgs(self, imgs, class_name, it): 97 | outdir = os.path.join(self.img_dir, class_name) 98 | if self.is_master and not os.path.exists(outdir): 99 | os.makedirs(outdir) 100 | if self.multi_process_logging: 101 | dist.barrier() 102 | outfile = os.path.join(outdir, '{:08d}_{}.png'.format(it, self.rank)) 103 | 104 | # imgs = imgs / 2 + 0.5 105 | imgs = torchvision.utils.make_grid(imgs) 106 | torchvision.utils.save_image(imgs.clone(), outfile, nrow=8) 107 | 108 | if self.monitoring == 'tensorboard': 109 | self.tb.add_image(class_name, imgs, global_step=it) 110 | 111 | def add_figure(self, fig, class_name, it, save_img=True): 112 | if save_img: 113 | outdir = os.path.join(self.img_dir, class_name) 114 | if self.is_master and not os.path.exists(outdir): 115 | os.makedirs(outdir) 116 | if self.multi_process_logging: 117 | dist.barrier() 118 | outfile = os.path.join(outdir, '{:08d}_{}.png'.format(it, self.rank)) 119 | 120 | image_hwc = io_util.figure_to_image(fig) 121 | imageio.imwrite(outfile, image_hwc) 122 | if self.monitoring == 'tensorboard': 123 | if len(image_hwc.shape) == 3: 124 | image_hwc = np.array(image_hwc[None, ...]) 125 | self.tb.add_images(class_name, torch.from_numpy(image_hwc), dataformats='NHWC', global_step=it) 126 | else: 127 | if self.monitoring == 'tensorboard': 128 | self.tb.add_figure(class_name, fig, it) 129 | 130 | def add_module_param(self, module_name, module, it): 131 | if self.monitoring == 'tensorboard': 132 | for name, param in module.named_parameters(): 133 | self.tb.add_histogram("{}/{}".format(module_name, name), param.detach(), it) 134 | 135 | def get_last(self, category, k, default=0.): 136 | if category not in self.stats: 137 | return default 138 | elif k not in self.stats[category]: 139 | return default 140 | else: 141 | return self.stats[category][k][-1][1] 142 | 143 | def save_stats(self, filename): 144 | filename = os.path.join(self.log_dir, filename + '_{}'.format(self.rank)) 145 | with open(filename, 'wb') as f: 146 | pickle.dump(self.stats, f) 147 | 148 | def load_stats(self, filename): 149 | filename = os.path.join(self.log_dir, filename + '_{}'.format(self.rank)) 150 | if not os.path.exists(filename): 151 | # log.info('=> File "%s" does not exist, will create new after calling save_stats()' % filename) 152 | return 153 | 154 | try: 155 | with open(filename, 'rb') as f: 156 | self.stats = pickle.load(f) 157 | log.info("=> Load file: {}".format(filename)) 158 | except EOFError: 159 | log.info('Warning: log file corrupted!') 160 | -------------------------------------------------------------------------------- /utils/mesh_util.py: -------------------------------------------------------------------------------- 1 | from utils.print_fn import log 2 | 3 | import time 4 | import plyfile 5 | import skimage 6 | import skimage.measure 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | 13 | def convert_sigma_samples_to_ply( 14 | input_3d_sigma_array: np.ndarray, 15 | voxel_grid_origin, 16 | volume_size, 17 | ply_filename_out, 18 | level=5.0, 19 | offset=None, 20 | scale=None,): 21 | """ 22 | Convert sdf samples to .ply 23 | 24 | :param input_3d_sdf_array: a float array of shape (n,n,n) 25 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 26 | :volume_size: a list of three floats 27 | :ply_filename_out: string, path of the filename to save to 28 | 29 | This function adapted from: https://github.com/RobotLocomotion/spartan 30 | """ 31 | start_time = time.time() 32 | 33 | verts, faces, normals, values = skimage.measure.marching_cubes( 34 | input_3d_sigma_array, level=level, spacing=volume_size 35 | ) 36 | 37 | # transform from voxel coordinates to camera coordinates 38 | # note x and y are flipped in the output of marching_cubes 39 | mesh_points = np.zeros_like(verts) 40 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] 41 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] 42 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] 43 | 44 | # apply additional offset and scale 45 | if scale is not None: 46 | mesh_points = mesh_points / scale 47 | if offset is not None: 48 | mesh_points = mesh_points - offset 49 | 50 | # try writing to the ply file 51 | 52 | # mesh_points = np.matmul(mesh_points, np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])) 53 | # mesh_points = np.matmul(mesh_points, np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])) 54 | 55 | 56 | num_verts = verts.shape[0] 57 | num_faces = faces.shape[0] 58 | 59 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 60 | 61 | for i in range(0, num_verts): 62 | verts_tuple[i] = tuple(mesh_points[i, :]) 63 | 64 | faces_building = [] 65 | for i in range(0, num_faces): 66 | faces_building.append(((faces[i, :].tolist(),))) 67 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 68 | 69 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 70 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 71 | 72 | ply_data = plyfile.PlyData([el_verts, el_faces]) 73 | log.info("saving mesh to %s" % str(ply_filename_out)) 74 | ply_data.write(ply_filename_out) 75 | 76 | log.info( 77 | "converting to ply format and writing to file took {} s".format( 78 | time.time() - start_time 79 | ) 80 | ) 81 | 82 | def extract_mesh(implicit_surface, volume_size=2.0, level=0.0, N=512, filepath='./surface.ply', show_progress=True, chunk=16*1024): 83 | s = volume_size 84 | voxel_grid_origin = [-s/2., -s/2., -s/2.] 85 | volume_size = [s, s, s] 86 | 87 | overall_index = np.arange(0, N ** 3, 1).astype(np.int) 88 | xyz = np.zeros([N ** 3, 3]) 89 | 90 | # transform first 3 columns 91 | # to be the x, y, z index 92 | xyz[:, 2] = overall_index % N 93 | xyz[:, 1] = (overall_index / N) % N 94 | xyz[:, 0] = ((overall_index / N) / N) % N 95 | 96 | # transform first 3 columns 97 | # to be the x, y, z coordinate 98 | xyz[:, 0] = (xyz[:, 0] * (s/(N-1))) + voxel_grid_origin[2] 99 | xyz[:, 1] = (xyz[:, 1] * (s/(N-1))) + voxel_grid_origin[1] 100 | xyz[:, 2] = (xyz[:, 2] * (s/(N-1))) + voxel_grid_origin[0] 101 | 102 | def batchify(query_fn, inputs: torch.Tensor, chunk=chunk): 103 | out = [] 104 | for i in tqdm(range(0, inputs.shape[0], chunk), disable=not show_progress): 105 | out_i = query_fn(torch.from_numpy(inputs[i:i+chunk]).float().cuda()).data.cpu().numpy() 106 | out.append(out_i) 107 | out = np.concatenate(out, axis=0) 108 | return out 109 | 110 | out = batchify(implicit_surface.forward, xyz) 111 | out = out.reshape([N, N, N]) 112 | convert_sigma_samples_to_ply(out, voxel_grid_origin, [float(v) / N for v in volume_size], filepath, level=level) 113 | -------------------------------------------------------------------------------- /utils/print_fn.py: -------------------------------------------------------------------------------- 1 | # NOTE: this file is seperated to prevent circular import 2 | 3 | import logging 4 | import sys 5 | #--------------------------------------------------------------------------- 6 | #----------------------- logging instead of printing ----------------------- 7 | #--------------------------------------------------------------------------- 8 | 9 | logs = set() 10 | # LOGGER 11 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 12 | RESET_SEQ = "\033[0m" 13 | COLOR_SEQ = "\033[1;%dm" 14 | 15 | COLORS = { 16 | 'WARNING': YELLOW, 17 | 'INFO': WHITE, 18 | 'DEBUG': BLUE, 19 | 'CRITICAL': YELLOW, 20 | 'ERROR': RED 21 | } 22 | 23 | class ColoredFormatter(logging.Formatter): 24 | def __init__(self, msg, use_color=True): 25 | logging.Formatter.__init__(self, msg) 26 | self.use_color = use_color 27 | 28 | def format(self, record): 29 | msg = record.msg 30 | levelname = record.levelname 31 | if self.use_color and levelname in COLORS and COLORS[levelname] != WHITE: 32 | if isinstance(msg, str): 33 | msg_color = COLOR_SEQ % (30 + COLORS[levelname]) + msg + RESET_SEQ 34 | record.msg = msg_color 35 | levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ 36 | record.levelname = levelname_color 37 | return logging.Formatter.format(self, record) 38 | 39 | def init_log(name, level=logging.INFO): 40 | if (name, level) in logs: 41 | return 42 | 43 | from utils.dist_util import is_master, get_rank 44 | 45 | logs.add((name, level)) 46 | logger = logging.getLogger(name) 47 | logger.setLevel(level) 48 | ch = logging.StreamHandler(stream=sys.stdout) 49 | ch.setLevel(level) 50 | 51 | logger.addFilter(lambda record: is_master()) 52 | 53 | format_str = f'%(asctime)s-rk{get_rank()}-%(filename)s#%(lineno)d:%(message)s' 54 | formatter = ColoredFormatter(format_str) 55 | ch.setFormatter(formatter) 56 | logger.addHandler(ch) 57 | 58 | logger.propagate = False 59 | 60 | return logger 61 | 62 | 63 | log = init_log('global', logging.INFO) -------------------------------------------------------------------------------- /utils/rend_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def load_K_Rt_from_P(P): 9 | """ 10 | modified from IDR https://github.com/lioryariv/idr 11 | """ 12 | out = cv2.decomposeProjectionMatrix(P) 13 | K = out[0] 14 | R = out[1] 15 | t = out[2] 16 | 17 | K = K/K[2,2] 18 | intrinsics = np.eye(4) 19 | intrinsics[:3, :3] = K 20 | 21 | pose = np.eye(4, dtype=np.float32) 22 | pose[:3, :3] = R.transpose() 23 | pose[:3,3] = (t[:3] / t[3])[:,0] 24 | 25 | return intrinsics, pose 26 | 27 | def normalize(vec): 28 | return vec / (np.linalg.norm(vec, axis=-1, keepdims=True) + 1e-9) 29 | 30 | def view_matrix( 31 | forward: np.ndarray, 32 | up: np.ndarray, 33 | cam_location: np.ndarray): 34 | rot_z = normalize(forward) 35 | rot_x = normalize(np.cross(up, rot_z)) 36 | rot_y = normalize(np.cross(rot_z, rot_x)) 37 | mat = np.stack((rot_x, rot_y, rot_z, cam_location), axis=-1) 38 | hom_vec = np.array([[0., 0., 0., 1.]]) 39 | if len(mat.shape) > 2: 40 | hom_vec = np.tile(hom_vec, [mat.shape[0], 1, 1]) 41 | mat = np.concatenate((mat, hom_vec), axis=-2) 42 | return mat 43 | 44 | def look_at( 45 | cam_location: np.ndarray, 46 | point: np.ndarray, 47 | up=np.array([0., -1., 0.]) # openCV convention 48 | # up=np.array([0., 1., 0.]) # openGL convention 49 | ): 50 | # Cam points in positive z direction 51 | forward = normalize(point - cam_location) # openCV convention 52 | # forward = normalize(cam_location - point) # openGL convention 53 | return view_matrix(forward, up, cam_location) 54 | 55 | def rot_to_quat(R): 56 | batch_size, _,_ = R.shape 57 | q = torch.ones((batch_size, 4)).to(R.device) 58 | 59 | R00 = R[..., 0,0] 60 | R01 = R[..., 0, 1] 61 | R02 = R[..., 0, 2] 62 | R10 = R[..., 1, 0] 63 | R11 = R[..., 1, 1] 64 | R12 = R[..., 1, 2] 65 | R20 = R[..., 2, 0] 66 | R21 = R[..., 2, 1] 67 | R22 = R[..., 2, 2] 68 | 69 | q[...,0]=torch.sqrt(1.0+R00+R11+R22)/2 70 | q[..., 1]=(R21-R12)/(4*q[:,0]) 71 | q[..., 2] = (R02 - R20) / (4 * q[:, 0]) 72 | q[..., 3] = (R10 - R01) / (4 * q[:, 0]) 73 | return q 74 | 75 | 76 | def quat_to_rot(q): 77 | prefix, _ = q.shape[:-1] 78 | q = F.normalize(q, dim=-1) 79 | R = torch.ones([*prefix, 3, 3]).to(q.device) 80 | qr = q[... ,0] 81 | qi = q[..., 1] 82 | qj = q[..., 2] 83 | qk = q[..., 3] 84 | R[..., 0, 0]=1-2 * (qj**2 + qk**2) 85 | R[..., 0, 1] = 2 * (qj *qi -qk*qr) 86 | R[..., 0, 2] = 2 * (qi * qk + qr * qj) 87 | R[..., 1, 0] = 2 * (qj * qi + qk * qr) 88 | R[..., 1, 1] = 1-2 * (qi**2 + qk**2) 89 | R[..., 1, 2] = 2*(qj*qk - qi*qr) 90 | R[..., 2, 0] = 2 * (qk * qi-qj * qr) 91 | R[..., 2, 1] = 2 * (qj*qk + qi*qr) 92 | R[..., 2, 2] = 1-2 * (qi**2 + qj**2) 93 | return R 94 | 95 | def lift(x, y, z, intrinsics): 96 | device = x.device 97 | # parse intrinsics 98 | intrinsics = intrinsics.to(device) 99 | fx = intrinsics[..., 0, 0] 100 | fy = intrinsics[..., 1, 1] 101 | cx = intrinsics[..., 0, 2] 102 | cy = intrinsics[..., 1, 2] 103 | sk = intrinsics[..., 0, 1] 104 | 105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z 106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 107 | 108 | # homogeneous 109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(device)), dim=-1) 110 | 111 | 112 | def get_rays(c2w, intrinsics, H, W, N_rays=-1): 113 | device = c2w.device 114 | if c2w.shape[-1] == 7: #In case of quaternion vector representation 115 | cam_loc = c2w[..., 4:] 116 | R = quat_to_rot(c2w[...,:4]) 117 | p = torch.eye(4).repeat([*c2w.shape[0:-1],1,1]).to(device).float() 118 | p[..., :3, :3] = R 119 | p[..., :3, 3] = cam_loc 120 | else: # In case of pose matrix representation 121 | cam_loc = c2w[..., :3, 3] 122 | p = c2w 123 | 124 | prefix = p.shape[:-2] 125 | device = c2w.device 126 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) 127 | i = i.t().to(device).reshape([*[1]*len(prefix), H*W]).expand([*prefix, H*W]) 128 | j = j.t().to(device).reshape([*[1]*len(prefix), H*W]).expand([*prefix, H*W]) 129 | 130 | if N_rays > 0: 131 | N_rays = min(N_rays, H*W) 132 | # ---------- option 1: full image uniformly randomize 133 | # select_inds = torch.from_numpy( 134 | # np.random.choice(H*W, size=[*prefix, N_rays], replace=False)).to(device) 135 | # select_inds = torch.randint(0, H*W, size=[N_rays]).expand([*prefix, N_rays]).to(device) 136 | # ---------- option 2: H/W seperately randomize 137 | select_hs = torch.randint(0, H, size=[N_rays]).to(device) 138 | select_ws = torch.randint(0, W, size=[N_rays]).to(device) 139 | select_inds = select_hs * W + select_ws 140 | select_inds = select_inds.expand([*prefix, N_rays]) 141 | 142 | i = torch.gather(i, -1, select_inds) 143 | j = torch.gather(j, -1, select_inds) 144 | else: 145 | select_inds = torch.arange(H*W).to(device).expand([*prefix, H*W]) 146 | 147 | pixel_points_cam = lift(i, j, torch.ones_like(i).to(device), intrinsics=intrinsics) 148 | 149 | # permute for batch matrix product 150 | pixel_points_cam = pixel_points_cam.transpose(-1,-2) 151 | 152 | # NOTE: left-multiply. 153 | # after the above permute(), shapes of coordinates changed from [B,N,4] to [B,4,N], which ensures correct left-multiplication 154 | # p is camera 2 world matrix. 155 | if len(prefix) > 0: 156 | world_coords = torch.bmm(p, pixel_points_cam).transpose(-1, -2)[..., :3] 157 | else: 158 | world_coords = torch.mm(p, pixel_points_cam).transpose(-1, -2)[..., :3] 159 | rays_d = world_coords - cam_loc[..., None, :] 160 | # ray_dirs = F.normalize(ray_dirs, dim=2) 161 | 162 | rays_o = cam_loc[..., None, :].expand_as(rays_d) 163 | 164 | return rays_o, rays_d, select_inds 165 | 166 | 167 | def near_far_from_sphere(ray_origins: torch.Tensor, ray_directions: torch.Tensor, r = 1.0, keepdim=True): 168 | """ 169 | NOTE: modified from https://github.com/Totoro97/NeuS 170 | ray_origins: camera center's coordinate 171 | ray_directions: camera rays' directions. already normalized. 172 | """ 173 | # rayso_norm_square = torch.sum(ray_origins**2, dim=-1, keepdim=True) 174 | # NOTE: (minus) the length of the line projected from [the line from camera to sphere center] to [the line of camera rays] 175 | ray_cam_dot = torch.sum(ray_origins * ray_directions, dim=-1, keepdim=keepdim) 176 | mid = -ray_cam_dot 177 | # NOTE: a convservative approximation of the half chord length from ray intersections with the sphere. 178 | # all half chord length < r 179 | near = mid - r 180 | far = mid + r 181 | 182 | near = near.clamp_min(0.0) 183 | far = far.clamp_min(r) # NOTE: instead of clamp_min(0.0), just some trick. 184 | 185 | return near, far 186 | 187 | 188 | def get_sphere_intersection(ray_origins: torch.Tensor, ray_directions: torch.Tensor, r = 1.0): 189 | """ 190 | NOTE: modified from IDR. https://github.com/lioryariv/idr 191 | ray_origins: camera center's coordinate 192 | ray_directions: camera rays' directions. already normalized. 193 | """ 194 | rayso_norm_square = torch.sum(ray_origins**2, dim=-1, keepdim=True) 195 | # (minus) the length of the line projected from [the line from camera to sphere center] to [the line of camera rays] 196 | ray_cam_dot = torch.sum(ray_origins * ray_directions, dim=-1, keepdim=True) 197 | 198 | # accurate ray-sphere intersections 199 | near = torch.zeros([*ray_origins.shape[:-1], 1]).to(ray_origins.device) 200 | far = torch.zeros([*ray_origins.shape[:-1], 1]).to(ray_origins.device) 201 | under_sqrt = ray_cam_dot ** 2 + r ** 2 - rayso_norm_square 202 | mask_intersect = under_sqrt > 0 203 | sqrt = torch.sqrt(under_sqrt[mask_intersect]) 204 | near[mask_intersect] = - sqrt - ray_cam_dot[mask_intersect] 205 | far[mask_intersect] = sqrt - ray_cam_dot[mask_intersect] 206 | 207 | near = near.clamp_min(0.0) 208 | far = far.clamp_min(0.0) 209 | 210 | return near, far, mask_intersect 211 | 212 | 213 | def get_dvals_from_radius(ray_origins: torch.Tensor, ray_directions: torch.Tensor, rs: torch.Tensor, far_end=True): 214 | """ 215 | ray_origins: camera center's coordinate 216 | ray_directions: camera rays' directions. already normalized. 217 | rs: the distance to the origin 218 | far_end: whether the point is on the far-end of the ray or on the near-end of the ray 219 | """ 220 | rayso_norm_square = torch.sum(ray_origins**2, dim=-1, keepdim=True) 221 | # NOTE: (minus) the length of the line projected from [the line from camera to sphere center] to [the line of camera rays] 222 | ray_cam_dot = torch.sum(ray_origins * ray_directions, dim=-1, keepdim=True) 223 | 224 | under_sqrt = rs**2 - (rayso_norm_square - ray_cam_dot ** 2) 225 | assert (under_sqrt > 0).all() 226 | sqrt = torch.sqrt(under_sqrt) 227 | 228 | if far_end: 229 | d_vals = -ray_cam_dot + sqrt 230 | else: 231 | d_vals = -ray_cam_dot - sqrt 232 | d_vals = torch.clamp_min(d_vals, 0.) 233 | 234 | return d_vals 235 | 236 | 237 | def lin2img(tensor: torch.Tensor, H: int, W: int, batched=False, B=None): 238 | *_, num_samples, channels = tensor.shape 239 | assert num_samples == H * W 240 | if batched: 241 | if B is None: 242 | B = tensor.shape[0] 243 | else: 244 | tensor = tensor.view([B, num_samples//B, channels]) 245 | return tensor.permute(0, 2, 1).view([B, channels, H, W]) 246 | else: 247 | return tensor.permute(1, 0).view([channels, H, W]) 248 | 249 | 250 | #---------------------------------------------------- 251 | #-------- Sampling points from ray ------------------ 252 | #---------------------------------------------------- 253 | 254 | # Hierarchical sampling (section 5.2) 255 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): 256 | # device = weights.get_device() 257 | device = weights.device 258 | # Get pdf 259 | weights = weights + 1e-5 # prevent nans 260 | pdf = weights / torch.sum(weights, -1, keepdim=True) 261 | cdf = torch.cumsum(pdf, -1) 262 | cdf = torch.cat( 263 | [torch.zeros_like(cdf[..., :1], device=device), cdf], -1 264 | ) # (batch, len(bins)) 265 | 266 | # Take uniform samples 267 | if det: 268 | u = torch.linspace(0.0, 1.0, steps=N_importance, device=device) 269 | u = u.expand(list(cdf.shape[:-1]) + [N_importance]) 270 | else: 271 | u = torch.rand(list(cdf.shape[:-1]) + [N_importance], device=device) 272 | u = u.contiguous() 273 | 274 | # Invert CDF 275 | inds = torch.searchsorted(cdf.detach(), u, right=False) 276 | 277 | below = torch.clamp_min(inds-1, 0) 278 | above = torch.clamp_max(inds, cdf.shape[-1]-1) 279 | # (batch, N_importance, 2) ==> (B, batch, N_importance, 2) 280 | inds_g = torch.stack([below, above], -1) 281 | 282 | matched_shape = [*inds_g.shape[:-1], cdf.shape[-1]] # fix prefix shape 283 | 284 | cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), -1, inds_g) 285 | bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), -1, inds_g) # fix prefix shape 286 | 287 | denom = cdf_g[..., 1] - cdf_g[..., 0] 288 | denom[denom (B, batch, N_importance, 2) 315 | inds_g = torch.stack([below, above], -1) 316 | 317 | matched_shape = [*inds_g.shape[:-1], cdf.shape[-1]] # fix prefix shape 318 | 319 | cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), -1, inds_g) 320 | bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), -1, inds_g) # fix prefix shape 321 | 322 | denom = cdf_g[..., 1] - cdf_g[..., 0] 323 | denom[denom [(B), N_rays*N_pts, ...] 25 | _N_rays = args[0].shape[dim_batchify] 26 | _N_pts = args[0].shape[dim_batchify+1] 27 | args = [arg.flatten(dim_batchify, dim_batchify+1) for arg in args] 28 | _N = args[0].shape[dim_batchify] 29 | raw_ret = [] 30 | for i in range(0, _N, chunk): 31 | if dim_batchify == 0: 32 | args_i = [arg[i:i+chunk] for arg in args] 33 | elif dim_batchify == 1: 34 | args_i = [arg[:, i:i+chunk] for arg in args] 35 | elif dim_batchify == 2: 36 | args_i = [arg[:, :, i:i+chunk] for arg in args] 37 | else: 38 | raise NotImplementedError 39 | raw_ret_i = query_fn(*args_i) 40 | if not isinstance(raw_ret_i, tuple): 41 | raw_ret_i = [raw_ret_i] 42 | raw_ret.append(raw_ret_i) 43 | collate_raw_ret = [] 44 | num_entry = 0 45 | for entry in zip(*raw_ret): 46 | if isinstance(entry[0], dict): 47 | tmp_dict = {} 48 | for list_item in entry: 49 | for k, v in list_item.items(): 50 | if k not in tmp_dict: 51 | tmp_dict[k] = [] 52 | tmp_dict[k].append(v) 53 | for k in tmp_dict.keys(): 54 | # [(B), N_rays*N_pts, ...] -> [(B), N_rays, N_pts, ...] 55 | # tmp_dict[k] = torch.cat(tmp_dict[k], dim=dim_batchify).unflatten(dim_batchify, [_N_rays, _N_pts]) 56 | # NOTE: compatible with torch 1.6 57 | v = torch.cat(tmp_dict[k], dim=dim_batchify) 58 | tmp_dict[k] = v.reshape([*v.shape[:dim_batchify], _N_rays, _N_pts, *v.shape[dim_batchify+1:]]) 59 | entry = tmp_dict 60 | else: 61 | # [(B), N_rays*N_pts, ...] -> [(B), N_rays, N_pts, ...] 62 | # entry = torch.cat(entry, dim=dim_batchify).unflatten(dim_batchify, [_N_rays, _N_pts]) 63 | # NOTE: compatible with torch 1.6 64 | v = torch.cat(entry, dim=dim_batchify) 65 | entry = v.reshape([*v.shape[:dim_batchify], _N_rays, _N_pts, *v.shape[dim_batchify+1:]]) 66 | collate_raw_ret.append(entry) 67 | num_entry += 1 68 | if num_entry == 1: 69 | return collate_raw_ret[0] 70 | else: 71 | return tuple(collate_raw_ret) --------------------------------------------------------------------------------