├── README.md ├── assets ├── acid_16view.json ├── acid_4view.json ├── acid_8view.json ├── evaluation_index_acid.json ├── evaluation_index_acid_video.json ├── evaluation_index_dtu_nctx2.json ├── evaluation_index_dtu_nctx3.json ├── evaluation_index_re10k.json ├── evaluation_index_re10k_video.json ├── re10k_16view.json ├── re10k_4view.json └── re10k_8view.json ├── config ├── compute_metrics.yaml ├── dataset │ ├── re10k.yaml │ ├── view_sampler │ │ ├── all.yaml │ │ ├── arbitrary.yaml │ │ ├── bounded.yaml │ │ └── evaluation.yaml │ └── view_sampler_dataset_specific_config │ │ ├── bounded_re10k.yaml │ │ └── evaluation_re10k.yaml ├── evaluation │ ├── ablation.yaml │ ├── acid.yaml │ ├── acid_video.yaml │ ├── re10k.yaml │ └── re10k_video.yaml ├── experiment │ ├── acid.yaml │ ├── dtu.yaml │ └── re10k.yaml ├── generate_evaluation_index.yaml ├── loss │ ├── depth.yaml │ ├── lpips.yaml │ └── mse.yaml ├── main.yaml └── model │ ├── decoder │ └── splatting_cuda.yaml │ └── encoder │ └── costvolume.yaml ├── figure └── pipeline.png ├── requirements.txt ├── src ├── __pycache__ │ ├── config.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── config.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── config.cpython-310.pyc │ ├── global_cfg.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── global_cfg.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── global_cfg.cpython-310.pyc │ └── main.cpython-310.pyc ├── config.py ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── data_module.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── data_module.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── data_module.cpython-310.pyc │ │ ├── dataset.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── dataset.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── dataset.cpython-310.pyc │ │ ├── dataset_re10k.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── dataset_re10k.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── dataset_re10k.cpython-310.pyc │ │ ├── types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── types.cpython-310.pyc │ │ ├── validation_wrapper.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── validation_wrapper.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── validation_wrapper.cpython-310.pyc │ ├── data_module.py │ ├── dataset.py │ ├── dataset_re10k.py │ ├── shims │ │ ├── __pycache__ │ │ │ ├── augmentation_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── augmentation_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── augmentation_shim.cpython-310.pyc │ │ │ ├── bounds_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── bounds_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── bounds_shim.cpython-310.pyc │ │ │ ├── crop_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── crop_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── crop_shim.cpython-310.pyc │ │ │ ├── patch_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── patch_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ └── patch_shim.cpython-310.pyc │ │ ├── augmentation_shim.py │ │ ├── bounds_shim.py │ │ ├── crop_shim.py │ │ └── patch_shim.py │ ├── types.py │ ├── validation_wrapper.py │ └── view_sampler │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── view_sampler.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler.cpython-310.pyc │ │ ├── view_sampler_all.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_all.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_all.cpython-310.pyc │ │ ├── view_sampler_arbitrary.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_arbitrary.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_arbitrary.cpython-310.pyc │ │ ├── view_sampler_bounded.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_bounded.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_bounded.cpython-310.pyc │ │ ├── view_sampler_evaluation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── view_sampler_evaluation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── view_sampler_evaluation.cpython-310.pyc │ │ ├── view_sampler.py │ │ ├── view_sampler_all.py │ │ ├── view_sampler_arbitrary.py │ │ ├── view_sampler_bounded.py │ │ └── view_sampler_evaluation.py ├── evaluation │ ├── __pycache__ │ │ ├── evaluation_index_generator.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── evaluation_index_generator.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── evaluation_index_generator.cpython-310.pyc │ │ ├── metrics.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── metrics.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── metrics.cpython-310.pyc │ ├── evaluation_cfg.py │ ├── evaluation_index_generator.py │ ├── metric_computer.py │ └── metrics.py ├── geometry │ ├── __pycache__ │ │ ├── epipolar_lines.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── epipolar_lines.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── epipolar_lines.cpython-310.pyc │ │ ├── projection.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── projection.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── projection.cpython-310.pyc │ ├── epipolar_lines.py │ └── projection.py ├── global_cfg.py ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── loss.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss.cpython-310.pyc │ │ ├── loss_depth.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss_depth.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss_depth.cpython-310.pyc │ │ ├── loss_lpips.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss_lpips.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss_lpips.cpython-310.pyc │ │ ├── loss_mse.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── loss_mse.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── loss_mse.cpython-310.pyc │ ├── loss.py │ ├── loss_depth.py │ ├── loss_lpips.py │ └── loss_mse.py ├── main.py ├── misc │ ├── LocalLogger.py │ ├── __pycache__ │ │ ├── LocalLogger.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── LocalLogger.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── LocalLogger.cpython-310.pyc │ │ ├── benchmarker.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── benchmarker.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── benchmarker.cpython-310.pyc │ │ ├── heterogeneous_pairings.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── heterogeneous_pairings.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── heterogeneous_pairings.cpython-310.pyc │ │ ├── image_io.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── image_io.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── image_io.cpython-310.pyc │ │ ├── nn_module_tools.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── nn_module_tools.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── nn_module_tools.cpython-310.pyc │ │ ├── sh_rotation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── sh_rotation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── sh_rotation.cpython-310.pyc │ │ ├── step_tracker.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── step_tracker.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── step_tracker.cpython-310.pyc │ │ ├── wandb_tools.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ └── wandb_tools.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── benchmarker.py │ ├── collation.py │ ├── discrete_probability_distribution.py │ ├── heterogeneous_pairings.py │ ├── image_io.py │ ├── nn_module_tools.py │ ├── sh_rotation.py │ ├── step_tracker.py │ └── wandb_tools.py ├── model │ ├── __pycache__ │ │ ├── model_wrapper.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── model_wrapper.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── model_wrapper.cpython-310.pyc │ │ ├── types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── types.cpython-310.pyc │ ├── decoder │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── cuda_splatting.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── cuda_splatting.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── cuda_splatting.cpython-310.pyc │ │ │ ├── decoder.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── decoder.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── decoder.cpython-310.pyc │ │ │ ├── decoder_splatting_cuda.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── decoder_splatting_cuda.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ └── decoder_splatting_cuda.cpython-310.pyc │ │ ├── cuda_splatting.py │ │ ├── decoder.py │ │ └── decoder_splatting_cuda.py │ ├── encoder │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── encoder.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder.cpython-310.pyc │ │ │ ├── encoder_costvolume.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder_costvolume.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ └── encoder_costvolume.cpython-310.pyc │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── backbone_multiview.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── backbone_multiview.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── backbone_multiview.cpython-310.pyc │ │ │ │ ├── multiview_transformer.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── multiview_transformer.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ └── multiview_transformer.cpython-310.pyc │ │ │ ├── backbone_multiview.py │ │ │ ├── multiview_transformer.py │ │ │ └── unimatch │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── backbone.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── backbone.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── backbone.cpython-310.pyc │ │ │ │ ├── geometry.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── geometry.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── geometry.cpython-310.pyc │ │ │ │ ├── position.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── position.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── position.cpython-310.pyc │ │ │ │ ├── trident_conv.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── trident_conv.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── trident_conv.cpython-310.pyc │ │ │ │ ├── utils.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── utils.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ └── utils.cpython-310.pyc │ │ │ │ ├── attention.py │ │ │ │ ├── backbone.py │ │ │ │ ├── geometry.py │ │ │ │ ├── matching.py │ │ │ │ ├── position.py │ │ │ │ ├── reg_refine.py │ │ │ │ ├── transformer.py │ │ │ │ ├── trident_conv.py │ │ │ │ ├── unimatch.py │ │ │ │ └── utils.py │ │ ├── common │ │ │ ├── __pycache__ │ │ │ │ ├── gaussian_adapter.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── gaussian_adapter.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── gaussian_adapter.cpython-310.pyc │ │ │ │ ├── gaussians.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── gaussians.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ └── gaussians.cpython-310.pyc │ │ │ ├── gaussian_adapter.py │ │ │ ├── gaussians.py │ │ │ └── sampler.py │ │ ├── costvolume │ │ │ ├── __pycache__ │ │ │ │ ├── conversions.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── conversions.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── conversions.cpython-310.pyc │ │ │ │ ├── depth_predictor_multiview.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── depth_predictor_multiview.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ └── depth_predictor_multiview.cpython-310.pyc │ │ │ ├── conversions.py │ │ │ ├── depth_predictor_multiview.py │ │ │ └── ldm_unet │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── __init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── attention.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── attention.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── attention.cpython-310.pyc │ │ │ │ ├── unet.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── unet.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── unet.cpython-310.pyc │ │ │ │ ├── util.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── util.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ └── util.cpython-310.pyc │ │ │ │ ├── attention.py │ │ │ │ ├── unet.py │ │ │ │ └── util.py │ │ ├── encoder.py │ │ ├── encoder_costvolume.py │ │ ├── epipolar │ │ │ ├── __pycache__ │ │ │ │ ├── epipolar_sampler.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ ├── epipolar_sampler.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ │ └── epipolar_sampler.cpython-310.pyc │ │ │ └── epipolar_sampler.py │ │ └── visualization │ │ │ ├── __pycache__ │ │ │ ├── encoder_visualizer.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder_visualizer.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder_visualizer.cpython-310.pyc │ │ │ ├── encoder_visualizer_costvolume.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder_visualizer_costvolume.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder_visualizer_costvolume.cpython-310.pyc │ │ │ ├── encoder_visualizer_costvolume_cfg.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── encoder_visualizer_costvolume_cfg.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ └── encoder_visualizer_costvolume_cfg.cpython-310.pyc │ │ │ ├── encoder_visualizer.py │ │ │ ├── encoder_visualizer_costvolume.py │ │ │ └── encoder_visualizer_costvolume_cfg.py │ ├── encodings │ │ ├── __pycache__ │ │ │ ├── positional_encoding.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ │ ├── positional_encoding.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ │ └── positional_encoding.cpython-310.pyc │ │ └── positional_encoding.py │ ├── model_wrapper.py │ ├── ply_export.py │ └── types.py ├── scripts │ ├── compute_metrics.py │ ├── convert_dtu.py │ ├── dump_launch_configs.py │ ├── generate_dtu_evaluation_index.py │ ├── generate_evaluation_index.py │ ├── generate_video_evaluation_index.py │ ├── test_splatter.py │ └── visualize_epipolar_lines.py └── visualization │ ├── __pycache__ │ ├── annotation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── annotation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── annotation.cpython-310.pyc │ ├── color_map.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── color_map.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── color_map.cpython-310.pyc │ ├── colors.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── colors.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── colors.cpython-310.pyc │ ├── layout.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── layout.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ ├── layout.cpython-310.pyc │ ├── validation_in_3d.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ ├── validation_in_3d.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ └── validation_in_3d.cpython-310.pyc │ ├── annotation.py │ ├── camera_trajectory │ ├── __pycache__ │ │ ├── interpolation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── interpolation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── interpolation.cpython-310.pyc │ │ ├── wobble.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── wobble.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── wobble.cpython-310.pyc │ ├── interpolation.py │ ├── spin.py │ └── wobble.py │ ├── color_map.py │ ├── colors.py │ ├── drawing │ ├── __pycache__ │ │ ├── cameras.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── cameras.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── cameras.cpython-310.pyc │ │ ├── coordinate_conversion.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── coordinate_conversion.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── coordinate_conversion.cpython-310.pyc │ │ ├── lines.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── lines.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── lines.cpython-310.pyc │ │ ├── points.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── points.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── points.cpython-310.pyc │ │ ├── rendering.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── rendering.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── rendering.cpython-310.pyc │ │ ├── types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc │ │ ├── types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc │ │ └── types.cpython-310.pyc │ ├── cameras.py │ ├── coordinate_conversion.py │ ├── lines.py │ ├── points.py │ ├── rendering.py │ └── types.py │ ├── layout.py │ ├── validation_in_3d.py │ └── vis_depth.py ├── test.sh └── train.sh /assets/evaluation_index_dtu_nctx2.json: -------------------------------------------------------------------------------- 1 | {"scan1_train_00": {"context": [33, 31], "target": [32]}, "scan1_train_01": {"context": [33, 31], "target": [24]}, "scan1_train_02": {"context": [33, 31], "target": [23]}, "scan1_train_03": {"context": [33, 31], "target": [44]}, "scan8_train_00": {"context": [33, 31], "target": [32]}, "scan8_train_01": {"context": [33, 31], "target": [24]}, "scan8_train_02": {"context": [33, 31], "target": [23]}, "scan8_train_03": {"context": [33, 31], "target": [44]}, "scan21_train_00": {"context": [33, 31], "target": [32]}, "scan21_train_01": {"context": [33, 31], "target": [24]}, "scan21_train_02": {"context": [33, 31], "target": [23]}, "scan21_train_03": {"context": [33, 31], "target": [44]}, "scan30_train_00": {"context": [33, 31], "target": [32]}, "scan30_train_01": {"context": [33, 31], "target": [24]}, "scan30_train_02": {"context": [33, 31], "target": [23]}, "scan30_train_03": {"context": [33, 31], "target": [44]}, "scan31_train_00": {"context": [33, 31], "target": [32]}, "scan31_train_01": {"context": [33, 31], "target": [24]}, "scan31_train_02": {"context": [33, 31], "target": [23]}, "scan31_train_03": {"context": [33, 31], "target": [44]}, "scan34_train_00": {"context": [33, 31], "target": [32]}, "scan34_train_01": {"context": [33, 31], "target": [24]}, "scan34_train_02": {"context": [33, 31], "target": [23]}, "scan34_train_03": {"context": [33, 31], "target": [44]}, "scan38_train_00": {"context": [33, 31], "target": [32]}, "scan38_train_01": {"context": [33, 31], "target": [24]}, "scan38_train_02": {"context": [33, 31], "target": [23]}, "scan38_train_03": {"context": [33, 31], "target": [44]}, "scan40_train_00": {"context": [33, 31], "target": [32]}, "scan40_train_01": {"context": [33, 31], "target": [24]}, "scan40_train_02": {"context": [33, 31], "target": [23]}, "scan40_train_03": {"context": [33, 31], "target": [44]}, "scan41_train_00": {"context": [33, 31], "target": [32]}, "scan41_train_01": {"context": [33, 31], "target": [24]}, "scan41_train_02": {"context": [33, 31], "target": [23]}, "scan41_train_03": {"context": [33, 31], "target": [44]}, "scan45_train_00": {"context": [33, 31], "target": [32]}, "scan45_train_01": {"context": [33, 31], "target": [24]}, "scan45_train_02": {"context": [33, 31], "target": [23]}, "scan45_train_03": {"context": [33, 31], "target": [44]}, "scan55_train_00": {"context": [33, 31], "target": [32]}, "scan55_train_01": {"context": [33, 31], "target": [24]}, "scan55_train_02": {"context": [33, 31], "target": [23]}, "scan55_train_03": {"context": [33, 31], "target": [44]}, "scan63_train_00": {"context": [33, 31], "target": [32]}, "scan63_train_01": {"context": [33, 31], "target": [24]}, "scan63_train_02": {"context": [33, 31], "target": [23]}, "scan63_train_03": {"context": [33, 31], "target": [44]}, "scan82_train_00": {"context": [33, 31], "target": [32]}, "scan82_train_01": {"context": [33, 31], "target": [24]}, "scan82_train_02": {"context": [33, 31], "target": [23]}, "scan82_train_03": {"context": [33, 31], "target": [44]}, "scan103_train_00": {"context": [33, 31], "target": [32]}, "scan103_train_01": {"context": [33, 31], "target": [24]}, "scan103_train_02": {"context": [33, 31], "target": [23]}, "scan103_train_03": {"context": [33, 31], "target": [44]}, "scan110_train_00": {"context": [33, 31], "target": [32]}, "scan110_train_01": {"context": [33, 31], "target": [24]}, "scan110_train_02": {"context": [33, 31], "target": [23]}, "scan110_train_03": {"context": [33, 31], "target": [44]}, "scan114_train_00": {"context": [33, 31], "target": [32]}, "scan114_train_01": {"context": [33, 31], "target": [24]}, "scan114_train_02": {"context": [33, 31], "target": [23]}, "scan114_train_03": {"context": [33, 31], "target": [44]}} -------------------------------------------------------------------------------- /assets/evaluation_index_dtu_nctx3.json: -------------------------------------------------------------------------------- 1 | {"scan1_train_00": {"context": [33, 31, 43], "target": [32]}, "scan1_train_01": {"context": [33, 31, 43], "target": [24]}, "scan1_train_02": {"context": [33, 31, 43], "target": [23]}, "scan1_train_03": {"context": [33, 31, 43], "target": [44]}, "scan8_train_00": {"context": [33, 31, 43], "target": [32]}, "scan8_train_01": {"context": [33, 31, 43], "target": [24]}, "scan8_train_02": {"context": [33, 31, 43], "target": [23]}, "scan8_train_03": {"context": [33, 31, 43], "target": [44]}, "scan21_train_00": {"context": [33, 31, 43], "target": [32]}, "scan21_train_01": {"context": [33, 31, 43], "target": [24]}, "scan21_train_02": {"context": [33, 31, 43], "target": [23]}, "scan21_train_03": {"context": [33, 31, 43], "target": [44]}, "scan30_train_00": {"context": [33, 31, 43], "target": [32]}, "scan30_train_01": {"context": [33, 31, 43], "target": [24]}, "scan30_train_02": {"context": [33, 31, 43], "target": [23]}, "scan30_train_03": {"context": [33, 31, 43], "target": [44]}, "scan31_train_00": {"context": [33, 31, 43], "target": [32]}, "scan31_train_01": {"context": [33, 31, 43], "target": [24]}, "scan31_train_02": {"context": [33, 31, 43], "target": [23]}, "scan31_train_03": {"context": [33, 31, 43], "target": [44]}, "scan34_train_00": {"context": [33, 31, 43], "target": [32]}, "scan34_train_01": {"context": [33, 31, 43], "target": [24]}, "scan34_train_02": {"context": [33, 31, 43], "target": [23]}, "scan34_train_03": {"context": [33, 31, 43], "target": [44]}, "scan38_train_00": {"context": [33, 31, 43], "target": [32]}, "scan38_train_01": {"context": [33, 31, 43], "target": [24]}, "scan38_train_02": {"context": [33, 31, 43], "target": [23]}, "scan38_train_03": {"context": [33, 31, 43], "target": [44]}, "scan40_train_00": {"context": [33, 31, 43], "target": [32]}, "scan40_train_01": {"context": [33, 31, 43], "target": [24]}, "scan40_train_02": {"context": [33, 31, 43], "target": [23]}, "scan40_train_03": {"context": [33, 31, 43], "target": [44]}, "scan41_train_00": {"context": [33, 31, 43], "target": [32]}, "scan41_train_01": {"context": [33, 31, 43], "target": [24]}, "scan41_train_02": {"context": [33, 31, 43], "target": [23]}, "scan41_train_03": {"context": [33, 31, 43], "target": [44]}, "scan45_train_00": {"context": [33, 31, 43], "target": [32]}, "scan45_train_01": {"context": [33, 31, 43], "target": [24]}, "scan45_train_02": {"context": [33, 31, 43], "target": [23]}, "scan45_train_03": {"context": [33, 31, 43], "target": [44]}, "scan55_train_00": {"context": [33, 31, 43], "target": [32]}, "scan55_train_01": {"context": [33, 31, 43], "target": [24]}, "scan55_train_02": {"context": [33, 31, 43], "target": [23]}, "scan55_train_03": {"context": [33, 31, 43], "target": [44]}, "scan63_train_00": {"context": [33, 31, 43], "target": [32]}, "scan63_train_01": {"context": [33, 31, 43], "target": [24]}, "scan63_train_02": {"context": [33, 31, 43], "target": [23]}, "scan63_train_03": {"context": [33, 31, 43], "target": [44]}, "scan82_train_00": {"context": [33, 31, 43], "target": [32]}, "scan82_train_01": {"context": [33, 31, 43], "target": [24]}, "scan82_train_02": {"context": [33, 31, 43], "target": [23]}, "scan82_train_03": {"context": [33, 31, 43], "target": [44]}, "scan103_train_00": {"context": [33, 31, 43], "target": [32]}, "scan103_train_01": {"context": [33, 31, 43], "target": [24]}, "scan103_train_02": {"context": [33, 31, 43], "target": [23]}, "scan103_train_03": {"context": [33, 31, 43], "target": [44]}, "scan110_train_00": {"context": [33, 31, 43], "target": [32]}, "scan110_train_01": {"context": [33, 31, 43], "target": [24]}, "scan110_train_02": {"context": [33, 31, 43], "target": [23]}, "scan110_train_03": {"context": [33, 31, 43], "target": [44]}, "scan114_train_00": {"context": [33, 31, 43], "target": [32]}, "scan114_train_01": {"context": [33, 31, 43], "target": [24]}, "scan114_train_02": {"context": [33, 31, 43], "target": [23]}, "scan114_train_03": {"context": [33, 31, 43], "target": [44]}} -------------------------------------------------------------------------------- /config/compute_metrics.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - model/encoder: epipolar 4 | - loss: [] 5 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 6 | - override dataset/view_sampler: evaluation 7 | 8 | data_loader: 9 | train: 10 | num_workers: 0 11 | persistent_workers: true 12 | batch_size: 1 13 | seed: 1234 14 | test: 15 | num_workers: 4 16 | persistent_workers: false 17 | batch_size: 1 18 | seed: 2345 19 | val: 20 | num_workers: 0 21 | persistent_workers: true 22 | batch_size: 1 23 | seed: 3456 24 | 25 | seed: 111123 26 | -------------------------------------------------------------------------------- /config/dataset/re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: re10k 5 | roots: [datasets/re10k] 6 | make_baseline_1: true 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: true 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | test_times_per_scene: 1 24 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/all.yaml: -------------------------------------------------------------------------------- 1 | name: all 2 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/arbitrary.yaml: -------------------------------------------------------------------------------- 1 | name: arbitrary 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | # If you want to hard-code context views, do so here. 7 | context_views: null 8 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/bounded.yaml: -------------------------------------------------------------------------------- 1 | name: bounded 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | min_distance_to_context_views: 0 9 | 10 | warm_up_steps: 0 11 | initial_min_distance_between_context_views: 2 12 | initial_max_distance_between_context_views: 6 -------------------------------------------------------------------------------- /config/dataset/view_sampler/evaluation.yaml: -------------------------------------------------------------------------------- 1 | name: evaluation 2 | 3 | index_path: assets/evaluation_index_re10k_video.json 4 | num_context_views: 2 5 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | min_distance_between_context_views: 45 6 | max_distance_between_context_views: 192 7 | min_distance_to_context_views: 0 8 | warm_up_steps: 150_000 9 | initial_min_distance_between_context_views: 25 10 | initial_max_distance_between_context_views: 45 11 | num_target_views: 4 12 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | -------------------------------------------------------------------------------- /config/evaluation/ablation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/re10k/ours/frames 12 | - name: No Epipolar 13 | key: no_epipolar 14 | path: baselines/re10k/re10k_ablation_no_epipolar_transformer/frames 15 | - name: No Sampling 16 | key: no_sampling 17 | path: baselines/re10k/re10k_ablation_no_probabilistic_sampling/frames 18 | - name: No Depth Enc. 19 | key: no_depth_encoding 20 | path: baselines/re10k/re10k_ablation_no_depth_encoding/frames 21 | - name: Depth Reg. 22 | key: depth_regularization 23 | path: baselines/re10k/re10k_depth_loss/frames 24 | 25 | side_by_side_path: null 26 | animate_side_by_side: false 27 | highlighted: 28 | - scene: 67a69088a2695987 29 | target_index: 74 30 | - scene: e4f4574df7938f37 31 | target_index: 26 32 | - scene: 29e0bfbad00f0d5e 33 | target_index: 89 34 | 35 | output_metrics_path: baselines/re10k/evaluation_metrics_ablation.json 36 | -------------------------------------------------------------------------------- /config/evaluation/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_acid.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/acid/ours/frames 12 | - name: Du et al.~\cite{du2023cross} 13 | key: du2023 14 | path: baselines/acid/yilun/frames 15 | - name: GPNR~\cite{suhail2022generalizable} 16 | key: gpnr 17 | path: baselines/acid/gpnr/frames 18 | - name: pixelNeRF~\cite{pixelnerf} 19 | key: pixelnerf 20 | path: baselines/acid/pixelnerf/frames 21 | 22 | side_by_side_path: null 23 | animate_side_by_side: false 24 | highlighted: 25 | # Main Paper 26 | # - scene: 405dcfc20f9ba5cb 27 | # target_index: 60 28 | # - scene: 23bb6c5d93c40e13 29 | # target_index: 38 30 | # - scene: a5ab552ede3e70a0 31 | # target_index: 52 32 | # Supplemental 33 | - scene: b896ebcbe8d18553 34 | target_index: 61 35 | - scene: bfbb8ec94454b84f 36 | target_index: 89 37 | - scene: c124e820e8123518 38 | target_index: 111 39 | - scene: c06467006ab8705a 40 | target_index: 24 41 | - scene: d4c545fb5fa1a84a 42 | target_index: 23 43 | - scene: d8f7d9ef3a5a4527 44 | target_index: 138 45 | - scene: 4139c1649bcce55a 46 | target_index: 140 47 | # No Space 48 | # - scene: 20c3d098e3b28058 49 | # target_index: 360 50 | 51 | output_metrics_path: baselines/acid/evaluation_metrics.json -------------------------------------------------------------------------------- /config/evaluation/acid_video.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_acid_video.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/acid/ours/frames_video 12 | - name: Du et al. 13 | key: du2023 14 | path: baselines/acid/yilun/frames_video 15 | - name: GPNR 16 | key: gpnr 17 | path: baselines/acid/gpnr/frames_video 18 | - name: pixelNeRF 19 | key: pixelnerf 20 | path: baselines/acid/pixelnerf/frames_video 21 | 22 | side_by_side_path: outputs/video/acid 23 | animate_side_by_side: true 24 | highlighted: [] 25 | 26 | output_metrics_path: outputs/video/acid/evaluation_metrics.json 27 | -------------------------------------------------------------------------------- /config/evaluation/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/re10k/ours/frames 12 | - name: Du et al.~\cite{du2023cross} 13 | key: du2023 14 | path: baselines/re10k/yilun/frames 15 | - name: GPNR~\cite{suhail2022generalizable} 16 | key: gpnr 17 | path: baselines/re10k/gpnr/frames 18 | - name: pixelNeRF~\cite{pixelnerf} 19 | key: pixelnerf 20 | path: baselines/re10k/pixelnerf/frames 21 | 22 | side_by_side_path: null 23 | animate_side_by_side: false 24 | highlighted: 25 | # Main Paper 26 | - scene: 5be4f1f46b408d68 27 | target_index: 136 28 | - scene: 800ea72b6988f63e 29 | target_index: 167 30 | - scene: d3a01038c5f21473 31 | target_index: 201 32 | # Supplemental 33 | # - scene: 9e585ebbacb3e94c 34 | # target_index: 80 35 | # - scene: 7a00ed342b630d31 36 | # target_index: 101 37 | # - scene: 6f243139ca86b4e5 38 | # target_index: 54 39 | # - scene: 6e77ac6af5163f5b 40 | # target_index: 153 41 | # - scene: 5fbace6c6ca56228 42 | # target_index: 33 43 | # - scene: 7a34348316608aee 44 | # target_index: 80 45 | # - scene: 7a911883348688e9 46 | # target_index: 88 47 | 48 | 49 | output_metrics_path: baselines/re10k/evaluation_metrics.json -------------------------------------------------------------------------------- /config/evaluation/re10k_video.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k_video.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/re10k/ours/frames_video 12 | - name: Du et al. 13 | key: du2023 14 | path: baselines/re10k/yilun/frames_video 15 | - name: GPNR 16 | key: gpnr 17 | path: baselines/re10k/gpnr/frames_video 18 | - name: pixelNeRF 19 | key: pixelnerf 20 | path: baselines/re10k/pixelnerf/frames_video 21 | 22 | side_by_side_path: outputs/video/re10k 23 | animate_side_by_side: true 24 | highlighted: [] 25 | 26 | output_metrics_path: outputs/video/re10k/evaluation_metrics.json 27 | -------------------------------------------------------------------------------- /config/experiment/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: acid 10 | tags: [acid, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/acid] 40 | near: 1. 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | 45 | test: 46 | eval_time_skip_steps: 5 47 | compute_scores: true 48 | -------------------------------------------------------------------------------- /config/experiment/dtu.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: dtu/views2 10 | tags: [dtu, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/dtu] 40 | near: 2.125 41 | far: 4.525 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | test_times_per_scene: 4 45 | skip_bad_shape: false 46 | 47 | test: 48 | eval_time_skip_steps: 5 49 | compute_scores: true 50 | -------------------------------------------------------------------------------- /config/experiment/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: re10k 10 | tags: [re10k, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/re10k] 40 | near: 1. 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | 45 | test: 46 | eval_time_skip_steps: 5 47 | compute_scores: true 48 | -------------------------------------------------------------------------------- /config/generate_evaluation_index.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - override dataset/view_sampler: all 5 | 6 | dataset: 7 | overfit_to_scene: null 8 | 9 | data_loader: 10 | train: 11 | num_workers: 0 12 | persistent_workers: true 13 | batch_size: 1 14 | seed: 1234 15 | test: 16 | num_workers: 8 17 | persistent_workers: false 18 | batch_size: 1 19 | seed: 2345 20 | val: 21 | num_workers: 0 22 | persistent_workers: true 23 | batch_size: 1 24 | seed: 3456 25 | 26 | index_generator: 27 | num_target_views: 3 28 | min_overlap: 0.6 29 | max_overlap: 1.0 30 | min_distance: 45 31 | max_distance: 135 32 | output_path: outputs/evaluation_index_re10k 33 | save_previews: false 34 | seed: 123 35 | 36 | seed: 456 37 | -------------------------------------------------------------------------------- /config/loss/depth.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | weight: 0.25 3 | sigma_image: null 4 | use_second_derivative: false 5 | -------------------------------------------------------------------------------- /config/loss/lpips.yaml: -------------------------------------------------------------------------------- 1 | lpips: 2 | weight: 0.05 3 | apply_after_step: 150_000 4 | -------------------------------------------------------------------------------- /config/loss/mse.yaml: -------------------------------------------------------------------------------- 1 | mse: 2 | weight: 1.0 3 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: acid 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - model/encoder: costvolume 5 | - model/decoder: splatting_cuda 6 | - loss: [mse] 7 | 8 | wandb: 9 | project: mvsplat 10 | entity: placeholder 11 | name: placeholder 12 | mode: disabled 13 | id: null 14 | 15 | mode: train 16 | 17 | dataset: 18 | overfit_to_scene: null 19 | 20 | data_loader: 21 | # Avoid having to spin up new processes to print out visualizations. 22 | train: 23 | num_workers: 10 24 | persistent_workers: true 25 | batch_size: 4 26 | seed: 1234 27 | test: 28 | num_workers: 4 29 | persistent_workers: false 30 | batch_size: 1 31 | seed: 2345 32 | val: 33 | num_workers: 1 34 | persistent_workers: true 35 | batch_size: 1 36 | seed: 3456 37 | 38 | optimizer: 39 | lr: 2.e-4 40 | warm_up_steps: 2000 41 | cosine_lr: true 42 | 43 | checkpointing: 44 | load: null 45 | every_n_train_steps: 5000 # 5000 46 | save_top_k: 20 47 | pretrained_model: null 48 | 49 | train: 50 | depth_mode: null 51 | extended_visualization: false 52 | print_log_every_n_steps: 10 53 | 54 | test: 55 | output_path: outputs/test 56 | compute_scores: false 57 | eval_time_skip_steps: 0 58 | save_image: true 59 | save_video: false 60 | 61 | seed: 111123 62 | 63 | trainer: 64 | max_steps: -1 65 | val_check_interval: 1000 66 | gradient_clip_val: 0.5 67 | num_sanity_val_steps: 2 68 | 69 | output_dir: null 70 | -------------------------------------------------------------------------------- /config/model/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | -------------------------------------------------------------------------------- /config/model/encoder/costvolume.yaml: -------------------------------------------------------------------------------- 1 | name: costvolume 2 | 3 | opacity_mapping: 4 | initial: 0.0 5 | final: 0.0 6 | warm_up: 1 7 | 8 | num_depth_candidates: 32 9 | num_surfaces: 1 10 | 11 | gaussians_per_pixel: 1 12 | 13 | gaussian_adapter: 14 | gaussian_scale_min: 0.5 15 | gaussian_scale_max: 15.0 16 | sh_degree: 4 17 | 18 | d_feature: 128 19 | 20 | visualizer: 21 | num_samples: 8 22 | min_resolution: 256 23 | export_ply: false 24 | 25 | # params for multi-view depth predictor 26 | unimatch_weights_path: "checkpoints/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" 27 | multiview_trans_attn_split: 2 28 | costvolume_unet_feat_dim: 128 29 | costvolume_unet_channel_mult: [1,1,1] 30 | costvolume_unet_attn_res: [] 31 | depth_unet_feat_dim: 64 32 | depth_unet_attn_res: [] 33 | depth_unet_channel_mult: [1, 1, 1] 34 | downscale_factor: 4 35 | shim_patch_size: 4 36 | 37 | # below are ablation settings, keep them as false for default model 38 | wo_depth_refine: false # Table 3: base 39 | wo_cost_volume: false # Table 3: w/o cost volume 40 | wo_backbone_cross_attn: false # Table 3: w/o cross-view attention 41 | wo_cost_volume_refine: false # Table 3: w/o U-Net 42 | use_epipolar_trans: false # Table B: w/ Epipolar Transformer 43 | view_matching: true 44 | 45 | -------------------------------------------------------------------------------- /figure/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/figure/pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel 2 | tqdm 3 | pytorch_lightning 4 | black 5 | ruff 6 | hydra-core 7 | jaxtyping 8 | beartype 9 | wandb 10 | einops 11 | colorama 12 | scikit-image 13 | colorspacious 14 | matplotlib 15 | moviepy 16 | imageio 17 | git+https://github.com/dcharatan/diff-gaussian-rasterization-modified 18 | timm 19 | dacite 20 | lpips 21 | e3nn 22 | plyfile 23 | tabulate 24 | svg.py 25 | opencv-python==4.6.0.66 26 | sk-video 27 | -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/config.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/config.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/global_cfg.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/global_cfg.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/__pycache__/global_cfg.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/global_cfg.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/__pycache__/global_cfg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/global_cfg.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/main.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/__pycache__/main.cpython-310.pyc -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Literal, Optional, Type, TypeVar 4 | 5 | from dacite import Config, from_dict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .dataset.data_module import DataLoaderCfg, DatasetCfg 9 | from .loss import LossCfgWrapper 10 | from .model.decoder import DecoderCfg 11 | from .model.encoder import EncoderCfg 12 | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg 13 | 14 | 15 | @dataclass 16 | class CheckpointingCfg: 17 | load: Optional[str] # Not a path, since it could be something like wandb://... 18 | every_n_train_steps: int 19 | save_top_k: int 20 | pretrained_model: Optional[str] 21 | 22 | 23 | @dataclass 24 | class ModelCfg: 25 | decoder: DecoderCfg 26 | encoder: EncoderCfg 27 | 28 | 29 | @dataclass 30 | class TrainerCfg: 31 | max_steps: int 32 | val_check_interval: int | float | None 33 | gradient_clip_val: int | float | None 34 | num_sanity_val_steps: int 35 | 36 | 37 | @dataclass 38 | class RootCfg: 39 | wandb: dict 40 | mode: Literal["train", "test"] 41 | dataset: DatasetCfg 42 | data_loader: DataLoaderCfg 43 | model: ModelCfg 44 | optimizer: OptimizerCfg 45 | checkpointing: CheckpointingCfg 46 | trainer: TrainerCfg 47 | loss: list[LossCfgWrapper] 48 | test: TestCfg 49 | train: TrainCfg 50 | seed: int 51 | 52 | 53 | TYPE_HOOKS = { 54 | Path: Path, 55 | } 56 | 57 | 58 | T = TypeVar("T") 59 | 60 | 61 | def load_typed_config( 62 | cfg: DictConfig, 63 | data_class: Type[T], 64 | extra_type_hooks: dict = {}, 65 | ) -> T: 66 | return from_dict( 67 | data_class, 68 | OmegaConf.to_container(cfg), 69 | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), 70 | ) 71 | 72 | 73 | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: 74 | # The dummy allows the union to be converted. 75 | @dataclass 76 | class Dummy: 77 | dummy: LossCfgWrapper 78 | 79 | return [ 80 | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy 81 | for k, v in joined.items() 82 | ] 83 | 84 | 85 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 86 | return load_typed_config( 87 | cfg, 88 | RootCfg, 89 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers}, 90 | ) 91 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from ..misc.step_tracker import StepTracker 4 | from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg 5 | from .types import Stage 6 | from .view_sampler import get_view_sampler 7 | 8 | DATASETS: dict[str, Dataset] = { 9 | "re10k": DatasetRE10k, 10 | } 11 | 12 | 13 | DatasetCfg = DatasetRE10kCfg 14 | 15 | 16 | def get_dataset( 17 | cfg: DatasetCfg, 18 | stage: Stage, 19 | step_tracker: StepTracker | None, 20 | ) -> Dataset: 21 | view_sampler = get_view_sampler( 22 | cfg.view_sampler, 23 | stage, 24 | cfg.overfit_to_scene is not None, 25 | cfg.cameras_are_circular, 26 | step_tracker, 27 | ) 28 | return DATASETS[cfg.name](cfg, stage, view_sampler) 29 | -------------------------------------------------------------------------------- /src/dataset/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/data_module.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/data_module.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/data_module.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/data_module.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/data_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/data_module.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/dataset.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/dataset.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset_re10k.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/dataset_re10k.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset_re10k.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/dataset_re10k.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/dataset_re10k.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/dataset_re10k.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/types.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/types.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/validation_wrapper.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/validation_wrapper.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/validation_wrapper.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/validation_wrapper.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/__pycache__/validation_wrapper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/__pycache__/validation_wrapper.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | image_shape: list[int] 9 | background_color: list[float] 10 | cameras_are_circular: bool 11 | overfit_to_scene: str | None 12 | view_sampler: ViewSamplerCfg 13 | -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/augmentation_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/augmentation_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/augmentation_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/augmentation_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/augmentation_shim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/augmentation_shim.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/bounds_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/bounds_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/bounds_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/bounds_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/bounds_shim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/bounds_shim.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/crop_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/crop_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/crop_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/crop_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/crop_shim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/crop_shim.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/patch_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/patch_shim.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/patch_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/patch_shim.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/shims/__pycache__/patch_shim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/shims/__pycache__/patch_shim.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /src/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum( 30 | intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" 31 | ) 32 | 33 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 34 | # don't make much sense anyway. 35 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 36 | 37 | return baselines / (disparity * mean_pixel_size) 38 | 39 | 40 | def apply_bounds_shim( 41 | batch: BatchedExample, 42 | near_disparity: float, 43 | far_disparity: float, 44 | ) -> BatchedExample: 45 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 46 | assumes that all of an example's views are of roughly the same thing. 47 | """ 48 | 49 | context = batch["context"] 50 | _, cv, _, h, w = context["image"].shape 51 | 52 | # Compute near and far planes using the context views. 53 | near = compute_depth_for_disparity( 54 | context["extrinsics"], 55 | context["intrinsics"], 56 | (h, w), 57 | near_disparity, 58 | ) 59 | far = compute_depth_for_disparity( 60 | context["extrinsics"], 61 | context["intrinsics"], 62 | (h, w), 63 | far_disparity, 64 | ) 65 | 66 | target = batch["target"] 67 | _, tv, _, _, _ = target["image"].shape 68 | return { 69 | **batch, 70 | "context": { 71 | **context, 72 | "near": repeat(near, "b -> b v", v=cv), 73 | "far": repeat(far, "b -> b v", v=cv), 74 | }, 75 | "target": { 76 | **target, 77 | "near": repeat(near, "b -> b v", v=tv), 78 | "far": repeat(far, "b -> b v", v=tv), 79 | }, 80 | } 81 | -------------------------------------------------------------------------------- /src/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | 8 | from ..types import AnyExample, AnyViews 9 | 10 | 11 | def rescale( 12 | image: Float[Tensor, "3 h_in w_in"], 13 | shape: tuple[int, int], 14 | ) -> Float[Tensor, "3 h_out w_out"]: 15 | h, w = shape 16 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 17 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 18 | image_new = Image.fromarray(image_new) 19 | image_new = image_new.resize((w, h), Image.LANCZOS) 20 | image_new = np.array(image_new) / 255 21 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 22 | return rearrange(image_new, "h w c -> c h w") 23 | 24 | 25 | def center_crop( 26 | images: Float[Tensor, "*#batch c h w"], 27 | intrinsics: Float[Tensor, "*#batch 3 3"], 28 | shape: tuple[int, int], 29 | ) -> tuple[ 30 | Float[Tensor, "*#batch c h_out w_out"], # updated images 31 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 32 | ]: 33 | *_, h_in, w_in = images.shape 34 | h_out, w_out = shape 35 | 36 | # Note that odd input dimensions induce half-pixel misalignments. 37 | row = (h_in - h_out) // 2 38 | col = (w_in - w_out) // 2 39 | 40 | # Center-crop the image. 41 | images = images[..., :, row : row + h_out, col : col + w_out] 42 | 43 | # Adjust the intrinsics to account for the cropping. 44 | intrinsics = intrinsics.clone() 45 | intrinsics[..., 0, 0] *= w_in / w_out # fx 46 | intrinsics[..., 1, 1] *= h_in / h_out # fy 47 | 48 | return images, intrinsics 49 | 50 | 51 | def rescale_and_crop( 52 | images: Float[Tensor, "*#batch c h w"], 53 | intrinsics: Float[Tensor, "*#batch 3 3"], 54 | shape: tuple[int, int], 55 | ) -> tuple[ 56 | Float[Tensor, "*#batch c h_out w_out"], # updated images 57 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 58 | ]: 59 | *_, h_in, w_in = images.shape 60 | h_out, w_out = shape 61 | assert h_out <= h_in and w_out <= w_in 62 | 63 | scale_factor = max(h_out / h_in, w_out / w_in) 64 | h_scaled = round(h_in * scale_factor) 65 | w_scaled = round(w_in * scale_factor) 66 | assert h_scaled == h_out or w_scaled == w_out 67 | 68 | # Reshape the images to the correct size. Assume we don't have to worry about 69 | # changing the intrinsics based on how the images are rounded. 70 | *batch, c, h, w = images.shape 71 | images = images.reshape(-1, c, h, w) 72 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 73 | images = images.reshape(*batch, c, h_scaled, w_scaled) 74 | 75 | return center_crop(images, intrinsics, shape) 76 | 77 | 78 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 79 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 80 | return { 81 | **views, 82 | "image": images, 83 | "intrinsics": intrinsics, 84 | } 85 | 86 | 87 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 88 | """Crop images in the example.""" 89 | return { 90 | **example, 91 | "context": apply_crop_shim_to_views(example["context"], shape), 92 | "target": apply_crop_shim_to_views(example["target"], shape), 93 | } 94 | -------------------------------------------------------------------------------- /src/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | 24 | 25 | class BatchedExample(TypedDict, total=False): 26 | target: BatchedViews 27 | context: BatchedViews 28 | scene: list[str] 29 | 30 | 31 | class UnbatchedViews(TypedDict, total=False): 32 | extrinsics: Float[Tensor, "_ 4 4"] 33 | intrinsics: Float[Tensor, "_ 3 3"] 34 | image: Float[Tensor, "_ 3 height width"] 35 | near: Float[Tensor, " _"] 36 | far: Float[Tensor, " _"] 37 | index: Int64[Tensor, " _"] 38 | 39 | 40 | class UnbatchedExample(TypedDict, total=False): 41 | target: UnbatchedViews 42 | context: UnbatchedViews 43 | scene: str 44 | 45 | 46 | # A data shim modifies the example after it's been returned from the data loader. 47 | DataShim = Callable[[BatchedExample], BatchedExample] 48 | 49 | AnyExample = BatchedExample | UnbatchedExample 50 | AnyViews = BatchedViews | UnbatchedViews 51 | -------------------------------------------------------------------------------- /src/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ValidationWrapper(Dataset): 8 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 9 | visualization step. 10 | """ 11 | 12 | dataset: Dataset 13 | dataset_iterator: Optional[Iterator] 14 | length: int 15 | 16 | def __init__(self, dataset: Dataset, length: int) -> None: 17 | super().__init__() 18 | self.dataset = dataset 19 | self.length = length 20 | self.dataset_iterator = None 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index: int): 26 | if isinstance(self.dataset, IterableDataset): 27 | if self.dataset_iterator is None: 28 | self.dataset_iterator = iter(self.dataset) 29 | return next(self.dataset_iterator) 30 | 31 | random_index = torch.randint(0, len(self.dataset), tuple()) 32 | return self.dataset[random_index.item()] 33 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ...misc.step_tracker import StepTracker 4 | from ..types import Stage 5 | from .view_sampler import ViewSampler 6 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 7 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 8 | from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 9 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 10 | 11 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 12 | "all": ViewSamplerAll, 13 | "arbitrary": ViewSamplerArbitrary, 14 | "bounded": ViewSamplerBounded, 15 | "evaluation": ViewSamplerEvaluation, 16 | } 17 | 18 | ViewSamplerCfg = ( 19 | ViewSamplerArbitraryCfg 20 | | ViewSamplerBoundedCfg 21 | | ViewSamplerEvaluationCfg 22 | | ViewSamplerAllCfg 23 | ) 24 | 25 | 26 | def get_view_sampler( 27 | cfg: ViewSamplerCfg, 28 | stage: Stage, 29 | overfit: bool, 30 | cameras_are_circular: bool, 31 | step_tracker: StepTracker | None, 32 | ) -> ViewSampler[Any]: 33 | return VIEW_SAMPLERS[cfg.name]( 34 | cfg, 35 | stage, 36 | overfit, 37 | cameras_are_circular, 38 | step_tracker, 39 | ) 40 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_all.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_all.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_all.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_all.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_all.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_all.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_arbitrary.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_arbitrary.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_arbitrary.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_arbitrary.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_arbitrary.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_arbitrary.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_bounded.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_bounded.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_bounded.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_bounded.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_bounded.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_bounded.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_evaluation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_evaluation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_evaluation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_evaluation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/__pycache__/view_sampler_evaluation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/dataset/view_sampler/__pycache__/view_sampler_evaluation.cpython-310.pyc -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from ...misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | **kwargs, 43 | ) -> tuple[ 44 | Int64[Tensor, " context_view"], # indices for context views 45 | Int64[Tensor, " target_view"], # indices for target views 46 | ]: 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def num_target_views(self) -> int: 52 | pass 53 | 54 | @property 55 | @abstractmethod 56 | def num_context_views(self) -> int: 57 | pass 58 | 59 | @property 60 | def global_step(self) -> int: 61 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 62 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | ) -> tuple[ 24 | Int64[Tensor, " context_view"], # indices for context views 25 | Int64[Tensor, " target_view"], # indices for target views 26 | ]: 27 | v, _, _ = extrinsics.shape 28 | all_frames = torch.arange(v, device=device) 29 | return all_frames, all_frames 30 | 31 | @property 32 | def num_context_views(self) -> int: 33 | return 0 34 | 35 | @property 36 | def num_target_views(self) -> int: 37 | return 0 38 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerArbitraryCfg: 13 | name: Literal["arbitrary"] 14 | num_context_views: int 15 | num_target_views: int 16 | context_views: list[int] | None 17 | target_views: list[int] | None 18 | 19 | 20 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 21 | def sample( 22 | self, 23 | scene: str, 24 | extrinsics: Float[Tensor, "view 4 4"], 25 | intrinsics: Float[Tensor, "view 3 3"], 26 | device: torch.device = torch.device("cpu"), 27 | ) -> tuple[ 28 | Int64[Tensor, " context_view"], # indices for context views 29 | Int64[Tensor, " target_view"], # indices for target views 30 | ]: 31 | """Arbitrarily sample context and target views.""" 32 | num_views, _, _ = extrinsics.shape 33 | 34 | index_context = torch.randint( 35 | 0, 36 | num_views, 37 | size=(self.cfg.num_context_views,), 38 | device=device, 39 | ) 40 | 41 | # Allow the context views to be fixed. 42 | if self.cfg.context_views is not None: 43 | assert len(self.cfg.context_views) == self.cfg.num_context_views 44 | index_context = torch.tensor( 45 | self.cfg.context_views, dtype=torch.int64, device=device 46 | ) 47 | 48 | index_target = torch.randint( 49 | 0, 50 | num_views, 51 | size=(self.cfg.num_target_views,), 52 | device=device, 53 | ) 54 | 55 | # Allow the target views to be fixed. 56 | if self.cfg.target_views is not None: 57 | assert len(self.cfg.target_views) == self.cfg.num_target_views 58 | index_target = torch.tensor( 59 | self.cfg.target_views, dtype=torch.int64, device=device 60 | ) 61 | 62 | return index_context, index_target 63 | 64 | @property 65 | def num_context_views(self) -> int: 66 | return self.cfg.num_context_views 67 | 68 | @property 69 | def num_target_views(self) -> int: 70 | return self.cfg.num_target_views 71 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from ...misc.step_tracker import StepTracker 13 | from ..types import Stage 14 | from .view_sampler import ViewSampler 15 | 16 | 17 | @dataclass 18 | class ViewSamplerEvaluationCfg: 19 | name: Literal["evaluation"] 20 | index_path: Path 21 | num_context_views: int 22 | 23 | 24 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 25 | index: dict[str, IndexEntry | None] 26 | 27 | def __init__( 28 | self, 29 | cfg: ViewSamplerEvaluationCfg, 30 | stage: Stage, 31 | is_overfitting: bool, 32 | cameras_are_circular: bool, 33 | step_tracker: StepTracker | None, 34 | ) -> None: 35 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 36 | 37 | dacite_config = Config(cast=[tuple]) 38 | with cfg.index_path.open("r") as f: 39 | self.index = { 40 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 41 | for k, v in json.load(f).items() 42 | } 43 | 44 | def sample( 45 | self, 46 | scene: str, 47 | extrinsics: Float[Tensor, "view 4 4"], 48 | intrinsics: Float[Tensor, "view 3 3"], 49 | device: torch.device = torch.device("cpu"), 50 | ) -> tuple[ 51 | Int64[Tensor, " context_view"], # indices for context views 52 | Int64[Tensor, " target_view"], # indices for target views 53 | ]: 54 | entry = self.index.get(scene) 55 | if entry is None: 56 | raise ValueError(f"No indices available for scene {scene}.") 57 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 58 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 59 | return context_indices, target_indices 60 | 61 | @property 62 | def num_context_views(self) -> int: 63 | return 0 64 | 65 | @property 66 | def num_target_views(self) -> int: 67 | return 0 68 | -------------------------------------------------------------------------------- /src/evaluation/__pycache__/evaluation_index_generator.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/evaluation/__pycache__/evaluation_index_generator.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/evaluation_index_generator.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/evaluation/__pycache__/evaluation_index_generator.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/evaluation_index_generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/evaluation/__pycache__/evaluation_index_generator.cpython-310.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/metrics.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/evaluation/__pycache__/metrics.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/metrics.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/evaluation/__pycache__/metrics.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/evaluation/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /src/evaluation/evaluation_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | 5 | @dataclass 6 | class MethodCfg: 7 | name: str 8 | key: str 9 | path: Path 10 | 11 | 12 | @dataclass 13 | class SceneCfg: 14 | scene: str 15 | target_index: int 16 | 17 | 18 | @dataclass 19 | class EvaluationCfg: 20 | methods: list[MethodCfg] 21 | side_by_side_path: Path | None 22 | animate_side_by_side: bool 23 | highlighted: list[SceneCfg] 24 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from skimage.metrics import structural_similarity 8 | from torch import Tensor 9 | 10 | 11 | @torch.no_grad() 12 | def compute_psnr( 13 | ground_truth: Float[Tensor, "batch channel height width"], 14 | predicted: Float[Tensor, "batch channel height width"], 15 | ) -> Float[Tensor, " batch"]: 16 | ground_truth = ground_truth.clip(min=0, max=1) 17 | predicted = predicted.clip(min=0, max=1) 18 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 19 | return -10 * mse.log10() 20 | 21 | 22 | @cache 23 | def get_lpips(device: torch.device) -> LPIPS: 24 | return LPIPS(net="vgg").to(device) 25 | 26 | 27 | @torch.no_grad() 28 | def compute_lpips( 29 | ground_truth: Float[Tensor, "batch channel height width"], 30 | predicted: Float[Tensor, "batch channel height width"], 31 | ) -> Float[Tensor, " batch"]: 32 | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) 33 | return value[:, 0, 0, 0] 34 | 35 | 36 | @torch.no_grad() 37 | def compute_ssim( 38 | ground_truth: Float[Tensor, "batch channel height width"], 39 | predicted: Float[Tensor, "batch channel height width"], 40 | ) -> Float[Tensor, " batch"]: 41 | ssim = [ 42 | structural_similarity( 43 | gt.detach().cpu().numpy(), 44 | hat.detach().cpu().numpy(), 45 | win_size=11, 46 | gaussian_weights=True, 47 | channel_axis=0, 48 | data_range=1.0, 49 | ) 50 | for gt, hat in zip(ground_truth, predicted) 51 | ] 52 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 53 | -------------------------------------------------------------------------------- /src/geometry/__pycache__/epipolar_lines.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/geometry/__pycache__/epipolar_lines.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/geometry/__pycache__/epipolar_lines.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/geometry/__pycache__/epipolar_lines.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/geometry/__pycache__/epipolar_lines.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/geometry/__pycache__/epipolar_lines.cpython-310.pyc -------------------------------------------------------------------------------- /src/geometry/__pycache__/projection.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/geometry/__pycache__/projection.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/geometry/__pycache__/projection.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/geometry/__pycache__/projection.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/geometry/__pycache__/projection.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/geometry/__pycache__/projection.cpython-310.pyc -------------------------------------------------------------------------------- /src/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Loss 2 | from .loss_depth import LossDepth, LossDepthCfgWrapper 3 | from .loss_lpips import LossLpips, LossLpipsCfgWrapper 4 | from .loss_mse import LossMse, LossMseCfgWrapper 5 | 6 | LOSSES = { 7 | LossDepthCfgWrapper: LossDepth, 8 | LossLpipsCfgWrapper: LossLpips, 9 | LossMseCfgWrapper: LossMse, 10 | } 11 | 12 | LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper 13 | 14 | 15 | def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: 16 | return [LOSSES[type(cfg)](cfg) for cfg in cfgs] 17 | -------------------------------------------------------------------------------- /src/loss/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_depth.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_depth.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_depth.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_depth.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_depth.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_depth.cpython-310.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_lpips.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_lpips.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_lpips.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_lpips.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_lpips.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_lpips.cpython-310.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_mse.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_mse.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_mse.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_mse.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/loss_mse.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/loss/__pycache__/loss_mse.cpython-310.pyc -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import fields 3 | from typing import Generic, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | 12 | T_cfg = TypeVar("T_cfg") 13 | T_wrapper = TypeVar("T_wrapper") 14 | 15 | 16 | class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): 17 | cfg: T_cfg 18 | name: str 19 | 20 | def __init__(self, cfg: T_wrapper) -> None: 21 | super().__init__() 22 | 23 | # Extract the configuration from the wrapper. 24 | (field,) = fields(type(cfg)) 25 | self.cfg = getattr(cfg, field.name) 26 | self.name = field.name 27 | 28 | @abstractmethod 29 | def forward( 30 | self, 31 | prediction: DecoderOutput, 32 | batch: BatchedExample, 33 | gaussians, 34 | global_step: int, 35 | ) -> Float[Tensor, ""]: 36 | pass 37 | -------------------------------------------------------------------------------- /src/loss/loss_depth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | from .loss import Loss 12 | 13 | 14 | @dataclass 15 | class LossDepthCfg: 16 | weight: float 17 | sigma_image: float | None 18 | use_second_derivative: bool 19 | 20 | 21 | @dataclass 22 | class LossDepthCfgWrapper: 23 | depth: LossDepthCfg 24 | 25 | 26 | class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): 27 | def forward( 28 | self, 29 | prediction: DecoderOutput, 30 | batch: BatchedExample, 31 | gaussians, 32 | global_step: int, 33 | ) -> Float[Tensor, ""]: 34 | # Scale the depth between the near and far planes. 35 | near = batch["target"]["near"][..., None, None].log() 36 | far = batch["target"]["far"][..., None, None].log() 37 | depth = prediction.depth.minimum(far).maximum(near) 38 | depth = (depth - near) / (far - near) 39 | 40 | # Compute the difference between neighboring pixels in each direction. 41 | depth_dx = depth.diff(dim=-1) 42 | depth_dy = depth.diff(dim=-2) 43 | 44 | # If desired, compute a 2nd derivative. 45 | if self.cfg.use_second_derivative: 46 | depth_dx = depth_dx.diff(dim=-1) 47 | depth_dy = depth_dy.diff(dim=-2) 48 | 49 | # If desired, add bilateral filtering. 50 | if self.cfg.sigma_image is not None: 51 | color_gt = batch["target"]["image"] 52 | color_dx = reduce(color_gt.diff(dim=-1), "b v c h w -> b v h w", "max") 53 | color_dy = reduce(color_gt.diff(dim=-2), "b v c h w -> b v h w", "max") 54 | if self.cfg.use_second_derivative: 55 | color_dx = color_dx[..., :, 1:].maximum(color_dx[..., :, :-1]) 56 | color_dy = color_dy[..., 1:, :].maximum(color_dy[..., :-1, :]) 57 | depth_dx = depth_dx * torch.exp(-color_dx * self.cfg.sigma_image) 58 | depth_dy = depth_dy * torch.exp(-color_dy * self.cfg.sigma_image) 59 | 60 | return self.cfg.weight * (depth_dx.abs().mean() + depth_dy.abs().mean()) 61 | -------------------------------------------------------------------------------- /src/loss/loss_lpips.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from torch import Tensor 8 | 9 | from ..dataset.types import BatchedExample 10 | from ..misc.nn_module_tools import convert_to_buffer 11 | from ..model.decoder.decoder import DecoderOutput 12 | from ..model.types import Gaussians 13 | from .loss import Loss 14 | 15 | 16 | @dataclass 17 | class LossLpipsCfg: 18 | weight: float 19 | apply_after_step: int 20 | 21 | 22 | @dataclass 23 | class LossLpipsCfgWrapper: 24 | lpips: LossLpipsCfg 25 | 26 | 27 | class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): 28 | lpips: LPIPS 29 | 30 | def __init__(self, cfg: LossLpipsCfgWrapper) -> None: 31 | super().__init__(cfg) 32 | 33 | self.lpips = LPIPS(net="vgg") 34 | convert_to_buffer(self.lpips, persistent=False) 35 | 36 | def forward( 37 | self, 38 | prediction: DecoderOutput, 39 | batch: BatchedExample, 40 | gaussians, 41 | global_step: int, 42 | ) -> Float[Tensor, ""]: 43 | image = batch["target"]["image"] 44 | 45 | # Before the specified step, don't apply the loss. 46 | if global_step < self.cfg.apply_after_step: 47 | return torch.tensor(0, dtype=torch.float32, device=image.device) 48 | 49 | loss = self.lpips.forward( 50 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 51 | rearrange(image, "b v c h w -> (b v) c h w"), 52 | normalize=True, 53 | ) 54 | return self.cfg.weight * loss.mean() 55 | -------------------------------------------------------------------------------- /src/loss/loss_mse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..dataset.types import BatchedExample 7 | from ..model.decoder.decoder import DecoderOutput 8 | from ..model.types import Gaussians 9 | from .loss import Loss 10 | 11 | 12 | @dataclass 13 | class LossMseCfg: 14 | weight: float 15 | 16 | 17 | @dataclass 18 | class LossMseCfgWrapper: 19 | mse: LossMseCfg 20 | 21 | 22 | class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): 23 | def forward( 24 | self, 25 | prediction: DecoderOutput, 26 | batch: BatchedExample, 27 | gaussians, 28 | global_step: int, 29 | ) -> Float[Tensor, ""]: 30 | delta = prediction.color - batch["target"]["image"] 31 | return self.cfg.weight * (delta**2).mean() 32 | -------------------------------------------------------------------------------- /src/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | from PIL import Image 6 | from pytorch_lightning.loggers.logger import Logger 7 | from pytorch_lightning.utilities import rank_zero_only 8 | 9 | LOG_PATH = Path("outputs/local") 10 | 11 | 12 | class LocalLogger(Logger): 13 | def __init__(self, log_path=None) -> None: 14 | super().__init__() 15 | self.experiment = None 16 | os.system(f"rm -r {LOG_PATH}") 17 | if log_path != None: 18 | self.log_path = log_path 19 | else: 20 | self.log_path = LOG_PATH 21 | 22 | @property 23 | def name(self): 24 | return "LocalLogger" 25 | 26 | @property 27 | def version(self): 28 | return 0 29 | 30 | @rank_zero_only 31 | def log_hyperparams(self, params): 32 | pass 33 | 34 | @rank_zero_only 35 | def log_metrics(self, metrics, step): 36 | pass 37 | 38 | @rank_zero_only 39 | def log_image( 40 | self, 41 | key: str, 42 | images: list[Any], 43 | step: Optional[int] = None, 44 | **kwargs, 45 | ): 46 | # The function signature is the same as the wandb logger's, but the step is 47 | # actually required. 48 | assert step is not None 49 | for index, image in enumerate(images): 50 | path = self.log_path / f"{key}/{index:0>2}_{step:0>6}.png" 51 | 52 | path.parent.mkdir(exist_ok=True, parents=True) 53 | Image.fromarray(image).save(path) 54 | -------------------------------------------------------------------------------- /src/misc/__pycache__/LocalLogger.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/LocalLogger.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/LocalLogger.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/LocalLogger.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/LocalLogger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/LocalLogger.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/benchmarker.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/benchmarker.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/benchmarker.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/benchmarker.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/benchmarker.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/benchmarker.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/heterogeneous_pairings.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/heterogeneous_pairings.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/heterogeneous_pairings.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/heterogeneous_pairings.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/heterogeneous_pairings.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/heterogeneous_pairings.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/image_io.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/image_io.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/image_io.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/image_io.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/image_io.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/image_io.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/nn_module_tools.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/nn_module_tools.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/nn_module_tools.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/nn_module_tools.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/nn_module_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/nn_module_tools.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/sh_rotation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/sh_rotation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/sh_rotation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/sh_rotation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/sh_rotation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/sh_rotation.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/step_tracker.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/step_tracker.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/step_tracker.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/step_tracker.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/step_tracker.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/step_tracker.cpython-310.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/wandb_tools.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/wandb_tools.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/__pycache__/wandb_tools.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/misc/__pycache__/wandb_tools.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | 39 | def clear_history(self) -> None: 40 | self.execution_times = defaultdict(list) 41 | -------------------------------------------------------------------------------- /src/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return { 14 | key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] 15 | } 16 | -------------------------------------------------------------------------------- /src/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /src/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /src/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | import skvideo.io 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as tf 9 | from einops import rearrange, repeat 10 | from jaxtyping import Float, UInt8 11 | from matplotlib.figure import Figure 12 | from PIL import Image 13 | from torch import Tensor 14 | 15 | FloatImage = Union[ 16 | Float[Tensor, "height width"], 17 | Float[Tensor, "channel height width"], 18 | Float[Tensor, "batch channel height width"], 19 | ] 20 | 21 | 22 | def fig_to_image( 23 | fig: Figure, 24 | dpi: int = 100, 25 | device: torch.device = torch.device("cpu"), 26 | ) -> Float[Tensor, "3 height width"]: 27 | buffer = io.BytesIO() 28 | fig.savefig(buffer, format="raw", dpi=dpi) 29 | buffer.seek(0) 30 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 31 | h = int(fig.bbox.bounds[3]) 32 | w = int(fig.bbox.bounds[2]) 33 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 34 | buffer.close() 35 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 36 | 37 | 38 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 39 | # Handle batched images. 40 | if image.ndim == 4: 41 | image = rearrange(image, "b c h w -> c h (b w)") 42 | 43 | # Handle single-channel images. 44 | if image.ndim == 2: 45 | image = rearrange(image, "h w -> () h w") 46 | 47 | # Ensure that there are 3 or 4 channels. 48 | channel, _, _ = image.shape 49 | if channel == 1: 50 | image = repeat(image, "() h w -> c h w", c=3) 51 | assert image.shape[0] in (3, 4) 52 | 53 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 54 | return rearrange(image, "c h w -> h w c").cpu().numpy() 55 | 56 | 57 | def save_image( 58 | image: FloatImage, 59 | path: Union[Path, str], 60 | ) -> None: 61 | """Save an image. Assumed to be in range 0-1.""" 62 | 63 | # Create the parent directory if it doesn't already exist. 64 | path = Path(path) 65 | path.parent.mkdir(exist_ok=True, parents=True) 66 | 67 | # Save the image. 68 | Image.fromarray(prep_image(image)).save(path) 69 | 70 | 71 | def load_image( 72 | path: Union[Path, str], 73 | ) -> Float[Tensor, "3 height width"]: 74 | return tf.ToTensor()(Image.open(path))[:3] 75 | 76 | 77 | def save_video( 78 | images: list[FloatImage], 79 | path: Union[Path, str], 80 | ) -> None: 81 | """Save an image. Assumed to be in range 0-1.""" 82 | 83 | # Create the parent directory if it doesn't already exist. 84 | path = Path(path) 85 | path.parent.mkdir(exist_ok=True, parents=True) 86 | 87 | # Save the image. 88 | # Image.fromarray(prep_image(image)).save(path) 89 | frames = [] 90 | for image in images: 91 | frames.append(prep_image(image)) 92 | 93 | writer = skvideo.io.FFmpegWriter(path, 94 | outputdict={'-pix_fmt': 'yuv420p', '-crf': '21', 95 | '-vf': f'setpts=1.*PTS'}) 96 | for frame in frames: 97 | writer.writeFrame(frame) 98 | writer.close() 99 | -------------------------------------------------------------------------------- /src/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /src/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | 3 | import torch 4 | from e3nn.o3 import matrix_to_angles, wigner_D 5 | from einops import einsum 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | 10 | def rotate_sh( 11 | sh_coefficients: Float[Tensor, "*#batch n"], 12 | rotations: Float[Tensor, "*#batch 3 3"], 13 | ) -> Float[Tensor, "*batch n"]: 14 | device = sh_coefficients.device 15 | dtype = sh_coefficients.dtype 16 | 17 | *_, n = sh_coefficients.shape 18 | alpha, beta, gamma = matrix_to_angles(rotations) 19 | result = [] 20 | for degree in range(isqrt(n)): 21 | with torch.cuda.device(device): 22 | sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) 23 | # sh_rotations = wigner_D(degree, alpha.to(device), beta.to(device), gamma.to(device)).type(dtype).to(device) 24 | sh_rotated = einsum( 25 | sh_rotations, 26 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 27 | "... i j, ... j -> ... i", 28 | ) 29 | result.append(sh_rotated) 30 | 31 | return torch.cat(result, dim=-1) 32 | 33 | 34 | if __name__ == "__main__": 35 | from pathlib import Path 36 | 37 | import matplotlib.pyplot as plt 38 | from e3nn.o3 import spherical_harmonics 39 | from matplotlib import cm 40 | from scipy.spatial.transform.rotation import Rotation as R 41 | 42 | device = torch.device("cuda") 43 | 44 | # Generate random spherical harmonics coefficients. 45 | degree = 4 46 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 47 | 48 | def plot_sh(sh_coefficients, path: Path) -> None: 49 | phi = torch.linspace(0, torch.pi, 100, device=device) 50 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 51 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 52 | x = torch.sin(phi) * torch.cos(theta) 53 | y = torch.sin(phi) * torch.sin(theta) 54 | z = torch.cos(phi) 55 | xyz = torch.stack([x, y, z], dim=-1) 56 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 57 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 58 | result = (result - result.min()) / (result.max() - result.min()) 59 | 60 | # Set the aspect ratio to 1 so our sphere looks spherical 61 | fig = plt.figure(figsize=plt.figaspect(1.0)) 62 | ax = fig.add_subplot(111, projection="3d") 63 | ax.plot_surface( 64 | x.cpu().numpy(), 65 | y.cpu().numpy(), 66 | z.cpu().numpy(), 67 | rstride=1, 68 | cstride=1, 69 | facecolors=cm.seismic(result.cpu().numpy()), 70 | ) 71 | # Turn off the axis planes 72 | ax.set_axis_off() 73 | path.parent.mkdir(exist_ok=True, parents=True) 74 | plt.savefig(path) 75 | 76 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 77 | rotation = torch.tensor( 78 | R.from_euler("x", angle.item()).as_matrix(), device=device 79 | ) 80 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 81 | 82 | print("Done!") 83 | -------------------------------------------------------------------------------- /src/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /src/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /src/model/__pycache__/model_wrapper.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/__pycache__/model_wrapper.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/model_wrapper.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/__pycache__/model_wrapper.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/model_wrapper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/__pycache__/model_wrapper.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/__pycache__/types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/__pycache__/types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/types.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/__pycache__/types.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from ...dataset import DatasetCfg 2 | from .decoder import Decoder 3 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 4 | 5 | DECODERS = { 6 | "splatting_cuda": DecoderSplattingCUDA, 7 | } 8 | 9 | DecoderCfg = DecoderSplattingCUDACfg 10 | 11 | 12 | def get_decoder(decoder_cfg: DecoderCfg, dataset_cfg: DatasetCfg) -> Decoder: 13 | return DECODERS[decoder_cfg.name](decoder_cfg, dataset_cfg) 14 | -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/cuda_splatting.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/cuda_splatting.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/cuda_splatting.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/cuda_splatting.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/cuda_splatting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/cuda_splatting.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/decoder.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/decoder.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/decoder.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/decoder.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/decoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/decoder_splatting_cuda.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/decoder_splatting_cuda.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/decoder_splatting_cuda.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/decoder_splatting_cuda.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/decoder/__pycache__/decoder_splatting_cuda.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/decoder/__pycache__/decoder_splatting_cuda.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ...dataset import DatasetCfg 9 | from ..types import Gaussians 10 | 11 | DepthRenderingMode = Literal[ 12 | "depth", 13 | "log", 14 | "disparity", 15 | "relative_disparity", 16 | ] 17 | 18 | 19 | @dataclass 20 | class DecoderOutput: 21 | color: Float[Tensor, "batch view 3 height width"] 22 | depth: Float[Tensor, "batch view height width"] | None 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | class Decoder(nn.Module, ABC, Generic[T]): 29 | cfg: T 30 | dataset_cfg: DatasetCfg 31 | 32 | def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: 33 | super().__init__() 34 | self.cfg = cfg 35 | self.dataset_cfg = dataset_cfg 36 | 37 | @abstractmethod 38 | def forward( 39 | self, 40 | gaussians: Gaussians, 41 | extrinsics: Float[Tensor, "batch view 4 4"], 42 | intrinsics: Float[Tensor, "batch view 3 3"], 43 | near: Float[Tensor, "batch view"], 44 | far: Float[Tensor, "batch view"], 45 | image_shape: tuple[int, int], 46 | depth_mode: DepthRenderingMode | None = None, 47 | ) -> DecoderOutput: 48 | pass 49 | -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_costvolume import EncoderCostVolume, EncoderCostVolumeCfg 5 | from .visualization.encoder_visualizer import EncoderVisualizer 6 | from .visualization.encoder_visualizer_costvolume import EncoderVisualizerCostVolume 7 | 8 | ENCODERS = { 9 | "costvolume": (EncoderCostVolume, EncoderVisualizerCostVolume), 10 | } 11 | 12 | EncoderCfg = EncoderCostVolumeCfg 13 | 14 | 15 | def get_encoder(cfg: EncoderCfg) -> tuple[Encoder, Optional[EncoderVisualizer]]: 16 | encoder, visualizer = ENCODERS[cfg.name] 17 | encoder = encoder(cfg) 18 | if visualizer is not None: 19 | visualizer = visualizer(cfg.visualizer, encoder) 20 | return encoder, visualizer 21 | -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/encoder.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/encoder.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/encoder.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/encoder.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/encoder.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/encoder_costvolume.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/encoder_costvolume.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/encoder_costvolume.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/encoder_costvolume.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/__pycache__/encoder_costvolume.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/__pycache__/encoder_costvolume.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone_multiview import BackboneMultiview 2 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/backbone_multiview.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/backbone_multiview.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/backbone_multiview.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/backbone_multiview.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/backbone_multiview.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/backbone_multiview.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/multiview_transformer.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/multiview_transformer.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/multiview_transformer.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/multiview_transformer.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/__pycache__/multiview_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/__pycache__/multiview_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/backbone.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/backbone.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/backbone.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/backbone.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/backbone.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/backbone.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/geometry.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/geometry.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/geometry.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/geometry.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/geometry.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/geometry.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/position.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/position.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/position.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/position.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/position.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/position.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/trident_conv.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/trident_conv.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/trident_conv.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/trident_conv.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/trident_conv.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/trident_conv.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/utils.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/utils.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/utils.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/utils.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/backbone/unimatch/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /src/model/encoder/common/__pycache__/gaussian_adapter.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/common/__pycache__/gaussian_adapter.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/common/__pycache__/gaussian_adapter.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/common/__pycache__/gaussian_adapter.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/common/__pycache__/gaussian_adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/common/__pycache__/gaussian_adapter.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/common/__pycache__/gaussians.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/common/__pycache__/gaussians.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/common/__pycache__/gaussians.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/common/__pycache__/gaussians.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/common/__pycache__/gaussians.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/common/__pycache__/gaussians.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/common/gaussians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 8 | def quaternion_to_matrix( 9 | quaternions: Float[Tensor, "*batch 4"], 10 | eps: float = 1e-8, 11 | ) -> Float[Tensor, "*batch 3 3"]: 12 | # Order changed to match scipy format! 13 | i, j, k, r = torch.unbind(quaternions, dim=-1) 14 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 15 | 16 | o = torch.stack( 17 | ( 18 | 1 - two_s * (j * j + k * k), 19 | two_s * (i * j - k * r), 20 | two_s * (i * k + j * r), 21 | two_s * (i * j + k * r), 22 | 1 - two_s * (i * i + k * k), 23 | two_s * (j * k - i * r), 24 | two_s * (i * k - j * r), 25 | two_s * (j * k + i * r), 26 | 1 - two_s * (i * i + j * j), 27 | ), 28 | -1, 29 | ) 30 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 31 | 32 | 33 | def build_covariance( 34 | scale: Float[Tensor, "*#batch 3"], 35 | rotation_xyzw: Float[Tensor, "*#batch 4"], 36 | ) -> Float[Tensor, "*batch 3 3"]: 37 | scale = scale.diag_embed() 38 | rotation = quaternion_to_matrix(rotation_xyzw) 39 | return ( 40 | rotation 41 | @ scale 42 | @ rearrange(scale, "... i j -> ... j i") 43 | @ rearrange(rotation, "... i j -> ... j i") 44 | ) 45 | -------------------------------------------------------------------------------- /src/model/encoder/common/sampler.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float, Int64, Shaped 2 | from torch import Tensor, nn 3 | 4 | from ....misc.discrete_probability_distribution import ( 5 | gather_discrete_topk, 6 | sample_discrete_distribution, 7 | ) 8 | 9 | 10 | class Sampler(nn.Module): 11 | def forward( 12 | self, 13 | probabilities: Float[Tensor, "*batch bucket"], 14 | num_samples: int, 15 | deterministic: bool, 16 | ) -> tuple[ 17 | Int64[Tensor, "*batch 1"], # index 18 | Float[Tensor, "*batch 1"], # probability density 19 | ]: 20 | return ( 21 | gather_discrete_topk(probabilities, num_samples) 22 | if deterministic 23 | else sample_discrete_distribution(probabilities, num_samples) 24 | ) 25 | 26 | def gather( 27 | self, 28 | index: Int64[Tensor, "*batch sample"], 29 | target: Shaped[Tensor, "..."], # *batch bucket *shape 30 | ) -> Shaped[Tensor, "..."]: # *batch sample *shape 31 | """Gather from the target according to the specified index. Handle the 32 | broadcasting needed for the gather to work. See the comments for the actual 33 | expected input/output shapes since jaxtyping doesn't support multiple variadic 34 | lengths in annotations. 35 | """ 36 | bucket_dim = index.ndim - 1 37 | while len(index.shape) < len(target.shape): 38 | index = index[..., None] 39 | broadcasted_index_shape = list(target.shape) 40 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 41 | index = index.broadcast_to(broadcasted_index_shape) 42 | return target.gather(dim=bucket_dim, index=index) 43 | -------------------------------------------------------------------------------- /src/model/encoder/costvolume/__pycache__/conversions.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/__pycache__/conversions.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/__pycache__/conversions.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/__pycache__/conversions.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/__pycache__/conversions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/__pycache__/conversions.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/__pycache__/depth_predictor_multiview.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/__pycache__/depth_predictor_multiview.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/__pycache__/depth_predictor_multiview.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/__pycache__/depth_predictor_multiview.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/__pycache__/depth_predictor_multiview.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/__pycache__/depth_predictor_multiview.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/conversions.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float 2 | from torch import Tensor 3 | 4 | 5 | def relative_disparity_to_depth( 6 | relative_disparity: Float[Tensor, "*#batch"], 7 | near: Float[Tensor, "*#batch"], 8 | far: Float[Tensor, "*#batch"], 9 | eps: float = 1e-10, 10 | ) -> Float[Tensor, " *batch"]: 11 | """Convert relative disparity, where 0 is near and 1 is far, to depth.""" 12 | disp_near = 1 / (near + eps) 13 | disp_far = 1 / (far + eps) 14 | return 1 / ((1 - relative_disparity) * (disp_near - disp_far) + disp_far + eps) 15 | 16 | 17 | def depth_to_relative_disparity( 18 | depth: Float[Tensor, "*#batch"], 19 | near: Float[Tensor, "*#batch"], 20 | far: Float[Tensor, "*#batch"], 21 | eps: float = 1e-10, 22 | ) -> Float[Tensor, " *batch"]: 23 | """Convert depth to relative disparity, where 0 is near and 1 is far""" 24 | disp_near = 1 / (near + eps) 25 | disp_far = 1 / (far + eps) 26 | disp = 1 / (depth + eps) 27 | return 1 - (disp - disp_far) / (disp_near - disp_far + eps) 28 | -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/__init__.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/__init__.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/attention.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/attention.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/attention.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/attention.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/unet.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/unet.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/unet.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/unet.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/util.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/util.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/util.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/util.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/costvolume/ldm_unet/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import nn 5 | 6 | from ...dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | deterministic: bool, 24 | ): 25 | pass 26 | 27 | def get_data_shim(self) -> DataShim: 28 | """The default shim doesn't modify the batch.""" 29 | return lambda x: x 30 | -------------------------------------------------------------------------------- /src/model/encoder/epipolar/__pycache__/epipolar_sampler.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/epipolar/__pycache__/epipolar_sampler.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/epipolar/__pycache__/epipolar_sampler.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/epipolar/__pycache__/epipolar_sampler.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/epipolar/__pycache__/epipolar_sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/epipolar/__pycache__/epipolar_sampler.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume_cfg.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume_cfg.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume_cfg.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume_cfg.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume_cfg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encoder/visualization/__pycache__/encoder_visualizer_costvolume_cfg.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | T_cfg = TypeVar("T_cfg") 8 | T_encoder = TypeVar("T_encoder") 9 | 10 | 11 | class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): 12 | cfg: T_cfg 13 | encoder: T_encoder 14 | 15 | def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: 16 | self.cfg = cfg 17 | self.encoder = encoder 18 | 19 | @abstractmethod 20 | def visualize( 21 | self, 22 | context: dict, 23 | global_step: int, 24 | ) -> dict[str, Float[Tensor, "3 _ _"]]: 25 | pass 26 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer_costvolume_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # This is in a separate file to avoid circular imports. 4 | 5 | 6 | @dataclass 7 | class EncoderVisualizerCostVolumeCfg: 8 | num_samples: int 9 | min_resolution: int 10 | export_ply: bool 11 | -------------------------------------------------------------------------------- /src/model/encodings/__pycache__/positional_encoding.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encodings/__pycache__/positional_encoding.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encodings/__pycache__/positional_encoding.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encodings/__pycache__/positional_encoding.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/model/encodings/__pycache__/positional_encoding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/model/encodings/__pycache__/positional_encoding.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/encodings/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum, rearrange, repeat 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """For the sake of simplicity, this encodes values in the range [0, 1].""" 10 | 11 | frequencies: Float[Tensor, "frequency phase"] 12 | phases: Float[Tensor, "frequency phase"] 13 | 14 | def __init__(self, num_octaves: int): 15 | super().__init__() 16 | octaves = torch.arange(num_octaves).float() 17 | 18 | # The lowest frequency has a period of 1. 19 | frequencies = 2 * torch.pi * 2**octaves 20 | frequencies = repeat(frequencies, "f -> f p", p=2) 21 | self.register_buffer("frequencies", frequencies, persistent=False) 22 | 23 | # Choose the phases to match sine and cosine. 24 | phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) 25 | phases = repeat(phases, "p -> f p", f=num_octaves) 26 | self.register_buffer("phases", phases, persistent=False) 27 | 28 | def forward( 29 | self, 30 | samples: Float[Tensor, "*batch dim"], 31 | ) -> Float[Tensor, "*batch embedded_dim"]: 32 | samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") 33 | return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") 34 | 35 | def d_out(self, dimensionality: int): 36 | return self.frequencies.numel() * dimensionality 37 | -------------------------------------------------------------------------------- /src/model/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | extrinsics: Float[Tensor, "4 4"], 28 | means: Float[Tensor, "gaussian 3"], 29 | scales: Float[Tensor, "gaussian 3"], 30 | rotations: Float[Tensor, "gaussian 4"], 31 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 32 | opacities: Float[Tensor, " gaussian"], 33 | path: Path, 34 | ): 35 | # Shift the scene so that the median Gaussian is at the origin. 36 | means = means - means.median(dim=0).values 37 | 38 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 39 | scale_factor = means.abs().quantile(0.95, dim=0).max() 40 | means = means / scale_factor 41 | scales = scales / scale_factor 42 | 43 | # Define a rotation that makes +Z be the world up vector. 44 | rotation = [ 45 | [0, 0, 1], 46 | [-1, 0, 0], 47 | [0, -1, 0], 48 | ] 49 | rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) 50 | 51 | # The Polycam viewer seems to start at a 45 degree angle. Since we want to be 52 | # looking directly at the object, we compose a 45 degree rotation onto the above 53 | # rotation. 54 | adjustment = torch.tensor( 55 | R.from_rotvec([0, 0, -45], True).as_matrix(), 56 | dtype=torch.float32, 57 | device=means.device, 58 | ) 59 | rotation = adjustment @ rotation 60 | 61 | # We also want to see the scene in camera space (as the default view). We therefore 62 | # compose the w2c rotation onto the above rotation. 63 | rotation = rotation @ extrinsics[:3, :3].inverse() 64 | 65 | # Apply the rotation to the means (Gaussian positions). 66 | means = einsum(rotation, means, "i j, ... j -> ... i") 67 | 68 | # Apply the rotation to the Gaussian rotations. 69 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 70 | rotations = rotation.detach().cpu().numpy() @ rotations 71 | rotations = R.from_matrix(rotations).as_quat() 72 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 73 | rotations = np.stack((w, x, y, z), axis=-1) 74 | 75 | # Since our axes are swizzled for the spherical harmonics, we only export the DC 76 | # band. 77 | harmonics_view_invariant = harmonics[..., 0] 78 | 79 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] 80 | elements = np.empty(means.shape[0], dtype=dtype_full) 81 | attributes = ( 82 | means.detach().cpu().numpy(), 83 | torch.zeros_like(means).detach().cpu().numpy(), 84 | harmonics_view_invariant.detach().cpu().contiguous().numpy(), 85 | opacities[..., None].detach().cpu().numpy(), 86 | scales.log().detach().cpu().numpy(), 87 | rotations, 88 | ) 89 | attributes = np.concatenate(attributes, axis=1) 90 | elements[:] = list(map(tuple, attributes)) 91 | path.parent.mkdir(exist_ok=True, parents=True) 92 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 93 | -------------------------------------------------------------------------------- /src/model/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] 13 | -------------------------------------------------------------------------------- /src/scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import hydra 6 | import torch 7 | from jaxtyping import install_import_hook 8 | from omegaconf import DictConfig 9 | from pytorch_lightning import Trainer 10 | 11 | # Configure beartype and jaxtyping. 12 | with install_import_hook( 13 | ("src",), 14 | ("beartype", "beartype"), 15 | ): 16 | from src.config import load_typed_config 17 | from src.dataset.data_module import DataLoaderCfg, DataModule, DatasetCfg 18 | from src.evaluation.evaluation_cfg import EvaluationCfg 19 | from src.evaluation.metric_computer import MetricComputer 20 | from src.global_cfg import set_cfg 21 | 22 | 23 | @dataclass 24 | class RootCfg: 25 | evaluation: EvaluationCfg 26 | dataset: DatasetCfg 27 | data_loader: DataLoaderCfg 28 | seed: int 29 | output_metrics_path: Path 30 | 31 | 32 | @hydra.main( 33 | version_base=None, 34 | config_path="../../config", 35 | config_name="compute_metrics", 36 | ) 37 | def evaluate(cfg_dict: DictConfig): 38 | cfg = load_typed_config(cfg_dict, RootCfg) 39 | set_cfg(cfg_dict) 40 | torch.manual_seed(cfg.seed) 41 | trainer = Trainer(max_epochs=-1, accelerator="gpu") 42 | computer = MetricComputer(cfg.evaluation) 43 | data_module = DataModule(cfg.dataset, cfg.data_loader) 44 | metrics = trainer.test(computer, datamodule=data_module) 45 | cfg.output_metrics_path.parent.mkdir(exist_ok=True, parents=True) 46 | with cfg.output_metrics_path.open("w") as f: 47 | json.dump(metrics[0], f) 48 | 49 | 50 | if __name__ == "__main__": 51 | evaluate() 52 | -------------------------------------------------------------------------------- /src/scripts/dump_launch_configs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from colorama import Fore 5 | 6 | if __name__ == "__main__": 7 | # Go to the repo directory. 8 | x = Path.cwd() 9 | while not (x / ".git").exists(): 10 | x = x.parent 11 | 12 | # Hackily load JSON with comments and trailing commas. 13 | with (x / ".vscode/launch.json").open("r") as f: 14 | launch_with_comments = f.readlines() 15 | launch = [ 16 | line for line in launch_with_comments if not line.strip().startswith("//") 17 | ] 18 | launch = "".join(launch) 19 | launch = yaml.safe_load(launch) 20 | 21 | for cfg in launch["configurations"]: 22 | print(f"{Fore.CYAN}{cfg['name']}{Fore.RESET}") 23 | 24 | arg_str = " ".join(cfg.get("args", [])) 25 | if "env" in cfg: 26 | env_str = " ".join([f"{key}={value}" for key, value in cfg["env"].items()]) 27 | else: 28 | env_str = "" 29 | 30 | command = f"{env_str} python3 -m {cfg['module']} {arg_str}".strip() 31 | print(f"{command}\n") 32 | -------------------------------------------------------------------------------- /src/scripts/generate_dtu_evaluation_index.py: -------------------------------------------------------------------------------- 1 | ''' Build upon: https://github.com/autonomousvision/murf/blob/main/datasets/dtu.py 2 | ''' 3 | 4 | import torch 5 | import os 6 | from glob import glob 7 | import argparse 8 | from einops import rearrange, repeat 9 | from dataclasses import asdict, dataclass 10 | import json 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | 15 | @dataclass 16 | class IndexEntry: 17 | context: tuple[int, ...] 18 | target: tuple[int, ...] 19 | 20 | 21 | def sorted_test_src_views_fixed(cam2worlds_dict, test_views, train_views): 22 | # use fixed src views for testing, instead of for using different src views for different test views 23 | cam_pos_trains = np.stack([cam2worlds_dict[x] for x in train_views])[ 24 | :, :3, 3 25 | ] # [V, 3], V src views 26 | cam_pos_target = np.stack([cam2worlds_dict[x] for x in test_views])[ 27 | :, :3, 3 28 | ] # [N, 3], N test views in total 29 | dis = np.sum( 30 | np.abs(cam_pos_trains[:, None] - cam_pos_target[None]), axis=(1, 2)) 31 | src_idx = np.argsort(dis) 32 | src_idx = [train_views[x] for x in src_idx] 33 | 34 | return src_idx 35 | 36 | 37 | def main(args): 38 | data_dir = os.path.join("datasets", args.dataset_name, "test") 39 | 40 | # load view pairs 41 | # Adopt from: https://github.com/autonomousvision/murf/blob/main/datasets/dtu.py#L95 42 | test_views = [32, 24, 23, 44] 43 | train_views = [i for i in range(49) if i not in test_views] 44 | 45 | index = {} 46 | for torch_file in tqdm(sorted(glob(os.path.join(data_dir, "*.torch")))): 47 | scene_datas = torch.load(torch_file) 48 | for scene_data in scene_datas: 49 | cameras = scene_data["cameras"] 50 | scene_name = scene_data["key"] 51 | 52 | # calculate nearest camera index 53 | w2c = repeat( 54 | torch.eye(4, dtype=torch.float32), 55 | "h w -> b h w", 56 | b=cameras.shape[0], 57 | ).clone() 58 | w2c[:, :3] = rearrange( 59 | cameras[:, 6:], "b (h w) -> b h w", h=3, w=4) 60 | opencv_c2ws = w2c.inverse() # .unsqueeze(0) 61 | xyzs = opencv_c2ws[:, :3, -1].unsqueeze(0) # 1, N, 3 62 | cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) 63 | cameras_dist_index = torch.argsort( 64 | cameras_dist_matrix, dim=-1).squeeze(0) 65 | 66 | cam2worlds_dict = {k: v for k, v in enumerate(opencv_c2ws)} 67 | nearest_fixed_views = sorted_test_src_views_fixed( 68 | cam2worlds_dict, test_views, train_views 69 | ) 70 | 71 | selected_pts = test_views 72 | for seq_idx, cur_mid in enumerate(selected_pts): 73 | cur_nn_index = nearest_fixed_views 74 | contexts = tuple([int(x) 75 | for x in cur_nn_index[: args.n_contexts]]) 76 | targets = (cur_mid,) 77 | index[f"{scene_name}_{seq_idx:02d}"] = IndexEntry( 78 | context=contexts, 79 | target=targets, 80 | ) 81 | # save index to files 82 | out_path = f"assets/evaluation_index_{args.dataset_name}_nctx{args.n_contexts}.json" 83 | with open(out_path, "w") as f: 84 | json.dump({k: None if v is None else asdict(v) 85 | for k, v in index.items()}, f) 86 | print(f"Dumped index to: {out_path}") 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--n_contexts", type=int, 93 | default=2, help="output directory") 94 | parser.add_argument("--dataset_name", type=str, default="dtu") 95 | 96 | params = parser.parse_args() 97 | 98 | main(params) 99 | -------------------------------------------------------------------------------- /src/scripts/generate_evaluation_index.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import hydra 4 | import torch 5 | from jaxtyping import install_import_hook 6 | from omegaconf import DictConfig 7 | from pytorch_lightning import Trainer 8 | 9 | # Configure beartype and jaxtyping. 10 | with install_import_hook( 11 | ("src",), 12 | ("beartype", "beartype"), 13 | ): 14 | from src.config import load_typed_config 15 | from src.dataset import DatasetCfg 16 | from src.dataset.data_module import DataLoaderCfg, DataModule 17 | from src.evaluation.evaluation_index_generator import ( 18 | EvaluationIndexGenerator, 19 | EvaluationIndexGeneratorCfg, 20 | ) 21 | from src.global_cfg import set_cfg 22 | 23 | 24 | @dataclass 25 | class RootCfg: 26 | dataset: DatasetCfg 27 | data_loader: DataLoaderCfg 28 | index_generator: EvaluationIndexGeneratorCfg 29 | seed: int 30 | 31 | 32 | @hydra.main( 33 | version_base=None, 34 | config_path="../../config", 35 | config_name="generate_evaluation_index", 36 | ) 37 | def train(cfg_dict: DictConfig): 38 | cfg = load_typed_config(cfg_dict, RootCfg) 39 | set_cfg(cfg_dict) 40 | torch.manual_seed(cfg.seed) 41 | trainer = Trainer(max_epochs=1, accelerator="gpu", devices="auto", strategy="auto") 42 | data_module = DataModule(cfg.dataset, cfg.data_loader, None) 43 | evaluation_index_generator = EvaluationIndexGenerator(cfg.index_generator) 44 | trainer.test(evaluation_index_generator, datamodule=data_module) 45 | evaluation_index_generator.save_index() 46 | 47 | 48 | if __name__ == "__main__": 49 | train() 50 | -------------------------------------------------------------------------------- /src/scripts/generate_video_evaluation_index.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--index_input', type=str, help='depth directory') 8 | parser.add_argument('--index_output', type=str, help='dataset directory') 9 | args = parser.parse_args() 10 | 11 | 12 | # INDEX_INPUT = Path("assets/evaluation_index_re10k.json") 13 | # INDEX_OUTPUT = Path("assets/evaluation_index_re10k_video.json") 14 | INDEX_INPUT = Path(args.index_input) 15 | INDEX_OUTPUT = Path(args.index_output) 16 | 17 | 18 | if __name__ == "__main__": 19 | with INDEX_INPUT.open("r") as f: 20 | index_input = json.load(f) 21 | 22 | index_output = {} 23 | for scene, scene_index_input in index_input.items(): 24 | # Handle scenes for which there's no index. 25 | if scene_index_input is None: 26 | index_output[scene] = None 27 | continue 28 | 29 | # Add all intermediate frames as target frames. 30 | a, b = scene_index_input["context"] 31 | index_output[scene] = { 32 | "context": [a, b], 33 | "target": list(range(a, b + 1)), 34 | } 35 | 36 | with INDEX_OUTPUT.open("w") as f: 37 | json.dump(index_output, f) 38 | -------------------------------------------------------------------------------- /src/visualization/__pycache__/annotation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/annotation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/annotation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/annotation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/annotation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/annotation.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/color_map.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/color_map.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/color_map.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/color_map.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/color_map.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/color_map.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/colors.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/colors.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/colors.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/colors.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/colors.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/colors.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/layout.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/layout.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/layout.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/layout.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/layout.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/layout.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/validation_in_3d.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/validation_in_3d.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/validation_in_3d.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/validation_in_3d.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/__pycache__/validation_in_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/__pycache__/validation_in_3d.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/__pycache__/interpolation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/camera_trajectory/__pycache__/interpolation.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/__pycache__/interpolation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/camera_trajectory/__pycache__/interpolation.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/__pycache__/interpolation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/camera_trajectory/__pycache__/interpolation.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/__pycache__/wobble.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/camera_trajectory/__pycache__/wobble.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/__pycache__/wobble.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/camera_trajectory/__pycache__/wobble.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/__pycache__/wobble.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/camera_trajectory/__pycache__/wobble.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /src/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /src/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/cameras.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/cameras.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/cameras.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/cameras.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/cameras.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/cameras.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/coordinate_conversion.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/coordinate_conversion.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/coordinate_conversion.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/coordinate_conversion.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/coordinate_conversion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/coordinate_conversion.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/lines.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/lines.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/lines.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/lines.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/lines.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/lines.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/points.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/points.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/points.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/points.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/points.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/points.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/rendering.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/rendering.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/rendering.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/rendering.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/rendering.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/rendering.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/types.cpython-310.opt-jaxtyping883a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/types.cpython-310.opt-jaxtyping983a4111806314cc973c4ea00fb072bf6.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/__pycache__/types.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengjun-zhang/GGN/1563ba69db17b6923944a2b31e4f37ede19b6549/src/visualization/drawing/__pycache__/types.cpython-310.pyc -------------------------------------------------------------------------------- /src/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /src/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /src/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /src/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /src/visualization/vis_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import torchvision.utils as vutils 5 | import cv2 6 | from matplotlib.cm import get_cmap 7 | import matplotlib as mpl 8 | import matplotlib.cm as cm 9 | 10 | 11 | # https://github.com/autonomousvision/unimatch/blob/master/utils/visualization.py 12 | 13 | 14 | def vis_disparity(disp): 15 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 16 | disp_vis = disp_vis.astype("uint8") 17 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) 18 | 19 | return disp_vis 20 | 21 | 22 | def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'): 23 | # visualize inverse depth 24 | assert isinstance(disp, torch.Tensor) 25 | 26 | disp = disp.numpy() 27 | vmax = np.percentile(disp, 95) 28 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) 29 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 30 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] 31 | 32 | if return_numpy: 33 | return colormapped_im 34 | 35 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] 36 | 37 | return viz 38 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=3 python -m src.main +experiment=re10k \ 2 | # checkpointing.load=outputs/2024-05-11/00-07-44/checkpoints/epoch_9-step_75000.ckpt \ 3 | # mode=test \ 4 | # dataset/view_sampler=evaluation \ 5 | # test.compute_scores=true \ 6 | # dataset.view_sampler.index_path=assets/re10k_4view.json 7 | 8 | 9 | CUDA_VISIBLE_DEVICES=0 python -m src.main +experiment=acid \ 10 | checkpointing.load=/data1/zsj/PixelGS/new/effsplat/outputs/2024-05-11/00-00-38/checkpoints/epoch_54-step_75000.ckpt \ 11 | mode=test \ 12 | dataset/view_sampler=evaluation \ 13 | test.compute_scores=true \ 14 | dataset.view_sampler.index_path=assets/acid_4view.json -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python -m src.main +experiment=re10k data_loader.train.batch_size=6 checkpointing.load=checkpoints/acid_continue_train.ckpt --------------------------------------------------------------------------------