├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── framework.png └── teaser.png ├── config ├── distill_scannet.yaml ├── eval.yaml ├── fusion_mipnerf360.yaml ├── fusion_mvimgnet.yaml ├── fusion_panoptic.yaml ├── fusion_scannet.yaml ├── official_train.yaml └── view_scannet.yaml ├── dataset ├── __init__.py ├── augmentation.py ├── feature_dataset.py ├── fusion_utils.py └── scannet │ ├── __init__.py │ ├── label_mapping.py │ ├── scannet_constants.py │ ├── scannetv2-labels.modified.tsv │ ├── scannetv2_test.txt │ ├── scannetv2_train.txt │ └── scannetv2_val.txt ├── distill.py ├── environment.yml ├── eval_segmentation.py ├── fusion.py ├── model ├── __init__.py ├── gaussian_model.py ├── lseg │ ├── LICENSE │ ├── README.MD │ ├── additional_utils │ │ ├── encoding_models.py │ │ └── models.py │ ├── data │ │ └── __init__.py │ ├── fewshot_data │ │ ├── README.md │ │ ├── common │ │ │ ├── evaluation.py │ │ │ ├── logger.py │ │ │ ├── utils.py │ │ │ └── vis.py │ │ ├── data │ │ │ ├── assets │ │ │ │ ├── architecture.png │ │ │ │ └── qualitative_results.png │ │ │ ├── coco.py │ │ │ ├── dataset.py │ │ │ ├── fss.py │ │ │ ├── pascal.py │ │ │ └── splits │ │ │ │ ├── coco │ │ │ │ ├── trn │ │ │ │ │ ├── fold0.pkl │ │ │ │ │ ├── fold1.pkl │ │ │ │ │ ├── fold2.pkl │ │ │ │ │ └── fold3.pkl │ │ │ │ └── val │ │ │ │ │ ├── fold0.pkl │ │ │ │ │ ├── fold1.pkl │ │ │ │ │ ├── fold2.pkl │ │ │ │ │ └── fold3.pkl │ │ │ │ ├── fss │ │ │ │ ├── test.txt │ │ │ │ ├── trn.txt │ │ │ │ └── val.txt │ │ │ │ └── pascal │ │ │ │ ├── trn │ │ │ │ ├── fold0.txt │ │ │ │ ├── fold1.txt │ │ │ │ ├── fold2.txt │ │ │ │ └── fold3.txt │ │ │ │ └── val │ │ │ │ ├── fold0.txt │ │ │ │ ├── fold1.txt │ │ │ │ ├── fold2.txt │ │ │ │ └── fold3.txt │ │ ├── model │ │ │ ├── base │ │ │ │ ├── conv4d.py │ │ │ │ ├── correlation.py │ │ │ │ └── feature.py │ │ │ ├── hsnet.py │ │ │ └── learner.py │ │ ├── test.py │ │ └── train.py │ ├── inputs │ │ └── cat1.jpeg │ ├── label_files │ │ ├── ade20k_objectInfo150.txt │ │ ├── fewshot_coco.txt │ │ ├── fewshot_fss.txt │ │ └── fewshot_pascal.txt │ ├── lseg │ │ ├── lseg_module.py │ │ ├── lseg_module_zs.py │ │ ├── lsegmentation_module.py │ │ ├── lsegmentation_module_zs.py │ │ └── models │ │ │ ├── lseg_blocks.py │ │ │ ├── lseg_blocks_zs.py │ │ │ ├── lseg_net.py │ │ │ ├── lseg_net_zs.py │ │ │ ├── lseg_vit.py │ │ │ └── lseg_vit_zs.py │ ├── lseg_app.py │ ├── lseg_demo.ipynb │ ├── modules │ │ ├── lseg_module.py │ │ ├── lseg_module_zs.py │ │ ├── lsegmentation_module.py │ │ ├── lsegmentation_module_zs.py │ │ └── models │ │ │ ├── lseg_blocks.py │ │ │ ├── lseg_blocks_zs.py │ │ │ ├── lseg_net.py │ │ │ ├── lseg_net_zs.py │ │ │ ├── lseg_vit.py │ │ │ └── lseg_vit_zs.py │ ├── prepare_ade20k.py │ ├── setup.cfg │ ├── setup.py │ ├── test_lseg.py │ ├── test_lseg_zs.py │ ├── train_lseg.py │ └── utils.py ├── lseg_predictor.py ├── mink_unet.py ├── openseg_predictor.py ├── render_utils.py ├── renderer.py ├── resnet_base.py ├── samclip_predictor.py ├── vlpart │ ├── __init__.py │ ├── swintransformer.py │ ├── text_encoder.py │ ├── vlpart.py │ ├── vlpart_fast_rcnn.py │ ├── vlpart_roi_heads.py │ └── vocab.py └── vlpart_predictor.py ├── requirements.txt ├── scene ├── __init__.py ├── blender_loader.py ├── camera.py ├── colmap_loader.py ├── scannet_loader.py └── scene.py ├── submodules ├── channel-rasterization │ ├── .gitignore │ ├── CMakeLists.txt │ ├── LICENSE.md │ ├── README.md │ ├── channel_rasterization │ │ └── __init__.py │ ├── cuda_rasterizer │ │ ├── auxiliary.h │ │ ├── backward.cu │ │ ├── backward.h │ │ ├── config.h │ │ ├── forward.cu │ │ ├── forward.h │ │ ├── rasterizer.h │ │ ├── rasterizer_impl.cu │ │ └── rasterizer_impl.h │ ├── ext.cpp │ ├── rasterize_points.cu │ ├── rasterize_points.h │ ├── setup.py │ └── third_party │ │ └── stbi_image_write.h ├── rgbd-rasterization │ ├── CMakeLists.txt │ ├── LICENSE.md │ ├── README.md │ ├── cuda_rasterizer │ │ ├── auxiliary.h │ │ ├── backward.cu │ │ ├── backward.h │ │ ├── config.h │ │ ├── forward.cu │ │ ├── forward.h │ │ ├── rasterizer.h │ │ ├── rasterizer_impl.cu │ │ └── rasterizer_impl.h │ ├── ext.cpp │ ├── rasterize_points.cu │ ├── rasterize_points.h │ ├── rgbd_rasterization │ │ └── __init__.py │ ├── setup.py │ └── third_party │ │ └── stbi_image_write.h ├── segment-anything │ ├── .flake8 │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── linter.sh │ ├── segment_anything │ │ ├── __init__.py │ │ ├── automask.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── build_sam_hq.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── image_encoder.py │ │ │ ├── mask_decoder.py │ │ │ ├── mask_decoder_hq.py │ │ │ ├── prompt_encoder.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ ├── predictor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── amg.py │ │ │ ├── onnx.py │ │ │ └── transforms.py │ ├── setup.cfg │ └── setup.py └── simple-knn │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn.cu │ ├── simple_knn.h │ ├── simple_knn │ └── .gitkeep │ ├── spatial.cu │ └── spatial.h ├── tools ├── scannet_sens_reader.py └── unzip_label_filt.py ├── train.py ├── utils ├── camera_utils.py ├── dataset_utils.py ├── general_utils.py ├── graphics_utils.py ├── loss_utils.py ├── metric.py ├── sh_utils.py └── system_utils.py └── view_viser.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | weights/ 3 | results_distill/ 4 | semantic/ 5 | eval_samples/ 6 | fusion/ 7 | *.sh 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | .vscode 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/channel-rasterization/third_party/glm"] 2 | path = submodules/channel-rasterization/third_party/glm 3 | url = https://github.com/g-truc/glm/ 4 | [submodule "submodules/rgbd-rasterization/third_party/glm"] 5 | path = submodules/rgbd-rasterization/third_party/glm 6 | url = https://github.com/g-truc/glm/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/assets/framework.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/assets/teaser.png -------------------------------------------------------------------------------- /config/distill_scannet.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 1 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 1 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | dynamic: False 21 | load_iteration: 10000 22 | device: "cuda:0" 23 | 24 | fusion: 25 | out_dir: "/PATH/TO/YOUR/OWN" 26 | 27 | distill: 28 | exp_name: openseg_new 29 | model_3d: MinkUNet34A 30 | voxel_size: 0.02 31 | aug: True 32 | feature_type: all 33 | lr: 0.001 34 | epochs: 100 35 | loss_type: cosine 36 | schedule_milestones: [20, 40, 60, 80, 100] 37 | schedule_gamma: 0.3 38 | batch_size: 1 39 | num_workers: 16 40 | test_interval: 10 41 | save_interval: 10 42 | -------------------------------------------------------------------------------- /config/eval.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 1 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 0 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | load_iteration: -1 21 | device: "cuda:0" 22 | 23 | fusion: 24 | out_dir: "/PATH/TO/YOUR/OWN" 25 | model_2d: openseg 26 | 27 | distill: 28 | model_3d: MinkUNet34A 29 | model_dir: "/PATH/TO/YOUR/OWN" 30 | text_model: openseg 31 | voxel_size: 0.02 32 | iteration: 100 33 | feature_type: all 34 | 35 | eval: 36 | eval_mode: labelmap # choose from 2d, 3d, 2d_and_3d, pretrained, labelmap 37 | width: 648 38 | height: 484 39 | pred_on_3d: True 40 | feature_fusion: concat # choose from concat, argmax 41 | 42 | -------------------------------------------------------------------------------- /config/fusion_mipnerf360.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 1 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 0 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | dynamic: False 21 | load_iteration: -1 22 | device: "cuda:0" 23 | 24 | fusion: 25 | img_dim: [779, 519] 26 | num_workers: 8 27 | model_2d: openseg # choose from openseg, lseg, samclip, vlpart 28 | depth: render # choose from image, render, surface, none 29 | depth_scale: 1000.0 30 | visibility_threshold: 0.05 31 | cut_boundary: 10 32 | n_split_points: 9999999 #80000 33 | num_rand_file_per_scene: 1 #5 34 | out_dir: "./fusion" -------------------------------------------------------------------------------- /config/fusion_mvimgnet.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 0.5 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 0 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | dynamic: False 21 | load_iteration: -1 22 | device: "cuda:0" 23 | 24 | fusion: 25 | img_dim: [540, 960] 26 | num_workers: 8 27 | model_2d: openseg # choose from openseg, lseg, samclip, vlpart 28 | depth: render # choose from image, render, surface, none 29 | depth_scale: 1000.0 30 | visibility_threshold: 0.02 31 | outlier_threshold: -2.0 32 | cut_boundary: 10 33 | n_split_points: 99999999 #50000 34 | num_rand_file_per_scene: 1 #5 35 | out_dir: "/PATH/TO/YOUR/OWN" 36 | -------------------------------------------------------------------------------- /config/fusion_panoptic.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 1 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 0 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | dynamic: True 21 | load_iteration: -1 22 | num_timesteps: 150 23 | device: "cuda:0" 24 | 25 | fusion: 26 | img_dim: [640, 360] 27 | num_workers: 8 28 | model_2d: vlpart # choose from openseg, lseg, samclip, vlpart 29 | depth: surface # choose from image, render, surface, none 30 | depth_scale: 1000.0 31 | visibility_threshold: 0.01 32 | cut_boundary: 10 33 | outlier_threshold: -2.0 34 | n_split_points: 9999999 35 | num_rand_file_per_scene: 1 36 | out_dir: "/PATH/TO/YOUR/OWN" 37 | -------------------------------------------------------------------------------- /config/fusion_scannet.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 1 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 0 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | dynamic: False 21 | load_iteration: -1 22 | device: "cuda:0" 23 | 24 | fusion: 25 | img_dim: [648, 484] 26 | num_workers: 8 27 | model_2d: openseg # choose from openseg, lseg, samclip, vlpart 28 | depth: render # choose from image, render, surface, none 29 | depth_scale: 1000.0 30 | visibility_threshold: 0.05 31 | cut_boundary: 10 32 | n_split_points: 999999999 # train: 80000, eval: 999999999 (large enough) 33 | num_rand_file_per_scene: 1 # train: 5, eval: 1 34 | out_dir: "/PATH/TO/YOUR/OWN" 35 | -------------------------------------------------------------------------------- /config/official_train.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | test_cameras: True 4 | colmap_images: "images" 5 | colmap_eval_hold: 8 6 | downscale_ratio: 1 7 | white_background: False 8 | device: "cuda:0" 9 | 10 | pipeline: 11 | convert_shs_python: False 12 | compute_cov3d_python: False 13 | debug: False 14 | seed: 0 15 | 16 | model: 17 | sh_degree: 3 18 | model_dir: ~ 19 | load_iteration: -1 20 | device: "cuda:0" 21 | random_init: False 22 | 23 | train: 24 | exp_name: "EXP_NAME" 25 | iterations: 30000 26 | num_workers: 8 27 | test_iterations: [100, 7000, 30000] 28 | save_iterations: [7000, 30000] 29 | checkpoint_iterations: [] 30 | cut_edge: False # If ScanNet, set True to cut the 1% black edge 31 | 32 | position_lr_init: 0.00016 33 | position_lr_final: 0.0000016 34 | position_lr_delay_mult: 0.01 35 | position_lr_max_steps: 10000 36 | feature_lr: 0.0025 37 | opacity_lr: 0.05 38 | scaling_lr: 0.005 39 | rotation_lr: 0.001 40 | percent_dense: 0.01 41 | lambda_dssim: 0.2 42 | densification_interval: 100 # 1000 # 100 # 2000 43 | opacity_reset_interval: 3000 #2000 # 3000 # 2000 44 | densify_from_iter: 500 45 | densify_until_iter: 15000 # 10000 # 15000 #10000 46 | densify_grad_threshold: 0.0002 47 | random_background: False -------------------------------------------------------------------------------- /config/view_scannet.yaml: -------------------------------------------------------------------------------- 1 | scene: 2 | scene_path: "/PATH/TO/YOUR/OWN" 3 | dataset_name: "cocomap" 4 | test_cameras: False 5 | colmap_images: "images" 6 | colmap_eval_hold: 8 7 | downscale_ratio: 1 8 | white_background: False 9 | device: "cuda:0" 10 | 11 | pipeline: 12 | convert_shs_python: False 13 | compute_cov3d_python: False 14 | debug: False 15 | seed: 0 16 | 17 | model: 18 | sh_degree: 3 19 | model_dir: "/PATH/TO/YOUR/OWN" 20 | dynamic: False 21 | load_iteration: -1 22 | device: "cuda:0" 23 | 24 | render: 25 | fusion_dir: "/PATH/TO/YOUR/OWN/*.pt" 26 | model_2d: openseg # choose from openseg, lseg, samclip, vlpart 27 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Code heavily based on https://github.com/pengsongyou/openscene""" 2 | -------------------------------------------------------------------------------- /dataset/feature_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | from dataset.fusion_utils import Voxelizer 7 | from dataset.augmentation import ElasticDistortion, RandomHorizontalFlip, Compose 8 | from utils.dataset_utils import load_gaussian_ply 9 | 10 | 11 | class FeatureDataset(Dataset): 12 | # Augmentation arguments 13 | SCALE_AUGMENTATION_BOUND = (0.9, 1.1) 14 | ROTATION_AUGMENTATION_BOUND = ( 15 | (-np.pi / 64, np.pi / 64), 16 | (-np.pi / 64, np.pi / 64), 17 | (-np.pi, np.pi), 18 | ) 19 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0)) 20 | ELASTIC_DISTORT_PARAMS = ((0.2, 0.4), (0.8, 1.6)) 21 | 22 | ROTATION_AXIS = "z" 23 | 24 | def __init__( 25 | self, gaussians_dir, point_dir, gaussian_iterations=30000, voxel_size=0.02, aug=False, feature_type="all" 26 | ): 27 | self.aug = aug 28 | self.feature_type = feature_type 29 | self.scenes = os.listdir(gaussians_dir) 30 | self.scenes.sort() 31 | 32 | self.data = [] 33 | for scene in self.scenes: 34 | features = os.listdir(os.path.join(point_dir, scene)) 35 | features.sort() 36 | for feature in features: 37 | ply_path = os.path.join( 38 | gaussians_dir, 39 | scene, 40 | "point_cloud", 41 | f"iteration_{gaussian_iterations}", 42 | "point_cloud.ply", 43 | ) 44 | feature_path = os.path.join(point_dir, scene, feature) 45 | self.data.append([ply_path, feature_path, 0]) 46 | 47 | self.voxelizer = Voxelizer( 48 | voxel_size=voxel_size, 49 | clip_bound=None, 50 | use_augmentation=aug, 51 | scale_augmentation_bound=self.SCALE_AUGMENTATION_BOUND, 52 | rotation_augmentation_bound=self.ROTATION_AUGMENTATION_BOUND, 53 | translation_augmentation_ratio_bound=self.TRANSLATION_AUGMENTATION_RATIO_BOUND, 54 | ) 55 | 56 | self.prevoxel_transforms = Compose([ElasticDistortion(self.ELASTIC_DISTORT_PARAMS)]) 57 | self.input_transforms = Compose([RandomHorizontalFlip(self.ROTATION_AXIS, is_temporal=False)]) 58 | 59 | def __getitem__(self, index): 60 | with torch.no_grad(): 61 | ply_path, feature_path, head_id = self.data[index] 62 | locs, features = load_gaussian_ply(ply_path, self.feature_type) 63 | gt = torch.load(feature_path) 64 | features_gt, mask_chunk = gt["feat"], gt["mask_full"] 65 | 66 | # numpy transforms 67 | if self.aug: 68 | locs = self.prevoxel_transforms(locs) 69 | 70 | locs, features, _, inds_reconstruct, vox_ind = self.voxelizer.voxelize( 71 | locs, features, None, return_ind=True 72 | ) 73 | 74 | vox_ind = torch.from_numpy(vox_ind) 75 | mask = mask_chunk[vox_ind] 76 | mask_ind = mask_chunk.nonzero(as_tuple=False)[:, 0] 77 | index1 = -torch.ones(mask_chunk.shape[0], dtype=int) 78 | index1[mask_ind] = mask_ind 79 | 80 | index1 = index1[vox_ind] 81 | chunk_ind = index1[index1 != -1] 82 | 83 | index2 = torch.zeros(mask_chunk.shape[0]) 84 | index2[mask_ind] = 1 85 | index3 = torch.cumsum(index2, dim=0, dtype=int) 86 | 87 | indices = index3[chunk_ind] - 1 88 | features_gt = features_gt[indices] 89 | 90 | if self.aug: 91 | locs, features, _ = self.input_transforms(locs, features, None) 92 | 93 | locs = torch.from_numpy(locs).int() 94 | locs = torch.cat([torch.ones(locs.shape[0], 1, dtype=torch.int), locs], dim=1) 95 | features = torch.from_numpy(features).float() 96 | 97 | return locs, features, features_gt, mask, head_id 98 | 99 | def __len__(self): 100 | return len(self.data) 101 | -------------------------------------------------------------------------------- /dataset/scannet/__init__.py: -------------------------------------------------------------------------------- 1 | """Code from https://github.com/ScanNet/ScanNet/tree/master/BenchmarkScripts/ScanNet200""" 2 | -------------------------------------------------------------------------------- /dataset/scannet/label_mapping.py: -------------------------------------------------------------------------------- 1 | """Read the scannetv2-labels.combined.tsv and convert the label to other ids""" 2 | import os 3 | import csv 4 | 5 | 6 | def read_label_mapping(filename, label_from="id", label_to="nyu40id"): 7 | assert os.path.isfile(filename) 8 | mapping = dict() 9 | with open(filename) as csvfile: 10 | reader = csv.DictReader(csvfile, delimiter="\t") 11 | for row in reader: 12 | mapping[row[label_from]] = int(row[label_to]) 13 | 14 | # if ints convert 15 | def represents_int(s): 16 | try: 17 | int(s) 18 | return True 19 | except ValueError: 20 | return False 21 | 22 | if represents_int(list(mapping.keys())[0]): 23 | mapping = {int(k): v for k, v in mapping.items()} 24 | return mapping 25 | -------------------------------------------------------------------------------- /dataset/scannet/scannet_constants.py: -------------------------------------------------------------------------------- 1 | SCANNET20_CLASS_LABELS = ( 2 | "wall", 3 | "floor", 4 | "cabinet", 5 | "bed", 6 | "chair", 7 | "sofa", 8 | "table", 9 | "door", 10 | "window", 11 | "bookshelf", 12 | "picture", 13 | "counter", 14 | "desk", 15 | "curtain", 16 | "refridgerator", 17 | "shower curtain", 18 | "toilet", 19 | "sink", 20 | "bathtub", 21 | ) 22 | 23 | COCOMAP_CLASS_LABELS = ( 24 | "wall", 25 | "floor", 26 | "cabinet", 27 | "bed", 28 | "chair", 29 | "sofa", 30 | "table", 31 | "door", 32 | "window", 33 | "shelves", 34 | "counter", 35 | "curtain", 36 | "ceiling", 37 | "refridgerator", 38 | "television", 39 | "person", 40 | "toilet", 41 | "sink", 42 | "lamp", 43 | "bag", 44 | ) 45 | 46 | COLORMAP = [ 47 | (0.0, 0.0, 0.0), 48 | (174.0, 199.0, 232.0), 49 | (152.0, 223.0, 138.0), 50 | (31.0, 119.0, 180.0), 51 | (255.0, 187.0, 120.0), 52 | (188.0, 189.0, 34.0), 53 | (140.0, 86.0, 75.0), 54 | (255.0, 152.0, 150.0), 55 | (214.0, 39.0, 40.0), 56 | (197.0, 176.0, 213.0), 57 | (148.0, 103.0, 189.0), 58 | (196.0, 156.0, 148.0), 59 | (23.0, 190.0, 207.0), 60 | (247.0, 182.0, 210.0), 61 | (219.0, 219.0, 141.0), 62 | (255.0, 127.0, 14.0), 63 | (158.0, 218.0, 229.0), 64 | (44.0, 160.0, 44.0), 65 | (112.0, 128.0, 144.0), 66 | (227.0, 119.0, 194.0), 67 | (213.0, 92.0, 176.0), 68 | (94.0, 106.0, 211.0), 69 | (82.0, 84.0, 163.0), 70 | (100.0, 85.0, 144.0), 71 | (66.0, 188.0, 102.0), 72 | (140.0, 57.0, 197.0), 73 | (202.0, 185.0, 52.0), 74 | (51.0, 176.0, 203.0), 75 | (200.0, 54.0, 131.0), 76 | (92.0, 193.0, 61.0), 77 | (78.0, 71.0, 183.0), 78 | (172.0, 114.0, 82.0), 79 | (91.0, 163.0, 138.0), 80 | (153.0, 98.0, 156.0), 81 | (140.0, 153.0, 101.0), 82 | (100.0, 125.0, 154.0), 83 | (178.0, 127.0, 135.0), 84 | (146.0, 111.0, 194.0), 85 | (96.0, 207.0, 209.0), 86 | ] 87 | -------------------------------------------------------------------------------- /dataset/scannet/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sega 2 | channels: 3 | - nvidia/label/cuda-11.8.0 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - cuda=11.8.0 8 | - pytorch-cuda=11.8 9 | - plyfile=0.8.1 10 | - python=3.9.18 11 | - pytorch=2.1.1 12 | - numpy=1.23.5 13 | - torchaudio=2.1.1 14 | - torchvision=0.16.1 15 | - pip=22.3.1 16 | - setuptools=66.0.0 17 | - tqdm -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_model import GaussianModel 2 | from .renderer import render, render_chn 3 | -------------------------------------------------------------------------------- /model/lseg/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Intelligent Systems Lab Org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/lseg/data/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import itertools 4 | import functools 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | import torchvision.transforms as torch_transforms 9 | import encoding.datasets as enc_ds 10 | 11 | encoding_datasets = { 12 | x: functools.partial(enc_ds.get_dataset, x) 13 | for x in ["coco", "ade20k", "pascal_voc", "pascal_aug", "pcontext", "citys"] 14 | } 15 | 16 | 17 | def get_dataset(name, **kwargs): 18 | if name in encoding_datasets: 19 | return encoding_datasets[name.lower()](**kwargs) 20 | assert False, f"dataset {name} not found" 21 | 22 | 23 | def get_available_datasets(): 24 | return list(encoding_datasets.keys()) 25 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/common/evaluation.py: -------------------------------------------------------------------------------- 1 | r""" Evaluate mask prediction """ 2 | import torch 3 | 4 | 5 | class Evaluator: 6 | r""" Computes intersection and union between prediction and ground-truth """ 7 | @classmethod 8 | def initialize(cls): 9 | cls.ignore_index = 255 10 | 11 | @classmethod 12 | def classify_prediction(cls, pred_mask, gt_mask, query_ignore_idx=None): 13 | # gt_mask = batch.get('query_mask') 14 | 15 | # # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020)) 16 | # query_ignore_idx = batch.get('query_ignore_idx') 17 | if query_ignore_idx is not None: 18 | assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0 19 | query_ignore_idx *= cls.ignore_index 20 | gt_mask = gt_mask + query_ignore_idx 21 | pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index 22 | 23 | # compute intersection and union of each episode in a batch 24 | area_inter, area_pred, area_gt = [], [], [] 25 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask): 26 | _inter = _pred_mask[_pred_mask == _gt_mask] 27 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1) 28 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device) 29 | else: 30 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1) 31 | area_inter.append(_area_inter) 32 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1)) 33 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1)) 34 | area_inter = torch.stack(area_inter).t() 35 | area_pred = torch.stack(area_pred).t() 36 | area_gt = torch.stack(area_gt).t() 37 | area_union = area_pred + area_gt - area_inter 38 | 39 | return area_inter, area_union 40 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/common/utils.py: -------------------------------------------------------------------------------- 1 | r""" Helper functions """ 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def fix_randseed(seed): 9 | r""" Set random seeds for reproducibility """ 10 | if seed is None: 11 | seed = int(random.random() * 1e5) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def mean(x): 21 | return sum(x) / len(x) if len(x) > 0 else 0.0 22 | 23 | 24 | def to_cuda(batch): 25 | for key, value in batch.items(): 26 | if isinstance(value, torch.Tensor): 27 | batch[key] = value.cuda() 28 | return batch 29 | 30 | 31 | def to_cpu(tensor): 32 | return tensor.detach().clone().cpu() 33 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/common/vis.py: -------------------------------------------------------------------------------- 1 | r""" Visualize model predictions """ 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | 8 | from fewshot_data.common import utils 9 | 10 | 11 | class Visualizer: 12 | 13 | @classmethod 14 | def initialize(cls, visualize): 15 | cls.visualize = visualize 16 | if not visualize: 17 | return 18 | 19 | cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)} 20 | for key, value in cls.colors.items(): 21 | cls.colors[key] = tuple([c / 255 for c in cls.colors[key]]) 22 | 23 | # cls.mean_img = [0.485, 0.456, 0.406] 24 | # cls.std_img = [0.229, 0.224, 0.225] 25 | cls.mean_img = [0.5] * 3 26 | cls.std_img = [0.5] * 3 27 | cls.to_pil = transforms.ToPILImage() 28 | cls.vis_path = './vis/' 29 | if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path) 30 | 31 | @classmethod 32 | def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None): 33 | spt_img_b = utils.to_cpu(spt_img_b) 34 | spt_mask_b = utils.to_cpu(spt_mask_b) 35 | qry_img_b = utils.to_cpu(qry_img_b) 36 | qry_mask_b = utils.to_cpu(qry_mask_b) 37 | pred_mask_b = utils.to_cpu(pred_mask_b) 38 | cls_id_b = utils.to_cpu(cls_id_b) 39 | 40 | for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \ 41 | enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)): 42 | iou = iou_b[sample_idx] if iou_b is not None else None 43 | cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou) 44 | 45 | @classmethod 46 | def to_numpy(cls, tensor, type): 47 | if type == 'img': 48 | return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8) 49 | elif type == 'mask': 50 | return np.array(tensor).astype(np.uint8) 51 | else: 52 | raise Exception('Undefined tensor type: %s' % type) 53 | 54 | @classmethod 55 | def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None): 56 | 57 | spt_color = cls.colors['blue'] 58 | qry_color = cls.colors['red'] 59 | pred_color = cls.colors['red'] 60 | 61 | spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs] 62 | spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs] 63 | spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks] 64 | spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)] 65 | 66 | qry_img = cls.to_numpy(qry_img, 'img') 67 | qry_pil = cls.to_pil(qry_img) 68 | qry_mask = cls.to_numpy(qry_mask, 'mask') 69 | pred_mask = cls.to_numpy(pred_mask, 'mask') 70 | pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color)) 71 | qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color)) 72 | 73 | merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil]) 74 | 75 | iou = iou.item() if iou else 0.0 76 | merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg') 77 | 78 | @classmethod 79 | def merge_image_pair(cls, pil_imgs): 80 | r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """ 81 | 82 | canvas_width = sum([pil.size[0] for pil in pil_imgs]) 83 | canvas_height = max([pil.size[1] for pil in pil_imgs]) 84 | canvas = Image.new('RGB', (canvas_width, canvas_height)) 85 | 86 | xpos = 0 87 | for pil in pil_imgs: 88 | canvas.paste(pil, (xpos, 0)) 89 | xpos += pil.size[0] 90 | 91 | return canvas 92 | 93 | @classmethod 94 | def apply_mask(cls, image, mask, color, alpha=0.5): 95 | r""" Apply mask to the given image. """ 96 | for c in range(3): 97 | image[:, :, c] = np.where(mask == 1, 98 | image[:, :, c] * 99 | (1 - alpha) + alpha * color[c] * 255, 100 | image[:, :, c]) 101 | return image 102 | 103 | @classmethod 104 | def unnormalize(cls, img): 105 | img = img.clone() 106 | for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img): 107 | im_channel.mul_(std).add_(mean) 108 | return img 109 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/assets/architecture.png -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/assets/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/assets/qualitative_results.png -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/coco.py: -------------------------------------------------------------------------------- 1 | r""" COCO-20i few-shot semantic segmentation dataset """ 2 | import os 3 | import pickle 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetCOCO(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize): 14 | self.split = 'val' if split in ['val', 'test'] else 'trn' 15 | self.fold = fold 16 | self.nfolds = 4 17 | self.nclass = 80 18 | self.benchmark = 'coco' 19 | self.shot = shot 20 | self.split_coco = split if split == 'val2014' else 'train2014' 21 | self.base_path = os.path.join(datapath, 'COCO2014') 22 | self.transform = transform 23 | self.use_original_imgsize = use_original_imgsize 24 | 25 | self.class_ids = self.build_class_ids() 26 | self.img_metadata_classwise = self.build_img_metadata_classwise() 27 | self.img_metadata = self.build_img_metadata() 28 | 29 | def __len__(self): 30 | return len(self.img_metadata) if self.split == 'trn' else 1000 31 | 32 | def __getitem__(self, idx): 33 | # ignores idx during training & testing and perform uniform sampling over object classes to form an episode 34 | # (due to the large size of the COCO dataset) 35 | query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame() 36 | 37 | query_img = self.transform(query_img) 38 | query_mask = query_mask.float() 39 | if not self.use_original_imgsize: 40 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 41 | 42 | if self.shot: 43 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 44 | for midx, smask in enumerate(support_masks): 45 | support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 46 | support_masks = torch.stack(support_masks) 47 | 48 | 49 | batch = {'query_img': query_img, 50 | 'query_mask': query_mask, 51 | 'query_name': query_name, 52 | 53 | 'org_query_imsize': org_qry_imsize, 54 | 55 | 'support_imgs': support_imgs, 56 | 'support_masks': support_masks, 57 | 'support_names': support_names, 58 | 'class_id': torch.tensor(class_sample)} 59 | 60 | return batch 61 | 62 | def build_class_ids(self): 63 | nclass_trn = self.nclass // self.nfolds 64 | class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)] 65 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 66 | class_ids = class_ids_trn if self.split == 'trn' else class_ids_val 67 | 68 | return class_ids 69 | 70 | def build_img_metadata_classwise(self): 71 | with open('fewshot_data/data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f: 72 | img_metadata_classwise = pickle.load(f) 73 | return img_metadata_classwise 74 | 75 | def build_img_metadata(self): 76 | img_metadata = [] 77 | for k in self.img_metadata_classwise.keys(): 78 | img_metadata += self.img_metadata_classwise[k] 79 | return sorted(list(set(img_metadata))) 80 | 81 | def read_mask(self, name): 82 | mask_path = os.path.join(self.base_path, 'annotations', name) 83 | mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png'))) 84 | return mask 85 | 86 | def load_frame(self): 87 | class_sample = np.random.choice(self.class_ids, 1, replace=False)[0] 88 | query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 89 | query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB') 90 | query_mask = self.read_mask(query_name) 91 | 92 | org_qry_imsize = query_img.size 93 | 94 | query_mask[query_mask != class_sample + 1] = 0 95 | query_mask[query_mask == class_sample + 1] = 1 96 | 97 | support_names = [] 98 | if self.shot: 99 | while True: # keep sampling support set if query == support 100 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 101 | if query_name != support_name: support_names.append(support_name) 102 | if len(support_names) == self.shot: break 103 | 104 | support_imgs = [] 105 | support_masks = [] 106 | if self.shot: 107 | for support_name in support_names: 108 | support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB')) 109 | support_mask = self.read_mask(support_name) 110 | support_mask[support_mask != class_sample + 1] = 0 111 | support_mask[support_mask == class_sample + 1] = 1 112 | support_masks.append(support_mask) 113 | 114 | return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize 115 | 116 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/dataset.py: -------------------------------------------------------------------------------- 1 | r""" Dataloader builder for few-shot semantic segmentation dataset """ 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | from fewshot_data.data.pascal import DatasetPASCAL 6 | from fewshot_data.data.coco import DatasetCOCO 7 | from fewshot_data.data.fss import DatasetFSS 8 | 9 | 10 | class FSSDataset: 11 | @classmethod 12 | def initialize(cls, img_size, datapath, use_original_imgsize, imagenet_norm=False): 13 | cls.datasets = { 14 | 'pascal': DatasetPASCAL, 15 | 'coco': DatasetCOCO, 16 | 'fss': DatasetFSS, 17 | } 18 | 19 | if imagenet_norm: 20 | cls.img_mean = [0.485, 0.456, 0.406] 21 | cls.img_std = [0.229, 0.224, 0.225] 22 | print('use norm: {}, {}'.format(cls.img_mean, cls.img_std)) 23 | else: 24 | cls.img_mean = [0.5] * 3 25 | cls.img_std = [0.5] * 3 26 | print('use norm: {}, {}'.format(cls.img_mean, cls.img_std)) 27 | 28 | cls.datapath = datapath 29 | cls.use_original_imgsize = use_original_imgsize 30 | 31 | cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), 32 | transforms.ToTensor(), 33 | transforms.Normalize(cls.img_mean, cls.img_std)]) 34 | 35 | @classmethod 36 | def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1): 37 | shuffle = split == 'trn' 38 | nworker = nworker if split == 'trn' else 0 39 | dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize) 40 | dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker) 41 | 42 | return dataloader 43 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/trn/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/trn/fold0.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/trn/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/trn/fold1.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/trn/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/trn/fold2.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/trn/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/trn/fold3.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/val/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/val/fold0.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/val/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/val/fold1.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/val/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/val/fold2.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/coco/val/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/fewshot_data/data/splits/coco/val/fold3.pkl -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/fss/test.txt: -------------------------------------------------------------------------------- 1 | bus 2 | hotel_slipper 3 | burj_al 4 | reflex_camera 5 | abe's_flyingfish 6 | oiltank_car 7 | doormat 8 | fish_eagle 9 | barber_shaver 10 | motorbike 11 | feather_clothes 12 | wandering_albatross 13 | rice_cooker 14 | delta_wing 15 | fish 16 | nintendo_switch 17 | bustard 18 | diver 19 | minicooper 20 | cathedrale_paris 21 | big_ben 22 | combination_lock 23 | villa_savoye 24 | american_alligator 25 | gym_ball 26 | andean_condor 27 | leggings 28 | pyramid_cube 29 | jet_aircraft 30 | meatloaf 31 | reel 32 | swan 33 | osprey 34 | crt_screen 35 | microscope 36 | rubber_eraser 37 | arrow 38 | monkey 39 | mitten 40 | spiderman 41 | parthenon 42 | bat 43 | chess_king 44 | sulphur_butterfly 45 | quail_egg 46 | oriole 47 | iron_man 48 | wooden_boat 49 | anise 50 | steering_wheel 51 | groenendael 52 | dwarf_beans 53 | pteropus 54 | chalk_brush 55 | bloodhound 56 | moon 57 | english_foxhound 58 | boxing_gloves 59 | peregine_falcon 60 | pyraminx 61 | cicada 62 | screw 63 | shower_curtain 64 | tredmill 65 | bulb 66 | bell_pepper 67 | lemur_catta 68 | doughnut 69 | twin_tower 70 | astronaut 71 | nintendo_3ds 72 | fennel_bulb 73 | indri 74 | captain_america_shield 75 | kunai 76 | broom 77 | iphone 78 | earphone1 79 | flying_squirrel 80 | onion 81 | vinyl 82 | sydney_opera_house 83 | oyster 84 | harmonica 85 | egg 86 | breast_pump 87 | guitar 88 | potato_chips 89 | tunnel 90 | cuckoo 91 | rubick_cube 92 | plastic_bag 93 | phonograph 94 | net_surface_shoes 95 | goldfinch 96 | ipad 97 | mite_predator 98 | coffee_mug 99 | golden_plover 100 | f1_racing 101 | lapwing 102 | nintendo_gba 103 | pizza 104 | rally_car 105 | drilling_platform 106 | cd 107 | fly 108 | magpie_bird 109 | leaf_fan 110 | little_blue_heron 111 | carriage 112 | moist_proof_pad 113 | flying_snakes 114 | dart_target 115 | warehouse_tray 116 | nintendo_wiiu 117 | chiffon_cake 118 | bath_ball 119 | manatee 120 | cloud 121 | marimba 122 | eagle 123 | ruler 124 | soymilk_machine 125 | sled 126 | seagull 127 | glider_flyingfish 128 | doublebus 129 | transport_helicopter 130 | window_screen 131 | truss_bridge 132 | wasp 133 | snowman 134 | poached_egg 135 | strawberry 136 | spinach 137 | earphone2 138 | downy_pitch 139 | taj_mahal 140 | rocking_chair 141 | cablestayed_bridge 142 | sealion 143 | banana_boat 144 | pheasant 145 | stone_lion 146 | electronic_stove 147 | fox 148 | iguana 149 | rugby_ball 150 | hang_glider 151 | water_buffalo 152 | lotus 153 | paper_plane 154 | missile 155 | flamingo 156 | american_chamelon 157 | kart 158 | chinese_knot 159 | cabbage_butterfly 160 | key 161 | church 162 | tiltrotor 163 | helicopter 164 | french_fries 165 | water_heater 166 | snow_leopard 167 | goblet 168 | fan 169 | snowplow 170 | leafhopper 171 | pspgo 172 | black_bear 173 | quail 174 | condor 175 | chandelier 176 | hair_razor 177 | white_wolf 178 | toaster 179 | pidan 180 | pyramid 181 | chicken_leg 182 | letter_opener 183 | apple_icon 184 | porcupine 185 | chicken 186 | stingray 187 | warplane 188 | windmill 189 | bamboo_slip 190 | wig 191 | flying_geckos 192 | stonechat 193 | haddock 194 | australian_terrier 195 | hover_board 196 | siamang 197 | canton_tower 198 | santa_sledge 199 | arch_bridge 200 | curlew 201 | sushi 202 | beet_root 203 | accordion 204 | leaf_egg 205 | stealth_aircraft 206 | stork 207 | bucket 208 | hawk 209 | chess_queen 210 | ocarina 211 | knife 212 | whippet 213 | cantilever_bridge 214 | may_bug 215 | wagtail 216 | leather_shoes 217 | wheelchair 218 | shumai 219 | speedboat 220 | vacuum_cup 221 | chess_knight 222 | pumpkin_pie 223 | wooden_spoon 224 | bamboo_dragonfly 225 | ganeva_chair 226 | soap 227 | clearwing_flyingfish 228 | pencil_sharpener1 229 | cricket 230 | photocopier 231 | nintendo_sp 232 | samarra_mosque 233 | clam 234 | charge_battery 235 | flying_frog 236 | ferrari911 237 | polo_shirt 238 | echidna 239 | coin 240 | tower_pisa 241 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/data/splits/fss/val.txt: -------------------------------------------------------------------------------- 1 | handcuff 2 | mortar 3 | matchstick 4 | wine_bottle 5 | dowitcher 6 | triumphal_arch 7 | gyromitra 8 | hatchet 9 | airliner 10 | broccoli 11 | olive 12 | pubg_lvl3backpack 13 | calculator 14 | toucan 15 | shovel 16 | sewing_machine 17 | icecream 18 | woodpecker 19 | pig 20 | relay_stick 21 | mcdonald_sign 22 | cpu 23 | peanut 24 | pumpkin 25 | sturgeon 26 | hammer 27 | hami_melon 28 | squirrel_monkey 29 | shuriken 30 | power_drill 31 | pingpong_ball 32 | crocodile 33 | carambola 34 | monarch_butterfly 35 | drum 36 | water_tower 37 | panda 38 | toilet_brush 39 | pay_phone 40 | yonex_icon 41 | cricketball 42 | revolver 43 | chimpanzee 44 | crab 45 | corn 46 | baseball 47 | rabbit 48 | croquet_ball 49 | artichoke 50 | abacus 51 | harp 52 | bell 53 | gas_tank 54 | scissors 55 | vase 56 | upright_piano 57 | typewriter 58 | bittern 59 | impala 60 | tray 61 | fire_hydrant 62 | beer_bottle 63 | sock 64 | soup_bowl 65 | spider 66 | cherry 67 | macaw 68 | toilet_seat 69 | fire_balloon 70 | french_ball 71 | fox_squirrel 72 | volleyball 73 | cornmeal 74 | folding_chair 75 | pubg_airdrop 76 | beagle 77 | skateboard 78 | narcissus 79 | whiptail 80 | cup 81 | arabian_camel 82 | badger 83 | stopwatch 84 | ab_wheel 85 | ox 86 | lettuce 87 | monocycle 88 | redshank 89 | vulture 90 | whistle 91 | smoothing_iron 92 | mashed_potato 93 | conveyor 94 | yoga_pad 95 | tow_truck 96 | siamese_cat 97 | cigar 98 | white_stork 99 | sniper_rifle 100 | stretcher 101 | tulip 102 | handkerchief 103 | basset 104 | iceberg 105 | gibbon 106 | lacewing 107 | thrush 108 | cheetah 109 | bighorn_sheep 110 | espresso_maker 111 | pretzel 112 | english_setter 113 | sandbar 114 | cheese 115 | daisy 116 | arctic_fox 117 | briard 118 | colubus 119 | balance_beam 120 | coffeepot 121 | soap_dispenser 122 | yawl 123 | consomme 124 | parking_meter 125 | cactus 126 | turnstile 127 | taro 128 | fire_screen 129 | digital_clock 130 | rose 131 | pomegranate 132 | bee_eater 133 | schooner 134 | ski_mask 135 | jay_bird 136 | plaice 137 | red_fox 138 | syringe 139 | camomile 140 | pickelhaube 141 | blenheim_spaniel 142 | pear 143 | parachute 144 | common_newt 145 | bowtie 146 | cigarette 147 | oscilloscope 148 | laptop 149 | african_crocodile 150 | apron 151 | coconut 152 | sandal 153 | kwanyin 154 | lion 155 | eel 156 | balloon 157 | crepe 158 | armadillo 159 | kazoo 160 | lemon 161 | spider_monkey 162 | tape_player 163 | ipod 164 | bee 165 | sea_cucumber 166 | suitcase 167 | television 168 | pillow 169 | banjo 170 | rock_snake 171 | partridge 172 | platypus 173 | lycaenid_butterfly 174 | pinecone 175 | conversion_plug 176 | wolf 177 | frying_pan 178 | timber_wolf 179 | bluetick 180 | crayon 181 | giant_schnauzer 182 | orang 183 | scarerow 184 | kobe_logo 185 | loguat 186 | saxophone 187 | ceiling_fan 188 | cardoon 189 | equestrian_helmet 190 | louvre_pyramid 191 | hotdog 192 | ironing_board 193 | razor 194 | nagoya_castle 195 | loggerhead_turtle 196 | lipstick 197 | cradle 198 | strongbox 199 | raven 200 | kit_fox 201 | albatross 202 | flat-coated_retriever 203 | beer_glass 204 | ice_lolly 205 | sungnyemun 206 | totem_pole 207 | vacuum 208 | bolete 209 | mango 210 | ginger 211 | weasel 212 | cabbage 213 | refrigerator 214 | school_bus 215 | hippo 216 | tiger_cat 217 | saltshaker 218 | piano_keyboard 219 | windsor_tie 220 | sea_urchin 221 | microsd 222 | barbell 223 | swim_ring 224 | bulbul_bird 225 | water_ouzel 226 | ac_ground 227 | sweatshirt 228 | umbrella 229 | hair_drier 230 | hammerhead_shark 231 | tomato 232 | projector 233 | cushion 234 | dishwasher 235 | three-toed_sloth 236 | tiger_shark 237 | har_gow 238 | baby 239 | thor's_hammer 240 | nike_logo 241 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/model/base/conv4d.py: -------------------------------------------------------------------------------- 1 | r""" Implementation of center-pivot 4D convolution """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CenterPivotConv4d(nn.Module): 8 | r""" CenterPivot 4D conv""" 9 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True): 10 | super(CenterPivotConv4d, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2], 13 | bias=bias, padding=padding[:2]) 14 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:], 15 | bias=bias, padding=padding[2:]) 16 | 17 | self.stride34 = stride[2:] 18 | self.kernel_size = kernel_size 19 | self.stride = stride 20 | self.padding = padding 21 | self.idx_initialized = False 22 | 23 | def prune(self, ct): 24 | bsz, ch, ha, wa, hb, wb = ct.size() 25 | if not self.idx_initialized: 26 | idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device) 27 | idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device) 28 | self.len_h = len(idxh) 29 | self.len_w = len(idxw) 30 | self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1) 31 | self.idx_initialized = True 32 | ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w) 33 | 34 | return ct_pruned 35 | 36 | def forward(self, x): 37 | if self.stride[2:][-1] > 1: 38 | out1 = self.prune(x) 39 | else: 40 | out1 = x 41 | bsz, inch, ha, wa, hb, wb = out1.size() 42 | out1 = out1.permute(0, 4, 5, 1, 2, 3).contiguous().view(-1, inch, ha, wa) 43 | out1 = self.conv1(out1) 44 | outch, o_ha, o_wa = out1.size(-3), out1.size(-2), out1.size(-1) 45 | out1 = out1.view(bsz, hb, wb, outch, o_ha, o_wa).permute(0, 3, 4, 5, 1, 2).contiguous() 46 | 47 | bsz, inch, ha, wa, hb, wb = x.size() 48 | out2 = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, inch, hb, wb) 49 | out2 = self.conv2(out2) 50 | outch, o_hb, o_wb = out2.size(-3), out2.size(-2), out2.size(-1) 51 | out2 = out2.view(bsz, ha, wa, outch, o_hb, o_wb).permute(0, 3, 1, 2, 4, 5).contiguous() 52 | 53 | if out1.size()[-2:] != out2.size()[-2:] and self.padding[-2:] == (0, 0): 54 | out1 = out1.view(bsz, outch, o_ha, o_wa, -1).sum(dim=-1) 55 | out2 = out2.squeeze() 56 | 57 | y = out1 + out2 58 | return y 59 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/model/base/correlation.py: -------------------------------------------------------------------------------- 1 | r""" Provides functions that builds/manipulates correlation tensors """ 2 | import torch 3 | 4 | 5 | class Correlation: 6 | 7 | @classmethod 8 | def multilayer_correlation(cls, query_feats, support_feats, stack_ids): 9 | eps = 1e-5 10 | 11 | corrs = [] 12 | for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)): 13 | bsz, ch, hb, wb = support_feat.size() 14 | support_feat = support_feat.view(bsz, ch, -1) 15 | support_feat = support_feat / (support_feat.norm(dim=1, p=2, keepdim=True) + eps) 16 | 17 | bsz, ch, ha, wa = query_feat.size() 18 | query_feat = query_feat.view(bsz, ch, -1) 19 | query_feat = query_feat / (query_feat.norm(dim=1, p=2, keepdim=True) + eps) 20 | 21 | corr = torch.bmm(query_feat.transpose(1, 2), support_feat).view(bsz, ha, wa, hb, wb) 22 | corr = corr.clamp(min=0) 23 | corrs.append(corr) 24 | 25 | corr_l4 = torch.stack(corrs[-stack_ids[0]:]).transpose(0, 1).contiguous() 26 | corr_l3 = torch.stack(corrs[-stack_ids[1]:-stack_ids[0]]).transpose(0, 1).contiguous() 27 | corr_l2 = torch.stack(corrs[-stack_ids[2]:-stack_ids[1]]).transpose(0, 1).contiguous() 28 | 29 | return [corr_l4, corr_l3, corr_l2] 30 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/model/base/feature.py: -------------------------------------------------------------------------------- 1 | r""" Extracts intermediate features from given backbone network & layer ids """ 2 | 3 | 4 | def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None): 5 | r""" Extract intermediate features from VGG """ 6 | feats = [] 7 | feat = img 8 | for lid, module in enumerate(backbone.features): 9 | feat = module(feat) 10 | if lid in feat_ids: 11 | feats.append(feat.clone()) 12 | return feats 13 | 14 | 15 | def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids): 16 | r""" Extract intermediate features from ResNet""" 17 | feats = [] 18 | 19 | # Layer 0 20 | feat = backbone.conv1.forward(img) 21 | feat = backbone.bn1.forward(feat) 22 | feat = backbone.relu.forward(feat) 23 | feat = backbone.maxpool.forward(feat) 24 | 25 | # Layer 1-4 26 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)): 27 | res = feat 28 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) 29 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) 30 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 31 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) 32 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) 33 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 34 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) 35 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) 36 | 37 | if bid == 0: 38 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res) 39 | 40 | feat += res 41 | 42 | if hid + 1 in feat_ids: 43 | feats.append(feat.clone()) 44 | 45 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 46 | 47 | return feats -------------------------------------------------------------------------------- /model/lseg/fewshot_data/model/hsnet.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze Network """ 2 | from functools import reduce 3 | from operator import add 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision.models import resnet 9 | from torchvision.models import vgg 10 | 11 | from fewshot_data.model.base.feature import extract_feat_vgg, extract_feat_res 12 | from fewshot_data.model.base.correlation import Correlation 13 | from fewshot_data.model.learner import HPNLearner 14 | 15 | 16 | class HypercorrSqueezeNetwork(nn.Module): 17 | def __init__(self, backbone, use_original_imgsize): 18 | super(HypercorrSqueezeNetwork, self).__init__() 19 | 20 | # 1. Backbone network initialization 21 | self.backbone_type = backbone 22 | self.use_original_imgsize = use_original_imgsize 23 | if backbone == 'vgg16': 24 | self.backbone = vgg.vgg16(pretrained=True) 25 | self.feat_ids = [17, 19, 21, 24, 26, 28, 30] 26 | self.extract_feats = extract_feat_vgg 27 | nbottlenecks = [2, 2, 3, 3, 3, 1] 28 | elif backbone == 'resnet50': 29 | self.backbone = resnet.resnet50(pretrained=True) 30 | self.feat_ids = list(range(4, 17)) 31 | self.extract_feats = extract_feat_res 32 | nbottlenecks = [3, 4, 6, 3] 33 | elif backbone == 'resnet101': 34 | self.backbone = resnet.resnet101(pretrained=True) 35 | self.feat_ids = list(range(4, 34)) 36 | self.extract_feats = extract_feat_res 37 | nbottlenecks = [3, 4, 23, 3] 38 | else: 39 | raise Exception('Unavailable backbone: %s' % backbone) 40 | 41 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks))) 42 | self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)]) 43 | self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3] 44 | self.backbone.eval() 45 | self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:]))) 46 | self.cross_entropy_loss = nn.CrossEntropyLoss() 47 | 48 | def forward(self, query_img, support_img, support_mask): 49 | with torch.no_grad(): 50 | query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 51 | support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 52 | support_feats = self.mask_feature(support_feats, support_mask.clone()) 53 | corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids) 54 | 55 | logit_mask = self.hpn_learner(corr) 56 | if not self.use_original_imgsize: 57 | logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True) 58 | 59 | return logit_mask 60 | 61 | def mask_feature(self, features, support_mask): 62 | for idx, feature in enumerate(features): 63 | mask = F.interpolate(support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True) 64 | features[idx] = features[idx] * mask 65 | return features 66 | 67 | def predict_mask_nshot(self, batch, nshot): 68 | 69 | # Perform multiple prediction given (nshot) number of different support sets 70 | logit_mask_agg = 0 71 | for s_idx in range(nshot): 72 | logit_mask = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx]) 73 | 74 | if self.use_original_imgsize: 75 | org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()]) 76 | logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True) 77 | 78 | logit_mask_agg += logit_mask.argmax(dim=1).clone() 79 | if nshot == 1: return logit_mask_agg 80 | 81 | # Average & quantize predictions given threshold (=0.5) 82 | bsz = logit_mask_agg.size(0) 83 | max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0] 84 | max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()]) 85 | max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1) 86 | pred_mask = logit_mask_agg.float() / max_vote 87 | pred_mask[pred_mask < 0.5] = 0 88 | pred_mask[pred_mask >= 0.5] = 1 89 | 90 | return pred_mask 91 | 92 | def compute_objective(self, logit_mask, gt_mask): 93 | bsz = logit_mask.size(0) 94 | logit_mask = logit_mask.view(bsz, 2, -1) 95 | gt_mask = gt_mask.view(bsz, -1).long() 96 | 97 | return self.cross_entropy_loss(logit_mask, gt_mask) 98 | 99 | def train_mode(self): 100 | self.train() 101 | self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging 102 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/model/learner.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from fewshot_data.model.base.conv4d import CenterPivotConv4d as Conv4d 6 | 7 | 8 | class HPNLearner(nn.Module): 9 | def __init__(self, inch): 10 | super(HPNLearner, self).__init__() 11 | 12 | def make_building_block(in_channel, out_channels, kernel_sizes, spt_strides, group=4): 13 | assert len(out_channels) == len(kernel_sizes) == len(spt_strides) 14 | 15 | building_block_layers = [] 16 | for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)): 17 | inch = in_channel if idx == 0 else out_channels[idx - 1] 18 | ksz4d = (ksz,) * 4 19 | str4d = (1, 1) + (stride,) * 2 20 | pad4d = (ksz // 2,) * 4 21 | 22 | building_block_layers.append(Conv4d(inch, outch, ksz4d, str4d, pad4d)) 23 | building_block_layers.append(nn.GroupNorm(group, outch)) 24 | building_block_layers.append(nn.ReLU(inplace=True)) 25 | 26 | return nn.Sequential(*building_block_layers) 27 | 28 | outch1, outch2, outch3 = 16, 64, 128 29 | 30 | # Squeezing building blocks 31 | self.encoder_layer4 = make_building_block(inch[0], [outch1, outch2, outch3], [3, 3, 3], [2, 2, 2]) 32 | self.encoder_layer3 = make_building_block(inch[1], [outch1, outch2, outch3], [5, 3, 3], [4, 2, 2]) 33 | self.encoder_layer2 = make_building_block(inch[2], [outch1, outch2, outch3], [5, 5, 3], [4, 4, 2]) 34 | 35 | # Mixing building blocks 36 | self.encoder_layer4to3 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) 37 | self.encoder_layer3to2 = make_building_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) 38 | 39 | # Decoder layers 40 | self.decoder1 = nn.Sequential(nn.Conv2d(outch3, outch3, (3, 3), padding=(1, 1), bias=True), 41 | nn.ReLU(), 42 | nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True), 43 | nn.ReLU()) 44 | 45 | self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True), 46 | nn.ReLU(), 47 | nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True)) 48 | 49 | def interpolate_support_dims(self, hypercorr, spatial_size=None): 50 | bsz, ch, ha, wa, hb, wb = hypercorr.size() 51 | hypercorr = hypercorr.permute(0, 4, 5, 1, 2, 3).contiguous().view(bsz * hb * wb, ch, ha, wa) 52 | hypercorr = F.interpolate(hypercorr, spatial_size, mode='bilinear', align_corners=True) 53 | o_hb, o_wb = spatial_size 54 | hypercorr = hypercorr.view(bsz, hb, wb, ch, o_hb, o_wb).permute(0, 3, 4, 5, 1, 2).contiguous() 55 | return hypercorr 56 | 57 | def forward(self, hypercorr_pyramid): 58 | 59 | # Encode hypercorrelations from each layer (Squeezing building blocks) 60 | hypercorr_sqz4 = self.encoder_layer4(hypercorr_pyramid[0]) 61 | hypercorr_sqz3 = self.encoder_layer3(hypercorr_pyramid[1]) 62 | hypercorr_sqz2 = self.encoder_layer2(hypercorr_pyramid[2]) 63 | 64 | # Propagate encoded 4D-tensor (Mixing building blocks) 65 | hypercorr_sqz4 = self.interpolate_support_dims(hypercorr_sqz4, hypercorr_sqz3.size()[-4:-2]) 66 | hypercorr_mix43 = hypercorr_sqz4 + hypercorr_sqz3 67 | hypercorr_mix43 = self.encoder_layer4to3(hypercorr_mix43) 68 | 69 | hypercorr_mix43 = self.interpolate_support_dims(hypercorr_mix43, hypercorr_sqz2.size()[-4:-2]) 70 | hypercorr_mix432 = hypercorr_mix43 + hypercorr_sqz2 71 | hypercorr_mix432 = self.encoder_layer3to2(hypercorr_mix432) 72 | 73 | bsz, ch, ha, wa, hb, wb = hypercorr_mix432.size() 74 | hypercorr_encoded = hypercorr_mix432.view(bsz, ch, ha, wa, -1).mean(dim=-1) 75 | 76 | # Decode the encoded 4D-tensor 77 | hypercorr_decoded = self.decoder1(hypercorr_encoded) 78 | upsample_size = (hypercorr_decoded.size(-1) * 2,) * 2 79 | hypercorr_decoded = F.interpolate(hypercorr_decoded, upsample_size, mode='bilinear', align_corners=True) 80 | logit_mask = self.decoder2(hypercorr_decoded) 81 | 82 | return logit_mask 83 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/test.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze testing code """ 2 | import argparse 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import torch 7 | 8 | from fewshot_data.model.hsnet import HypercorrSqueezeNetwork 9 | from fewshot_data.common.logger import Logger, AverageMeter 10 | from fewshot_data.common.vis import Visualizer 11 | from fewshot_data.common.evaluation import Evaluator 12 | from fewshot_data.common import utils 13 | from fewshot_data.data.dataset import FSSDataset 14 | 15 | 16 | def test(model, dataloader, nshot): 17 | r""" Test HSNet """ 18 | 19 | # Freeze randomness during testing for reproducibility 20 | utils.fix_randseed(0) 21 | average_meter = AverageMeter(dataloader.dataset) 22 | 23 | for idx, batch in enumerate(dataloader): 24 | 25 | # 1. Hypercorrelation Squeeze Networks forward pass 26 | batch = utils.to_cuda(batch) 27 | pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot) 28 | 29 | assert pred_mask.size() == batch['query_mask'].size() 30 | 31 | # 2. Evaluate prediction 32 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch) 33 | average_meter.update(area_inter, area_union, batch['class_id'], loss=None) 34 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) 35 | 36 | # Visualize predictions 37 | if Visualizer.visualize: 38 | Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'], 39 | batch['query_img'], batch['query_mask'], 40 | pred_mask, batch['class_id'], idx, 41 | area_inter[1].float() / area_union[1].float()) 42 | # Write evaluation results 43 | average_meter.write_result('Test', 0) 44 | miou, fb_iou = average_meter.compute_iou() 45 | 46 | return miou, fb_iou 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | # Arguments parsing 52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation') 53 | parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN') 54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 55 | parser.add_argument('--logpath', type=str, default='') 56 | parser.add_argument('--bsz', type=int, default=1) 57 | parser.add_argument('--nworker', type=int, default=0) 58 | parser.add_argument('--load', type=str, default='') 59 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 60 | parser.add_argument('--nshot', type=int, default=1) 61 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101']) 62 | parser.add_argument('--visualize', action='store_true') 63 | parser.add_argument('--use_original_imgsize', action='store_true') 64 | args = parser.parse_args() 65 | Logger.initialize(args, training=False) 66 | 67 | # Model initialization 68 | model = HypercorrSqueezeNetwork(args.backbone, args.use_original_imgsize) 69 | model.eval() 70 | Logger.log_params(model) 71 | 72 | # Device setup 73 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 74 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 75 | model = nn.DataParallel(model) 76 | model.to(device) 77 | 78 | # Load trained model 79 | if args.load == '': raise Exception('Pretrained model not specified.') 80 | model.load_state_dict(torch.load(args.load)) 81 | 82 | # Helper classes (for testing) initialization 83 | Evaluator.initialize() 84 | Visualizer.initialize(args.visualize) 85 | 86 | # Dataset initialization 87 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize) 88 | dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) 89 | 90 | # Test HSNet 91 | with torch.no_grad(): 92 | test_miou, test_fb_iou = test(model, dataloader_test, args.nshot) 93 | Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item())) 94 | Logger.info('==================== Finished Testing ====================') 95 | -------------------------------------------------------------------------------- /model/lseg/fewshot_data/train.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze training (validation) code """ 2 | import argparse 3 | 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torch 7 | 8 | from fewshot_data.model.hsnet import HypercorrSqueezeNetwork 9 | from fewshot_data.common.logger import Logger, AverageMeter 10 | from fewshot_data.common.evaluation import Evaluator 11 | from fewshot_data.common import utils 12 | from fewshot_data.data.dataset import FSSDataset 13 | 14 | 15 | def train(epoch, model, dataloader, optimizer, training): 16 | r""" Train HSNet """ 17 | 18 | # Force randomness during training / freeze randomness during testing 19 | utils.fix_randseed(None) if training else utils.fix_randseed(0) 20 | model.module.train_mode() if training else model.module.eval() 21 | average_meter = AverageMeter(dataloader.dataset) 22 | 23 | for idx, batch in enumerate(dataloader): 24 | # 1. Hypercorrelation Squeeze Networks forward pass 25 | batch = utils.to_cuda(batch) 26 | logit_mask = model(batch['query_img'], batch['support_imgs'].squeeze(1), batch['support_masks'].squeeze(1)) 27 | pred_mask = logit_mask.argmax(dim=1) 28 | 29 | # 2. Compute loss & update model parameters 30 | loss = model.module.compute_objective(logit_mask, batch['query_mask']) 31 | if training: 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | # 3. Evaluate prediction 37 | area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch) 38 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone()) 39 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50) 40 | 41 | # Write evaluation results 42 | average_meter.write_result('Training' if training else 'Validation', epoch) 43 | avg_loss = utils.mean(average_meter.loss_buf) 44 | miou, fb_iou = average_meter.compute_iou() 45 | 46 | return avg_loss, miou, fb_iou 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | # Arguments parsing 52 | parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation') 53 | parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN') 54 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 55 | parser.add_argument('--logpath', type=str, default='') 56 | parser.add_argument('--bsz', type=int, default=20) 57 | parser.add_argument('--lr', type=float, default=1e-3) 58 | parser.add_argument('--niter', type=int, default=2000) 59 | parser.add_argument('--nworker', type=int, default=8) 60 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 61 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101']) 62 | args = parser.parse_args() 63 | Logger.initialize(args, training=True) 64 | 65 | # Model initialization 66 | model = HypercorrSqueezeNetwork(args.backbone, False) 67 | Logger.log_params(model) 68 | 69 | # Device setup 70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 72 | model = nn.DataParallel(model) 73 | model.to(device) 74 | 75 | # Helper classes (for training) initialization 76 | optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}]) 77 | Evaluator.initialize() 78 | 79 | # Dataset initialization 80 | FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False) 81 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn') 82 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val') 83 | 84 | # Train HSNet 85 | best_val_miou = float('-inf') 86 | best_val_loss = float('inf') 87 | for epoch in range(args.niter): 88 | 89 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True) 90 | with torch.no_grad(): 91 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False) 92 | 93 | # Save the best model 94 | if val_miou > best_val_miou: 95 | best_val_miou = val_miou 96 | Logger.save_model_miou(model, epoch, val_miou) 97 | 98 | Logger.tbd_writer.add_scalars('fewshot_data/data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch) 99 | Logger.tbd_writer.add_scalars('fewshot_data/data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch) 100 | Logger.tbd_writer.add_scalars('fewshot_data/data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch) 101 | Logger.tbd_writer.flush() 102 | Logger.tbd_writer.close() 103 | Logger.info('==================== Finished Training ====================') 104 | -------------------------------------------------------------------------------- /model/lseg/inputs/cat1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/lseg/inputs/cat1.jpeg -------------------------------------------------------------------------------- /model/lseg/label_files/fewshot_coco.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | trafficlight 11 | firehydrant 12 | stopsign 13 | parkingmeter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sportsball 34 | kite 35 | baseballbat 36 | baseballglove 37 | skateboard 38 | surfboard 39 | tennisracket 40 | bottle 41 | wineglass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hotdog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cellphone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddybear 79 | hairdrier 80 | toothbrush -------------------------------------------------------------------------------- /model/lseg/label_files/fewshot_pascal.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /model/lseg/lseg/lseg_module_zs.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from argparse import ArgumentParser 6 | import pytorch_lightning as pl 7 | from .lsegmentation_module_zs import LSegmentationModuleZS 8 | from .models.lseg_net_zs import LSegNetZS, LSegRNNetZS 9 | from encoding.models.sseg.base import up_kwargs 10 | import os 11 | import clip 12 | import numpy as np 13 | from scipy import signal 14 | import glob 15 | from PIL import Image 16 | import matplotlib.pyplot as plt 17 | import pandas as pd 18 | 19 | 20 | class LSegModuleZS(LSegmentationModuleZS): 21 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): 22 | super(LSegModuleZS, self).__init__( 23 | data_path, dataset, batch_size, base_lr, max_epochs, **kwargs 24 | ) 25 | label_list = self.get_labels(dataset) 26 | self.len_dataloader = len(label_list) 27 | 28 | # print(kwargs) 29 | if kwargs["use_pretrained"] in ['False', False]: 30 | use_pretrained = False 31 | elif kwargs["use_pretrained"] in ['True', True]: 32 | use_pretrained = True 33 | 34 | if kwargs["backbone"] in ["clip_resnet101"]: 35 | self.net = LSegRNNetZS( 36 | label_list=label_list, 37 | backbone=kwargs["backbone"], 38 | features=kwargs["num_features"], 39 | aux=kwargs["aux"], 40 | use_pretrained=use_pretrained, 41 | arch_option=kwargs["arch_option"], 42 | block_depth=kwargs["block_depth"], 43 | activation=kwargs["activation"], 44 | ) 45 | else: 46 | self.net = LSegNetZS( 47 | label_list=label_list, 48 | backbone=kwargs["backbone"], 49 | features=kwargs["num_features"], 50 | aux=kwargs["aux"], 51 | use_pretrained=use_pretrained, 52 | arch_option=kwargs["arch_option"], 53 | block_depth=kwargs["block_depth"], 54 | activation=kwargs["activation"], 55 | ) 56 | 57 | def get_labels(self, dataset): 58 | labels = [] 59 | path = 'label_files/fewshot_{}.txt'.format(dataset) 60 | assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path) 61 | f = open(path, 'r') 62 | lines = f.readlines() 63 | for line in lines: 64 | label = line.strip() 65 | labels.append(label) 66 | f.close() 67 | print(labels) 68 | return labels 69 | 70 | @staticmethod 71 | def add_model_specific_args(parent_parser): 72 | parser = LSegmentationModuleZS.add_model_specific_args(parent_parser) 73 | parser = ArgumentParser(parents=[parser]) 74 | 75 | parser.add_argument( 76 | "--backbone", 77 | type=str, 78 | default="vitb16_384", 79 | help="backbone network", 80 | ) 81 | 82 | parser.add_argument( 83 | "--num_features", 84 | type=int, 85 | default=256, 86 | help="number of featurs that go from encoder to decoder", 87 | ) 88 | 89 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate") 90 | 91 | parser.add_argument( 92 | "--finetune_weights", type=str, help="load weights to finetune from" 93 | ) 94 | 95 | parser.add_argument( 96 | "--no-scaleinv", 97 | default=True, 98 | action="store_false", 99 | help="turn off scaleinv layers", 100 | ) 101 | 102 | parser.add_argument( 103 | "--no-batchnorm", 104 | default=False, 105 | action="store_true", 106 | help="turn off batchnorm", 107 | ) 108 | 109 | parser.add_argument( 110 | "--widehead", default=False, action="store_true", help="wider output head" 111 | ) 112 | 113 | parser.add_argument( 114 | "--widehead_hr", 115 | default=False, 116 | action="store_true", 117 | help="wider output head", 118 | ) 119 | 120 | parser.add_argument( 121 | "--use_pretrained", 122 | type=str, 123 | default="True", 124 | help="whether use the default model to intialize the model", 125 | ) 126 | 127 | parser.add_argument( 128 | "--arch_option", 129 | type=int, 130 | default=0, 131 | help="which kind of architecture to be used", 132 | ) 133 | 134 | parser.add_argument( 135 | "--block_depth", 136 | type=int, 137 | default=0, 138 | help="how many blocks should be used", 139 | ) 140 | 141 | parser.add_argument( 142 | "--activation", 143 | choices=['relu', 'lrelu', 'tanh'], 144 | default="relu", 145 | help="use which activation to activate the block", 146 | ) 147 | 148 | return parser 149 | -------------------------------------------------------------------------------- /model/lseg/modules/lseg_module_zs.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from argparse import ArgumentParser 6 | import pytorch_lightning as pl 7 | from .lsegmentation_module_zs import LSegmentationModuleZS 8 | from .models.lseg_net_zs import LSegNetZS, LSegRNNetZS 9 | from encoding.models.sseg.base import up_kwargs 10 | import os 11 | import clip 12 | import numpy as np 13 | from scipy import signal 14 | import glob 15 | from PIL import Image 16 | import matplotlib.pyplot as plt 17 | import pandas as pd 18 | 19 | 20 | class LSegModuleZS(LSegmentationModuleZS): 21 | def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): 22 | super(LSegModuleZS, self).__init__( 23 | data_path, dataset, batch_size, base_lr, max_epochs, **kwargs 24 | ) 25 | label_list = self.get_labels(dataset) 26 | self.len_dataloader = len(label_list) 27 | 28 | # print(kwargs) 29 | if kwargs["use_pretrained"] in ['False', False]: 30 | use_pretrained = False 31 | elif kwargs["use_pretrained"] in ['True', True]: 32 | use_pretrained = True 33 | 34 | if kwargs["backbone"] in ["clip_resnet101"]: 35 | self.net = LSegRNNetZS( 36 | label_list=label_list, 37 | backbone=kwargs["backbone"], 38 | features=kwargs["num_features"], 39 | aux=kwargs["aux"], 40 | use_pretrained=use_pretrained, 41 | arch_option=kwargs["arch_option"], 42 | block_depth=kwargs["block_depth"], 43 | activation=kwargs["activation"], 44 | ) 45 | else: 46 | self.net = LSegNetZS( 47 | label_list=label_list, 48 | backbone=kwargs["backbone"], 49 | features=kwargs["num_features"], 50 | aux=kwargs["aux"], 51 | use_pretrained=use_pretrained, 52 | arch_option=kwargs["arch_option"], 53 | block_depth=kwargs["block_depth"], 54 | activation=kwargs["activation"], 55 | ) 56 | 57 | def get_labels(self, dataset): 58 | labels = [] 59 | path = 'label_files/fewshot_{}.txt'.format(dataset) 60 | assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path) 61 | f = open(path, 'r') 62 | lines = f.readlines() 63 | for line in lines: 64 | label = line.strip() 65 | labels.append(label) 66 | f.close() 67 | print(labels) 68 | return labels 69 | 70 | @staticmethod 71 | def add_model_specific_args(parent_parser): 72 | parser = LSegmentationModuleZS.add_model_specific_args(parent_parser) 73 | parser = ArgumentParser(parents=[parser]) 74 | 75 | parser.add_argument( 76 | "--backbone", 77 | type=str, 78 | default="vitb16_384", 79 | help="backbone network", 80 | ) 81 | 82 | parser.add_argument( 83 | "--num_features", 84 | type=int, 85 | default=256, 86 | help="number of featurs that go from encoder to decoder", 87 | ) 88 | 89 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate") 90 | 91 | parser.add_argument( 92 | "--finetune_weights", type=str, help="load weights to finetune from" 93 | ) 94 | 95 | parser.add_argument( 96 | "--no-scaleinv", 97 | default=True, 98 | action="store_false", 99 | help="turn off scaleinv layers", 100 | ) 101 | 102 | parser.add_argument( 103 | "--no-batchnorm", 104 | default=False, 105 | action="store_true", 106 | help="turn off batchnorm", 107 | ) 108 | 109 | parser.add_argument( 110 | "--widehead", default=False, action="store_true", help="wider output head" 111 | ) 112 | 113 | parser.add_argument( 114 | "--widehead_hr", 115 | default=False, 116 | action="store_true", 117 | help="wider output head", 118 | ) 119 | 120 | parser.add_argument( 121 | "--use_pretrained", 122 | type=str, 123 | default="True", 124 | help="whether use the default model to intialize the model", 125 | ) 126 | 127 | parser.add_argument( 128 | "--arch_option", 129 | type=int, 130 | default=0, 131 | help="which kind of architecture to be used", 132 | ) 133 | 134 | parser.add_argument( 135 | "--block_depth", 136 | type=int, 137 | default=0, 138 | help="how many blocks should be used", 139 | ) 140 | 141 | parser.add_argument( 142 | "--activation", 143 | choices=['relu', 'lrelu', 'tanh'], 144 | default="relu", 145 | help="use which activation to activate the block", 146 | ) 147 | 148 | return parser 149 | -------------------------------------------------------------------------------- /model/lseg/prepare_ade20k.py: -------------------------------------------------------------------------------- 1 | # + 2 | # revised from https://github.com/zhanghang1989/PyTorch-Encoding/blob/331ecdd5306104614cb414b16fbcd9d1a8d40e1e/scripts/prepare_ade20k.py 3 | 4 | """Prepare ADE20K dataset""" 5 | import os 6 | import shutil 7 | import argparse 8 | import zipfile 9 | from encoding.utils import download, mkdir 10 | # - 11 | 12 | _TARGET_DIR = os.path.expanduser('../datasets/') 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Initialize ADE20K dataset.', 17 | epilog='Example: python prepare_ade20k.py', 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 20 | args = parser.parse_args() 21 | return args 22 | 23 | def download_ade(path, overwrite=False): 24 | _AUG_DOWNLOAD_URLS = [ 25 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'), 26 | ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),] 27 | download_dir = path 28 | mkdir(download_dir) 29 | for url, checksum in _AUG_DOWNLOAD_URLS: 30 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum) 31 | # extract 32 | with zipfile.ZipFile(filename,"r") as zip_ref: 33 | zip_ref.extractall(path=path) 34 | 35 | 36 | if __name__ == '__main__': 37 | args = parse_args() 38 | mkdir(os.path.expanduser('../datasets/')) 39 | if args.download_dir is not None: 40 | if os.path.isdir(_TARGET_DIR): 41 | os.remove(_TARGET_DIR) 42 | # make symlink 43 | os.symlink(args.download_dir, _TARGET_DIR) 44 | else: 45 | download_ade(_TARGET_DIR, overwrite=False) 46 | -------------------------------------------------------------------------------- /model/lseg/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = lseg 3 | version = 1.0 4 | license = MIT License 5 | 6 | [options] 7 | packages = lseg 8 | 9 | -------------------------------------------------------------------------------- /model/lseg/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | 5 | -------------------------------------------------------------------------------- /model/lseg/train_lseg.py: -------------------------------------------------------------------------------- 1 | from modules.lseg_module import LSegModule 2 | from utils import do_training, get_default_argument_parser 3 | 4 | if __name__ == "__main__": 5 | parser = LSegModule.add_model_specific_args(get_default_argument_parser()) 6 | args = parser.parse_args() 7 | do_training(args, LSegModule) 8 | -------------------------------------------------------------------------------- /model/lseg_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | import cv2 4 | from torch.nn import functional as F 5 | from model.lseg.modules.lseg_module import LSegModule 6 | from model.lseg.additional_utils.models import LSeg_MultiEvalModule 7 | from torchvision import transforms 8 | 9 | 10 | class LSeg: 11 | embedding_dim = 512 12 | 13 | def __init__(self, weight_path=None): 14 | # set memory growth to avoid out of memory 15 | if weight_path is not None: 16 | module = LSegModule.load_from_checkpoint( 17 | checkpoint_path=weight_path, 18 | backbone="clip_vitl16_384", 19 | data_path=None, 20 | num_features=256, 21 | batch_size=1, 22 | base_lr=1e-3, 23 | max_epochs=100, 24 | augment=False, 25 | aux=True, 26 | aux_weight=0, 27 | ignore_index=255, 28 | dataset="ade20k", 29 | se_loss=False, 30 | se_weight=0, 31 | arch_option=0, 32 | block_depth=0, 33 | activation="lrelu", 34 | ) 35 | self.transform = transforms.Compose(module.val_transform.transforms) 36 | net = module.net.cuda() 37 | scales = [1.0] 38 | self.evaluator = LSeg_MultiEvalModule(module, scales=scales, flip=False).cuda().eval() 39 | self.text_model = module.net.clip_pretrained.to(torch.float32).cuda() 40 | else: 41 | self.text_model, _ = clip.load("ViT-B/32", device='cuda', jit=False) 42 | self.text_model = self.text_model.to(torch.float32).cuda() 43 | 44 | def set_predefined_cls(self, cls): 45 | self.classes = ".".join(cls) 46 | print(self.classes) 47 | 48 | def set_predefined_part(self, cls, parts): 49 | self.classes = ".".join([f"{cls}:{e}" for e in parts]) 50 | print(self.classes) 51 | 52 | def get_text(self, vocabulary, prefix_prompt="a "): 53 | vocabulary = vocabulary.split(".") 54 | texts = [prefix_prompt + x.lower().replace(":", " ").replace("_", " ") for x in vocabulary] 55 | return texts 56 | 57 | def extract_image_feature(self, img_dir, img_size=None, regional_pool=True): 58 | """Extract per-pixel LSeg features. 59 | Only receives image path as input. 60 | """ 61 | 62 | # load RGB image 63 | image = cv2.imread(str(img_dir)) 64 | image = cv2.resize(image, (img_size[1], img_size[0])) 65 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 66 | 67 | # run LSeg 68 | x = self.transform(image).cuda() 69 | feat_2d = self.evaluator.compute_features(x.unsqueeze(0)) 70 | 71 | feat_2d = feat_2d[0].cpu() # [512, h, w] 72 | 73 | return feat_2d 74 | 75 | def extract_text_feature(self, labelset): 76 | # "ViT-B/32" # the model that LSeg uses 77 | if isinstance(labelset, str): 78 | lines = labelset.split(",") 79 | elif isinstance(labelset, list): 80 | lines = labelset 81 | else: 82 | raise NotImplementedError 83 | 84 | labels = [] 85 | for line in lines: 86 | label = line 87 | labels.append(label) 88 | text = clip.tokenize(labels) 89 | text = text.cuda() 90 | text_features = self.text_model.encode_text(text) 91 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 92 | 93 | return text_features 94 | -------------------------------------------------------------------------------- /model/openseg_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from tensorflow import io 4 | import tensorflow as tf2 5 | import tensorflow.compat.v1 as tf 6 | 7 | 8 | def read_bytes(path): 9 | """Read bytes for OpenSeg model running.""" 10 | 11 | with io.gfile.GFile(path, "rb") as f: 12 | file_bytes = f.read() 13 | return file_bytes 14 | 15 | 16 | class OpenSeg: 17 | embedding_dim = 768 18 | 19 | def __init__(self, weight_path, text_model_name, set_memory_growth=True): 20 | # set memory growth to avoid out of memory 21 | if weight_path is not None: 22 | print("Load Tensorflow OpenSeg model...") 23 | gpus = tf.config.experimental.list_physical_devices("GPU") 24 | for gpu in gpus: 25 | tf2.config.experimental.set_memory_growth(gpu, set_memory_growth) 26 | self.model = tf2.saved_model.load( 27 | weight_path, 28 | tags=[tf.saved_model.tag_constants.SERVING], 29 | ) 30 | self.text_emb = tf.zeros([1, 1, 768]) 31 | 32 | if text_model_name is not None: 33 | print("Loading CLIP {} model...".format(text_model_name)) 34 | self.text_model, _ = clip.load(text_model_name, device="cuda", jit=False) 35 | 36 | def set_predefined_cls(self, cls): 37 | self.classes = ".".join(cls) 38 | print(self.classes) 39 | 40 | def set_predefined_part(self, cls, parts): 41 | self.classes = ".".join([f"{cls}:{e}" for e in parts]) 42 | print(self.classes) 43 | 44 | def get_text(self, vocabulary, prefix_prompt="a "): 45 | vocabulary = vocabulary.split(".") 46 | texts = [prefix_prompt + x.lower().replace(":", " ").replace("_", " ") for x in vocabulary] 47 | return texts 48 | 49 | def extract_image_feature(self, img_dir, img_size=None, regional_pool=True): 50 | """Extract per-pixel OpenSeg features. 51 | Only receives image path as input. 52 | """ 53 | 54 | # load RGB image 55 | np_image_string = read_bytes(img_dir) 56 | # run OpenSeg 57 | results = self.model.signatures["serving_default"]( 58 | inp_image_bytes=tf.convert_to_tensor(np_image_string), inp_text_emb=self.text_emb 59 | ) 60 | img_info = results["image_info"] 61 | crop_sz = [ 62 | int(img_info[0, 0] * img_info[2, 0]), 63 | int(img_info[0, 1] * img_info[2, 1]), 64 | ] 65 | if regional_pool: 66 | image_embedding_feat = results["ppixel_ave_feat"][:, : crop_sz[0], : crop_sz[1]] 67 | else: 68 | image_embedding_feat = results["image_embedding_feat"][:, : crop_sz[0], : crop_sz[1]] 69 | if img_size is not None: 70 | feat_2d = tf.cast( 71 | tf.image.resize_nearest_neighbor(image_embedding_feat, img_size, align_corners=True)[0], 72 | dtype=tf.float16, 73 | ).numpy() 74 | else: 75 | feat_2d = tf.cast(image_embedding_feat[[0]], dtype=tf.float16).numpy() 76 | 77 | feat_2d = torch.from_numpy(feat_2d).permute(2, 0, 1) 78 | 79 | return feat_2d 80 | 81 | def extract_text_feature(self, labelset): 82 | # "ViT-L/14@336px" # the big model that OpenSeg uses 83 | if isinstance(labelset, str): 84 | lines = labelset.split(",") 85 | elif isinstance(labelset, list): 86 | lines = labelset 87 | else: 88 | raise NotImplementedError 89 | 90 | labels = [] 91 | for line in lines: 92 | label = line 93 | labels.append(label) 94 | text = clip.tokenize(labels) 95 | text = text.cuda() 96 | text_features = self.text_model.encode_text(text) 97 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 98 | 99 | return text_features 100 | -------------------------------------------------------------------------------- /model/render_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | import skimage.transform as sktf 6 | from dataset.scannet.scannet_constants import SCANNET20_CLASS_LABELS, COCOMAP_CLASS_LABELS, COLORMAP 7 | 8 | 9 | def get_text_features(model_2d, dataset_name="scannet20"): 10 | if isinstance(dataset_name, list): 11 | labelset = dataset_name 12 | elif dataset_name == "scannet20": 13 | labelset = list(SCANNET20_CLASS_LABELS) 14 | elif dataset_name == "cocomap": 15 | labelset = list(COCOMAP_CLASS_LABELS) 16 | 17 | # add unlabeled label and palette 18 | labelset = ["other"] + labelset 19 | 20 | palette = torch.tensor(COLORMAP[:len(labelset)+1]).cuda().flatten() 21 | text_features = model_2d.extract_text_feature(labelset).float() 22 | 23 | return palette, text_features 24 | 25 | 26 | def render_palette(label, palette): 27 | shape = label.shape 28 | label = label.reshape(-1) 29 | new_3d = torch.zeros((label.shape[0], 3)).cuda() 30 | u_index = torch.unique(label) 31 | for index in u_index: 32 | new_3d[label == index] = torch.tensor( 33 | [ 34 | palette[index * 3] / 255.0, 35 | palette[index * 3 + 1] / 255.0, 36 | palette[index * 3 + 2] / 255.0, 37 | ] 38 | ).cuda() 39 | 40 | return new_3d.reshape(*shape, 3).permute(2, 0, 1) 41 | 42 | 43 | def get_mapped_label(config, image_path, label_mapping): 44 | label_path = str(image_path).replace("color", "label-filt").replace(".jpg", ".png") 45 | if not os.path.exists(label_path): 46 | return None 47 | 48 | label_img = np.array(imageio.imread(label_path)) 49 | label_img = sktf.resize(label_img, [config.eval.height, config.eval.width], order=0, preserve_range=True) 50 | mapped = np.copy(label_img) 51 | for k, v in label_mapping.items(): 52 | mapped[label_img == k] = v 53 | label_img = mapped.astype(np.uint8) 54 | 55 | return label_img 56 | -------------------------------------------------------------------------------- /model/samclip_predictor.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import clip 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | import time 7 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 8 | from segment_anything.automask import SamAutomaticMaskGenerator as MultiScaleMaskGenerator 9 | 10 | 11 | class SAMCLIP: 12 | embedding_dim = 768 13 | 14 | def __init__(self, sam_path, clip_model_name): 15 | if sam_path is not None: 16 | print("Load SAM model...") 17 | sam = sam_model_registry["vit_h"](checkpoint=sam_path) 18 | sam.cuda() 19 | self.sam = SamPredictor(sam) 20 | self.mask_generator = MultiScaleMaskGenerator( 21 | model=sam, 22 | points_per_side=32, 23 | pred_iou_thresh=0.7, 24 | box_nms_thresh=0.7, 25 | stability_score_thresh=0.85, 26 | # crop_n_layers=1, 27 | # crop_n_points_downscale_factor=1, 28 | min_mask_region_area=100, 29 | ) 30 | 31 | if clip_model_name is not None: 32 | print("Loading CLIP {} model...".format(clip_model_name)) 33 | self.clip_model, self.preprocess = clip.load(clip_model_name, device="cuda", jit=False) 34 | 35 | def set_predefined_cls(self, cls): 36 | self.classes = ".".join(cls) 37 | print(self.classes) 38 | 39 | def set_predefined_part(self, cls, parts): 40 | self.classes = ".".join([f"{cls}:{e}" for e in parts]) 41 | print(self.classes) 42 | 43 | def get_text(self, vocabulary, prefix_prompt="a "): 44 | vocabulary = vocabulary.split(".") 45 | texts = [prefix_prompt + x.lower().replace(":", " ").replace("_", " ") for x in vocabulary] 46 | return texts 47 | 48 | def extract_image_feature(self, img_dir, img_size=None): 49 | """Extract per-pixel OpenSeg features. 50 | Only receives image path as input. 51 | """ 52 | image = cv2.imread(str(img_dir)) 53 | image = cv2.resize(image, (img_size[1], img_size[0])) 54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 55 | masks, masks_s, masks_m, masks_l = self.mask_generator.generate(image) 56 | 57 | sorted_masks = sorted(masks, key=lambda x: x["area"], reverse=True) 58 | pad_imgs = [] 59 | segs = [] 60 | scores = [] 61 | for mask in sorted_masks: 62 | bbox = mask["bbox"] 63 | seg_mask = mask["segmentation"] 64 | score = mask["stability_score"] * mask["predicted_iou"] 65 | x1, y1 = int(bbox[0]), int(bbox[1]) 66 | x2, y2 = int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]) 67 | # h_thresh = int(image.shape[0] * 0.1) 68 | # w_thresh = int(image.shape[1] * 0.1) 69 | # if x2 < w_thresh or x1 > image.shape[1] - w_thresh or y2 < h_thresh or y1 > image.shape[0] - h_thresh: 70 | # continue 71 | 72 | crop = (image * seg_mask[:, :, np.newaxis])[y1:y2, x1:x2] 73 | h, w, _ = crop.shape 74 | 75 | l = max(h, w) 76 | pad = np.zeros((l, l, 3), dtype=np.uint8) 77 | if h > w: 78 | pad[:, (h - w) // 2 : (h - w) // 2 + w, :] = crop 79 | else: 80 | pad[(w - h) // 2 : (w - h) // 2 + h, :, :] = crop 81 | pad_imgs.append(cv2.resize(pad, (336, 336))) 82 | segs.append(seg_mask) 83 | scores.append(score) 84 | 85 | if len(pad_imgs) == 0: 86 | print("Error: no mask detected!") 87 | return torch.zeros((768, image.shape[0], image.shape[1]), dtype=torch.half) 88 | 89 | pad_imgs = np.stack(pad_imgs, axis=0) # B, H, W, 3 90 | pad_imgs = torch.from_numpy(pad_imgs.astype("float32")).permute(0, 3, 1, 2) / 255.0 91 | pad_imgs = torchvision.transforms.Normalize( 92 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 93 | )(pad_imgs).cuda() 94 | 95 | crop_features = self.clip_model.encode_image(pad_imgs).cpu() 96 | features = torch.zeros((768, image.shape[0], image.shape[1]), dtype=torch.half) 97 | for idx, seg_mask in enumerate(segs): 98 | features[:, seg_mask] += crop_features[idx].unsqueeze(1) # * scores[idx] 99 | 100 | features = features / (features.norm(dim=0, keepdim=True) + 1e-8) 101 | 102 | return features 103 | 104 | def extract_text_feature(self, labelset): 105 | # "ViT-L/14@336px" # the big model that OpenSeg uses 106 | if isinstance(labelset, str): 107 | lines = labelset.split(",") 108 | elif isinstance(labelset, list): 109 | lines = labelset 110 | else: 111 | raise NotImplementedError 112 | 113 | labels = [] 114 | for line in lines: 115 | label = line 116 | labels.append(label) 117 | text = clip.tokenize(labels) 118 | text = text.cuda() 119 | text_features = self.clip_model.encode_text(text) 120 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 121 | 122 | return text_features 123 | -------------------------------------------------------------------------------- /model/vlpart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/model/vlpart/__init__.py -------------------------------------------------------------------------------- /model/vlpart/vlpart_fast_rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # VLPart: Going denser with open-vocabulary part segmentation 3 | # Written by Peize Sun and Shoufa Chen 4 | import logging 5 | from typing import Callable, Dict, List, Optional, Tuple, Union 6 | import math 7 | import copy 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from fvcore.nn import sigmoid_focal_loss_jit, giou_loss, smooth_l1_loss 14 | import fvcore.nn.weight_init as weight_init 15 | 16 | from detectron2.config import configurable 17 | from detectron2.data.detection_utils import get_fed_loss_cls_weights 18 | from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple 19 | from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss 20 | from detectron2.structures import Boxes, Instances, BitMasks, pairwise_iou, pairwise_ioa 21 | from detectron2.utils.events import get_event_storage 22 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers 23 | 24 | 25 | class TexeEmbedClassifier(nn.Module): 26 | def __init__( 27 | self, 28 | input_shape: ShapeSpec, 29 | zs_weight_dim: int = 1024, 30 | norm_weight: bool = True, 31 | norm_temperature: float = 50.0, 32 | ): 33 | super().__init__() 34 | if isinstance(input_shape, int): # some backward compatibility 35 | input_shape = ShapeSpec(channels=input_shape) 36 | input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) 37 | self.norm_weight = norm_weight 38 | self.norm_temperature = norm_temperature 39 | 40 | self.linear = nn.Linear(input_size, zs_weight_dim) 41 | 42 | def forward(self, x, text_embed): 43 | 44 | x = self.linear(x) 45 | if self.norm_weight: 46 | x = self.norm_temperature * F.normalize(x, p=2, dim=1) 47 | x = torch.mm(x, text_embed) 48 | return x 49 | 50 | 51 | class VLMFastRCNNOutputLayers(nn.Module): 52 | def __init__( 53 | self, 54 | input_shape: ShapeSpec, 55 | box2box_transform, 56 | use_sigmoid_ce: bool = True, 57 | test_score_thresh: float = 0.0, 58 | test_nms_thresh: float = 0.5, 59 | test_topk_per_image: int = 100, 60 | ): 61 | super().__init__() 62 | if isinstance(input_shape, int): # some backward compatibility 63 | input_shape = ShapeSpec(channels=input_shape) 64 | 65 | self.box2box_transform = box2box_transform 66 | self.use_sigmoid_ce = use_sigmoid_ce 67 | self.test_score_thresh = test_score_thresh 68 | self.test_nms_thresh = test_nms_thresh 69 | self.test_topk_per_image = test_topk_per_image 70 | 71 | input_size = input_shape.channels * \ 72 | (input_shape.width or 1) * (input_shape.height or 1) 73 | 74 | # bbox_pred 75 | self.bbox_pred = nn.Sequential( 76 | nn.Linear(input_size, input_size), 77 | nn.ReLU(inplace=True), 78 | nn.Linear(input_size, 4) 79 | ) 80 | # cls_score 81 | self.cls_score = TexeEmbedClassifier(input_shape) 82 | 83 | def forward(self, x, text_embed): 84 | if x.dim() > 2: 85 | x = torch.flatten(x, start_dim=1) 86 | cls_scores = self.cls_score(x, text_embed) 87 | proposal_deltas = self.bbox_pred(x) 88 | 89 | return cls_scores, proposal_deltas 90 | 91 | def predict_boxes(self, predictions, proposals): 92 | if not len(proposals): 93 | return [] 94 | _, proposal_deltas = predictions 95 | num_prop_per_image = [len(p) for p in proposals] 96 | proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) 97 | 98 | predict_boxes = self.box2box_transform.apply_deltas( 99 | proposal_deltas, 100 | proposal_boxes, 101 | ) # Nx(KxB) 102 | return predict_boxes.split(num_prop_per_image) 103 | 104 | def predict_probs(self, predictions, proposals): 105 | cls_scores, _ = predictions 106 | num_inst_per_image = [len(p) for p in proposals] 107 | cls_scores = cls_scores.split(num_inst_per_image, dim=0) 108 | 109 | final_scores = [] 110 | for cls_score in cls_scores: 111 | final_score = cls_score.sigmoid() if self.use_sigmoid_ce else F.softmax(cls_score, dim=-1) 112 | final_scores.append(final_score) 113 | return final_scores 114 | 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow[and-cuda]==2.14.0 2 | timm==0.6.13 3 | scipy==1.11.4 4 | omegaconf 5 | imageio==2.31.4 6 | scikit-image==0.22.0 7 | opencv-python 8 | ninja 9 | viser==0.1.17 10 | pytorch-lightning==2.2.4 11 | git+https://github.com/zhanghang1989/PyTorch-Encoding/ 12 | git+https://github.com/openai/CLIP.git 13 | git+https://github.com/facebookresearch/detectron2.git 14 | submodules/rgbd-rasterization 15 | submodules/channel-rasterization 16 | submodules/simple-knn 17 | submodules/segment-anything -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | from .scene import Scene 2 | -------------------------------------------------------------------------------- /scene/blender_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | from pathlib import Path 6 | 7 | from utils.sh_utils import SH2RGB 8 | from utils.graphics_utils import BasicPointCloud, focal2fov, fov2focal 9 | from utils.dataset_utils import SceneInfo, CameraInfo, getNerfppNorm, storePly, fetchPly 10 | 11 | 12 | def readCamerasFromTransforms(path, transformsfile, white_background, extensions=[".png", ".jpg", ""]): 13 | cam_infos = [] 14 | first_img = None 15 | 16 | with open(os.path.join(path, transformsfile)) as json_file: 17 | contents = json.load(json_file) 18 | frames = contents["frames"] 19 | for idx, frame in enumerate(frames): 20 | for extension in extensions: 21 | cam_name = os.path.join(path, frame["file_path"] + extension) 22 | if os.path.exists(cam_name): 23 | break 24 | 25 | # NeRF 'transform_matrix' is a camera-to-world transform 26 | c2w = np.array(frame["transform_matrix"]) 27 | if np.isinf(c2w).any(): 28 | continue 29 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 30 | c2w[:3, 1:3] *= -1 31 | 32 | # get the world-to-camera transform and set R, T 33 | w2c = np.linalg.inv(c2w) 34 | R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 35 | T = w2c[:3, 3] 36 | 37 | image_path = Path(cam_name) 38 | image_name = image_path.stem 39 | if first_img is None: 40 | first_img = np.array(Image.open(image_path).convert("RGBA")) 41 | width, height = first_img.shape[1], first_img.shape[0] 42 | 43 | if "fl_x" in frame: 44 | fovx = focal2fov(frame["fl_x"], width) 45 | fovy = focal2fov(frame["fl_y"], height) 46 | else: 47 | fovx = contents["camera_angle_x"] 48 | fovy = focal2fov(fov2focal(fovx, width), height) 49 | 50 | if "intrinsics" in frame: 51 | intrinsics = np.array(frame["intrinsics"]) 52 | else: 53 | intrinsics = np.zeros((4, 4), dtype=np.float32) 54 | intrinsics[0, 0] = fov2focal(fovx, width) 55 | intrinsics[1, 1] = fov2focal(fovy, height) 56 | intrinsics[2, 2] = 1 57 | intrinsics[3, 3] = 1 58 | intrinsics[0, 2] = width / 2 59 | intrinsics[1, 2] = height / 2 60 | 61 | cam_infos.append( 62 | CameraInfo( 63 | uid=idx, 64 | R=R, 65 | T=T, 66 | FovY=fovy, 67 | FovX=fovx, 68 | image_path=image_path, 69 | image_name=image_name, 70 | width=width, 71 | height=height, 72 | intrinsics=intrinsics, 73 | ) 74 | ) 75 | 76 | return cam_infos 77 | 78 | 79 | def readBlenderInfo(path, white_background, eval, extensions=[".png", ".jpg", ""]): 80 | print("Reading Training Transforms") 81 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extensions) 82 | print("Reading Test Transforms") 83 | try: 84 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extensions) 85 | except: 86 | print("Reading Test Transforms Error! Skip it.") 87 | test_cam_infos = [] 88 | 89 | if not eval: 90 | train_cam_infos.extend(test_cam_infos) 91 | test_cam_infos = [] 92 | 93 | nerf_normalization = getNerfppNorm(train_cam_infos) 94 | 95 | ply_path = os.path.join(path, "points3d.ply") 96 | if not os.path.exists(ply_path): 97 | # Since this data set has no colmap data, we start with random points 98 | num_pts = 100_000 99 | print(f"Generating random point cloud ({num_pts})...") 100 | 101 | # We create random points inside the bounds of the synthetic Blender scenes 102 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 103 | shs = np.random.random((num_pts, 3)) / 255.0 104 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 105 | 106 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 107 | try: 108 | pcd = fetchPly(ply_path) 109 | except: 110 | pcd = None 111 | 112 | scene_info = SceneInfo( 113 | point_cloud=pcd, 114 | train_cameras=train_cam_infos, 115 | test_cameras=test_cam_infos, 116 | nerf_normalization=nerf_normalization, 117 | ply_path=ply_path, 118 | ) 119 | return scene_info 120 | -------------------------------------------------------------------------------- /scene/camera.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | 18 | class MiniCam: 19 | def __init__( 20 | self, 21 | width, 22 | height, 23 | fovy, 24 | fovx, 25 | znear, 26 | zfar, 27 | world_view_transform, 28 | full_proj_transform, 29 | ): 30 | self.image_width = width 31 | self.image_height = height 32 | self.FoVy = fovy 33 | self.FoVx = fovx 34 | self.znear = znear 35 | self.zfar = zfar 36 | self.world_view_transform = world_view_transform 37 | self.full_proj_transform = full_proj_transform 38 | view_inv = torch.inverse(self.world_view_transform) 39 | self.camera_center = view_inv[3][:3] 40 | 41 | 42 | class Camera(nn.Module): 43 | def __init__( 44 | self, 45 | colmap_id, 46 | R, 47 | T, 48 | FoVx, 49 | FoVy, 50 | image, 51 | gt_alpha_mask, 52 | image_name, 53 | image_path, 54 | uid, 55 | trans=np.array([0.0, 0.0, 0.0]), 56 | scale=1.0, 57 | device="cuda", 58 | ): 59 | super(Camera, self).__init__() 60 | 61 | self.uid = uid 62 | self.colmap_id = colmap_id 63 | self.R = R 64 | self.T = T 65 | self.FoVx = FoVx 66 | self.FoVy = FoVy 67 | self.image_name = image_name 68 | 69 | self.data_device = torch.device(device) 70 | 71 | self.image_path = image_path 72 | self.original_image = image.clamp(0.0, 1.0) 73 | self.image_width = self.original_image.shape[2] 74 | self.image_height = self.original_image.shape[1] 75 | 76 | if gt_alpha_mask is not None: 77 | self.original_image *= gt_alpha_mask 78 | else: 79 | self.original_image *= torch.ones((1, self.image_height, self.image_width)) 80 | 81 | self.zfar = 100.0 82 | self.znear = 0.01 83 | 84 | self.trans = trans 85 | self.scale = scale 86 | 87 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) 88 | self.projection_matrix = getProjectionMatrix( 89 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 90 | ).transpose(0, 1) 91 | self.full_proj_transform = ( 92 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0)) 93 | ).squeeze(0) 94 | self.camera_center = self.world_view_transform.inverse()[3, :3] 95 | 96 | def cuda(self): 97 | self.original_image = self.original_image.to(self.data_device) 98 | self.world_view_transform = self.world_view_transform.to(self.data_device) 99 | self.projection_matrix = self.projection_matrix.to(self.data_device) 100 | self.full_proj_transform = self.full_proj_transform.to(self.data_device) 101 | self.camera_center = self.camera_center.to(self.data_device) 102 | -------------------------------------------------------------------------------- /scene/scannet_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | from pathlib import Path 6 | 7 | from utils.sh_utils import SH2RGB 8 | from utils.graphics_utils import BasicPointCloud, focal2fov, fov2focal 9 | from utils.dataset_utils import SceneInfo, CameraInfo, getNerfppNorm, storePly, fetchPly 10 | 11 | 12 | def readScanNetInfo(path, white_background, eval, llffhold=8, extensions=[".png", ".jpg"]): 13 | path = Path(path) 14 | image_dir = path / "color" 15 | pose_dir = path / "pose" 16 | image_sorted = list(sorted(image_dir.iterdir(), key=lambda x: int(x.name.split(".")[0]))) 17 | pose_sorted = list(sorted(pose_dir.iterdir(), key=lambda x: int(x.name.split(".")[0]))) 18 | 19 | cam_infos = [] 20 | K = np.loadtxt(os.path.join(path, "intrinsic/intrinsic_color.txt")) 21 | first_img = np.array(Image.open(image_sorted[0]).convert("RGBA")) 22 | width, height = first_img.shape[1], first_img.shape[0] 23 | 24 | fovx = focal2fov(K[0, 0], K[0, 2] * 2) 25 | fovy = focal2fov(K[1, 1], K[1, 2] * 2) 26 | 27 | i = 0 28 | for img, pose in zip(image_sorted, pose_sorted): 29 | i += 1 30 | idx = int(img.name.split(".")[0]) 31 | c2w = np.loadtxt(pose) 32 | c2w = np.array(c2w).reshape(4, 4).astype(np.float32) 33 | # ScanNet pose use COLMAP coordinates (Y down, Z forward), so no need to flip the axis 34 | # c2w[:3, 1:3] *= -1 35 | # We cannot accept files directly, as some of the poses are invalid 36 | if np.isinf(c2w).any(): 37 | continue 38 | 39 | w2c = np.linalg.inv(c2w) 40 | R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 41 | T = w2c[:3, 3] 42 | 43 | image_path = img 44 | image_name = Path(img).stem 45 | 46 | cam_infos.append( 47 | CameraInfo( 48 | uid=idx, 49 | R=R, 50 | T=T, 51 | FovY=fovy, 52 | FovX=fovx, 53 | image_path=image_path, 54 | image_name=image_name, 55 | width=width, 56 | height=height, 57 | intrinsics=K, 58 | ) 59 | ) 60 | 61 | nerf_normalization = getNerfppNorm(cam_infos) 62 | 63 | ply_path = os.path.join(path, "points3d.ply") 64 | if not os.path.exists(ply_path): 65 | # Since this data set has no colmap data, we start with random points 66 | num_pts = 100_000 67 | print(f"Generating random point cloud ({num_pts})...") 68 | 69 | # We create random points inside the bounds of the synthetic Blender scenes 70 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 71 | shs = np.random.random((num_pts, 3)) / 255.0 72 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 73 | 74 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 75 | 76 | pcd = fetchPly(ply_path) 77 | 78 | if eval: 79 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 80 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 81 | else: 82 | train_cam_infos = cam_infos 83 | test_cam_infos = [] 84 | 85 | scene_info = SceneInfo( 86 | point_cloud=pcd, 87 | train_cameras=train_cam_infos, 88 | test_cameras=test_cam_infos, 89 | nerf_normalization=nerf_normalization, 90 | ply_path=ply_path, 91 | ) 92 | return scene_info 93 | -------------------------------------------------------------------------------- /scene/scene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from torch.utils.data import Dataset 5 | from scene.blender_loader import readBlenderInfo 6 | from scene.colmap_loader import readColmapInfo 7 | from scene.scannet_loader import readScanNetInfo 8 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, loadCam 9 | 10 | sceneLoadTypeCallbacks = { 11 | "Colmap": readColmapInfo, 12 | "Blender": readBlenderInfo, 13 | "ScanNet": readScanNetInfo, 14 | } 15 | 16 | 17 | class SceneDataset(Dataset): 18 | def __init__(self, scene_info, args, split="train", resolution_scale=1.0): 19 | self.scene_info = scene_info 20 | self.args = args 21 | self.resolution_scale = resolution_scale 22 | if split == "train": 23 | self.camera_info = scene_info.train_cameras 24 | elif split == "test": 25 | self.camera_info = scene_info.test_cameras 26 | else: 27 | raise NotImplementedError("Undefined split") 28 | 29 | def __len__(self): 30 | return len(self.camera_info) 31 | 32 | def __getitem__(self, index): 33 | return loadCam(self.args, index, self.camera_info[index], self.resolution_scale) 34 | 35 | 36 | class Scene: 37 | def __init__(self, args, resolution_scales=[1.0]): 38 | self.train_cameras = {} 39 | self.test_cameras = {} 40 | 41 | # load scene 42 | if os.path.exists(os.path.join(args.scene_path, "pose")): 43 | print("Found pose directory, assuming ScanNet data set!") 44 | scene_info = sceneLoadTypeCallbacks["ScanNet"]( 45 | args.scene_path, 46 | args.colmap_images, 47 | args.test_cameras, 48 | args.colmap_eval_hold, 49 | ) 50 | elif os.path.exists(os.path.join(args.scene_path, "sparse")): 51 | print("Found sparse directory, assuming Colmap data set!") 52 | scene_info = sceneLoadTypeCallbacks["Colmap"]( 53 | args.scene_path, 54 | args.colmap_images, 55 | args.test_cameras, 56 | args.colmap_eval_hold, 57 | ) 58 | elif os.path.exists(os.path.join(args.scene_path, "transforms_train.json")): 59 | print("Found transforms_train.json file, assuming Blender synthetic data set!") 60 | scene_info = sceneLoadTypeCallbacks["Blender"](args.scene_path, args.white_background, args.test_cameras) 61 | # elif os.path.exists(os.path.join(args.scene_path, "traj_w_c.txt")): 62 | # print("Found traj_w_c.txt file, assuming Replica data set!") 63 | # scene_info = sceneLoadTypeCallbacks["Replica"](args.scene_path, args.white_background, args.test_cameras) 64 | else: 65 | assert False, "Could not recognize scene type!" 66 | 67 | self.pcd = scene_info.point_cloud 68 | self.cameras_extent = scene_info.nerf_normalization["radius"] 69 | 70 | for resolution_scale in resolution_scales: 71 | print("Loading Training Cameras...") 72 | self.train_cameras[resolution_scale] = SceneDataset(scene_info, args, "train", resolution_scale) 73 | print("Loading Test Cameras...") 74 | self.test_cameras[resolution_scale] = SceneDataset(scene_info, args, "test", resolution_scale) 75 | 76 | def getTrainCameras(self, scale=1.0): 77 | return self.train_cameras[scale] 78 | 79 | def getTestCameras(self, scale=1.0): 80 | return self.test_cameras[scale] 81 | -------------------------------------------------------------------------------- /submodules/channel-rasterization/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | diff_gaussian_rasterization.egg-info/ 3 | dist/ 4 | -------------------------------------------------------------------------------- /submodules/channel-rasterization/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(DiffRast LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(CudaRasterizer 23 | cuda_rasterizer/backward.h 24 | cuda_rasterizer/backward.cu 25 | cuda_rasterizer/forward.h 26 | cuda_rasterizer/forward.cu 27 | cuda_rasterizer/auxiliary.h 28 | cuda_rasterizer/rasterizer_impl.cu 29 | cuda_rasterizer/rasterizer_impl.h 30 | cuda_rasterizer/rasterizer.h 31 | ) 32 | 33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") 34 | 35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) 36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /submodules/channel-rasterization/LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /submodules/channel-rasterization/README.md: -------------------------------------------------------------------------------- 1 | # Differential Gaussian Rasterization 2 | 3 | Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. 4 | 5 |
6 |
7 |

BibTeX

8 |
@Article{kerbl3Dgaussians,
 9 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
10 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
11 |       journal      = {ACM Transactions on Graphics},
12 |       number       = {4},
13 |       volume       = {42},
14 |       month        = {July},
15 |       year         = {2023},
16 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
17 | }
18 |
19 |
-------------------------------------------------------------------------------- /submodules/channel-rasterization/cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, 28 | const float* bg_color, 29 | const float2* means2D, 30 | const float4* conic_opacity, 31 | const float* colors, 32 | const float* final_Ts, 33 | const uint32_t* n_contrib, 34 | const float* dL_dpixels, 35 | float3* dL_dmean2D, 36 | float4* dL_dconic2D, 37 | float* dL_dopacity, 38 | float* dL_dcolors); 39 | 40 | void preprocess( 41 | int P, int D, int M, 42 | const float3* means, 43 | const int* radii, 44 | const float* shs, 45 | const bool* clamped, 46 | const glm::vec3* scales, 47 | const glm::vec4* rotations, 48 | const float scale_modifier, 49 | const float* cov3Ds, 50 | const float* view, 51 | const float* proj, 52 | const float focal_x, float focal_y, 53 | const float tan_fovx, float tan_fovy, 54 | const glm::vec3* campos, 55 | const float3* dL_dmean2D, 56 | const float* dL_dconics, 57 | glm::vec3* dL_dmeans, 58 | float* dL_dcolor, 59 | float* dL_dcov3D, 60 | float* dL_dsh, 61 | glm::vec3* dL_dscale, 62 | glm::vec4* dL_drot); 63 | } 64 | 65 | #endif -------------------------------------------------------------------------------- /submodules/channel-rasterization/cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /submodules/channel-rasterization/cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float *orig_points, 26 | const glm::vec3 *scales, 27 | const float scale_modifier, 28 | const glm::vec4 *rotations, 29 | const float *opacities, 30 | const float *shs, 31 | bool *clamped, 32 | const float *cov3D_precomp, 33 | const float *colors_precomp, 34 | const float *viewmatrix, 35 | const float *projmatrix, 36 | const glm::vec3 *cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int *radii, 41 | float2 *points_xy_image, 42 | float *depths, 43 | float *cov3Ds, 44 | float *colors, 45 | float4 *conic_opacity, 46 | const dim3 grid, 47 | uint32_t *tiles_touched, 48 | bool prefiltered, 49 | const int num_channels); 50 | 51 | // Main rasterization method. 52 | void render( 53 | const dim3 grid, dim3 block, 54 | const uint2 *ranges, 55 | const uint32_t *point_list, 56 | int W, int H, 57 | const float2 *points_xy_image, 58 | const float *features, 59 | const float4 *conic_opacity, 60 | float *final_T, 61 | uint32_t *n_contrib, 62 | const float *bg_color, 63 | float *out_color, 64 | const int num_channels); 65 | } 66 | 67 | #endif -------------------------------------------------------------------------------- /submodules/channel-rasterization/cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | static void markVisible( 24 | int P, 25 | float *means3D, 26 | float *viewmatrix, 27 | float *projmatrix, 28 | bool *present); 29 | 30 | static int forward( 31 | std::function geometryBuffer, 32 | std::function binningBuffer, 33 | std::function imageBuffer, 34 | const int P, int D, int M, 35 | const float *background, 36 | const int width, int height, 37 | const float *means3D, 38 | const float *shs, 39 | const float *colors_precomp, 40 | const float *opacities, 41 | const float *scales, 42 | const float scale_modifier, 43 | const float *rotations, 44 | const float *cov3D_precomp, 45 | const float *viewmatrix, 46 | const float *projmatrix, 47 | const float *cam_pos, 48 | const float tan_fovx, float tan_fovy, 49 | const bool prefiltered, 50 | const int num_channel, 51 | float *out_color, 52 | int *radii = nullptr, 53 | bool debug = false); 54 | 55 | static void backward( 56 | const int P, int D, int M, int R, 57 | const float *background, 58 | const int width, int height, 59 | const float *means3D, 60 | const float *shs, 61 | const float *colors_precomp, 62 | const float *scales, 63 | const float scale_modifier, 64 | const float *rotations, 65 | const float *cov3D_precomp, 66 | const float *viewmatrix, 67 | const float *projmatrix, 68 | const float *campos, 69 | const float tan_fovx, float tan_fovy, 70 | const int *radii, 71 | char *geom_buffer, 72 | char *binning_buffer, 73 | char *image_buffer, 74 | const float *dL_dpix, 75 | float *dL_dmean2D, 76 | float *dL_dconic, 77 | float *dL_dopacity, 78 | float *dL_dcolor, 79 | float *dL_dmean3D, 80 | float *dL_dcov3D, 81 | float *dL_dsh, 82 | float *dL_dscale, 83 | float *dL_drot, 84 | bool debug); 85 | }; 86 | }; 87 | 88 | #endif -------------------------------------------------------------------------------- /submodules/channel-rasterization/cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | int* internal_radii; 36 | float2* means2D; 37 | float* cov3D; 38 | float4* conic_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | float* accum_alpha; 51 | 52 | static ImageState fromChunk(char*& chunk, size_t N); 53 | }; 54 | 55 | struct BinningState 56 | { 57 | size_t sorting_size; 58 | uint64_t* point_list_keys_unsorted; 59 | uint64_t* point_list_keys; 60 | uint32_t* point_list_unsorted; 61 | uint32_t* point_list; 62 | char* list_sorting_space; 63 | 64 | static BinningState fromChunk(char*& chunk, size_t P); 65 | }; 66 | 67 | template 68 | size_t required(size_t P) 69 | { 70 | char* size = nullptr; 71 | T::fromChunk(size, P); 72 | return ((size_t)size) + 128; 73 | } 74 | 75 | CudaRasterizer::GeometryState preprocess( 76 | std::function geometryBuffer, 77 | const int P, int D, int M, 78 | const float* background, 79 | const int width, int height, 80 | const float* means3D, 81 | const float* shs, 82 | const float* colors_precomp, 83 | const float* opacities, 84 | const float* scales, 85 | const float scale_modifier, 86 | const float* rotations, 87 | const float* cov3D_precomp, 88 | const float* viewmatrix, 89 | const float* projmatrix, 90 | const float* cam_pos, 91 | const float tan_fovx, float tan_fovy, 92 | const bool prefiltered, 93 | int* radii, 94 | bool debug); 95 | }; -------------------------------------------------------------------------------- /submodules/channel-rasterization/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 16 | { 17 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 18 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 19 | m.def("mark_visible", &markVisible); 20 | } -------------------------------------------------------------------------------- /submodules/channel-rasterization/rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | RasterizeGaussiansCUDA( 20 | const torch::Tensor &background, 21 | const torch::Tensor &means3D, 22 | const torch::Tensor &colors, 23 | const torch::Tensor &opacity, 24 | const torch::Tensor &scales, 25 | const torch::Tensor &rotations, 26 | const float scale_modifier, 27 | const torch::Tensor &cov3D_precomp, 28 | const torch::Tensor &viewmatrix, 29 | const torch::Tensor &projmatrix, 30 | const float tan_fovx, 31 | const float tan_fovy, 32 | const int image_height, 33 | const int image_width, 34 | const torch::Tensor &sh, 35 | const int degree, 36 | const torch::Tensor &campos, 37 | const bool prefiltered, 38 | const bool debug, 39 | const int num_channels); 40 | 41 | std::tuple 42 | RasterizeGaussiansBackwardCUDA( 43 | const torch::Tensor &background, 44 | const torch::Tensor &means3D, 45 | const torch::Tensor &radii, 46 | const torch::Tensor &colors, 47 | const torch::Tensor &scales, 48 | const torch::Tensor &rotations, 49 | const float scale_modifier, 50 | const torch::Tensor &cov3D_precomp, 51 | const torch::Tensor &viewmatrix, 52 | const torch::Tensor &projmatrix, 53 | const float tan_fovx, 54 | const float tan_fovy, 55 | const torch::Tensor &dL_dout_color, 56 | const torch::Tensor &sh, 57 | const int degree, 58 | const torch::Tensor &campos, 59 | const torch::Tensor &geomBuffer, 60 | const int R, 61 | const torch::Tensor &binningBuffer, 62 | const torch::Tensor &imageBuffer, 63 | const bool debug); 64 | 65 | torch::Tensor markVisible( 66 | torch::Tensor &means3D, 67 | torch::Tensor &viewmatrix, 68 | torch::Tensor &projmatrix); -------------------------------------------------------------------------------- /submodules/channel-rasterization/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | os.path.dirname(os.path.abspath(__file__)) 17 | 18 | setup( 19 | name="channel_rasterization", 20 | packages=["channel_rasterization"], 21 | ext_modules=[ 22 | CUDAExtension( 23 | name="channel_rasterization._C", 24 | sources=[ 25 | "cuda_rasterizer/rasterizer_impl.cu", 26 | "cuda_rasterizer/forward.cu", 27 | "cuda_rasterizer/backward.cu", 28 | "rasterize_points.cu", 29 | "ext.cpp", 30 | ], 31 | extra_compile_args={ 32 | "nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")] 33 | }, 34 | ) 35 | ], 36 | cmdclass={"build_ext": BuildExtension}, 37 | ) 38 | -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(DiffRast LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(CudaRasterizer 23 | cuda_rasterizer/backward.h 24 | cuda_rasterizer/backward.cu 25 | cuda_rasterizer/forward.h 26 | cuda_rasterizer/forward.cu 27 | cuda_rasterizer/auxiliary.h 28 | cuda_rasterizer/rasterizer_impl.cu 29 | cuda_rasterizer/rasterizer_impl.h 30 | cuda_rasterizer/rasterizer.h 31 | ) 32 | 33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "75;86") 34 | 35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) 36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/README.md: -------------------------------------------------------------------------------- 1 | 2 | This is a clone of https://github.com/graphdeco-inria/diff-gaussian-rasterization/tree/59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d 3 | 4 | However, it has been edited by Jonathon Luiten to also render 'depth' as well as colour. 5 | 6 | This is needed for Jonathon's Dynamic 3D Gaussians work which can be found here: http://dynamic3dgaussians.github.io 7 | 8 | By default, the depth is calculated as 'median depth', where the depth is the depth of the Gaussian center which causes the accumulated rays transmittance to drop below 0.5. 9 | If a ray doesn't reach this threshold it is given a default depth of 15. This median depth avoids the depth floaters around depth boundaries that 'mean depth' would give. 10 | If 'mean depth' is preffered, there is commented out code which also calculates 'mean depth'. 11 | See lines 307-308 and 363-372 of cuda_rasterizer/forward.cu. 12 | 13 | Note that the backward pass for the depth has not been implemented, so it won't work for training with depth ground-truth. 14 | 15 | Note that the code in this repo follows the (non commercial) license of Inria as laid out in LICENSE.md 16 | 17 | If you're using this as part of the Dynamic 3D Gaussians code, just follow the installation instruction for that codebase. 18 | 19 | To install this stand-alone I have been doing the following (although I don't think this is necessarily the best way): 20 | ``` 21 | git clone git@github.com:git@github.com:JonathonLuiten/diff-gaussian-rasterization-w-depth.git 22 | cd diff-gaussian-rasterization-w-depth 23 | python setup.py install 24 | pip install . 25 | ``` 26 | 27 | Original readme below: 28 | 29 | # Differential Gaussian Rasterization 30 | 31 | Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. 32 | 33 |
34 |
35 |

BibTeX

36 |
@Article{kerbl3Dgaussians,
37 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
38 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
39 |       journal      = {ACM Transactions on Graphics},
40 |       number       = {4},
41 |       volume       = {42},
42 |       month        = {July},
43 |       year         = {2023},
44 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
45 | }
46 |
47 |
-------------------------------------------------------------------------------- /submodules/rgbd-rasterization/cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, 28 | const float* bg_color, 29 | const float2* means2D, 30 | const float4* conic_opacity, 31 | const float* colors, 32 | const float* final_Ts, 33 | const uint32_t* n_contrib, 34 | const float* dL_dpixels, 35 | float3* dL_dmean2D, 36 | float4* dL_dconic2D, 37 | float* dL_dopacity, 38 | float* dL_dcolors); 39 | 40 | void preprocess( 41 | int P, int D, int M, 42 | const float3* means, 43 | const int* radii, 44 | const float* shs, 45 | const bool* clamped, 46 | const glm::vec3* scales, 47 | const glm::vec4* rotations, 48 | const float scale_modifier, 49 | const float* cov3Ds, 50 | const float* view, 51 | const float* proj, 52 | const float focal_x, float focal_y, 53 | const float tan_fovx, float tan_fovy, 54 | const glm::vec3* campos, 55 | const float3* dL_dmean2D, 56 | const float* dL_dconics, 57 | glm::vec3* dL_dmeans, 58 | float* dL_dcolor, 59 | float* dL_dcov3D, 60 | float* dL_dsh, 61 | glm::vec3* dL_dscale, 62 | glm::vec4* dL_drot); 63 | } 64 | 65 | #endif -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | const glm::vec3* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* cov3D_precomp, 33 | const float* colors_precomp, 34 | const float* viewmatrix, 35 | const float* projmatrix, 36 | const glm::vec3* cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int* radii, 41 | float2* points_xy_image, 42 | float* depths, 43 | float* cov3Ds, 44 | float* colors, 45 | float4* conic_opacity, 46 | const dim3 grid, 47 | uint32_t* tiles_touched, 48 | bool prefiltered); 49 | 50 | // Main rasterization method. 51 | void render( 52 | const dim3 grid, dim3 block, 53 | const uint2* ranges, 54 | const uint32_t* point_list, 55 | int W, int H, 56 | const float2* points_xy_image, 57 | const float* features, 58 | const float4* conic_opacity, 59 | float* final_T, 60 | uint32_t* n_contrib, 61 | const float* bg_color, 62 | float* out_color, 63 | const float* depth, 64 | float* out_depth); 65 | } 66 | 67 | 68 | #endif -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present); 30 | 31 | static int forward( 32 | std::function geometryBuffer, 33 | std::function binningBuffer, 34 | std::function imageBuffer, 35 | const int P, int D, int M, 36 | const float* background, 37 | const int width, int height, 38 | const float* means3D, 39 | const float* shs, 40 | const float* colors_precomp, 41 | const float* opacities, 42 | const float* scales, 43 | const float scale_modifier, 44 | const float* rotations, 45 | const float* cov3D_precomp, 46 | const float* viewmatrix, 47 | const float* projmatrix, 48 | const float* cam_pos, 49 | const float tan_fovx, float tan_fovy, 50 | const bool prefiltered, 51 | float* out_color, 52 | float* out_depth, 53 | int* radii = nullptr); 54 | 55 | static void backward( 56 | const int P, int D, int M, int R, 57 | const float* background, 58 | const int width, int height, 59 | const float* means3D, 60 | const float* shs, 61 | const float* colors_precomp, 62 | const float* scales, 63 | const float scale_modifier, 64 | const float* rotations, 65 | const float* cov3D_precomp, 66 | const float* viewmatrix, 67 | const float* projmatrix, 68 | const float* campos, 69 | const float tan_fovx, float tan_fovy, 70 | const int* radii, 71 | char* geom_buffer, 72 | char* binning_buffer, 73 | char* image_buffer, 74 | const float* dL_dpix, 75 | float* dL_dmean2D, 76 | float* dL_dconic, 77 | float* dL_dopacity, 78 | float* dL_dcolor, 79 | float* dL_dmean3D, 80 | float* dL_dcov3D, 81 | float* dL_dsh, 82 | float* dL_dscale, 83 | float* dL_drot); 84 | }; 85 | }; 86 | 87 | #endif -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | int* internal_radii; 36 | float2* means2D; 37 | float* cov3D; 38 | float4* conic_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | float* accum_alpha; 51 | 52 | static ImageState fromChunk(char*& chunk, size_t N); 53 | }; 54 | 55 | struct BinningState 56 | { 57 | size_t sorting_size; 58 | uint64_t* point_list_keys_unsorted; 59 | uint64_t* point_list_keys; 60 | uint32_t* point_list_unsorted; 61 | uint32_t* point_list; 62 | char* list_sorting_space; 63 | 64 | static BinningState fromChunk(char*& chunk, size_t P); 65 | }; 66 | 67 | template 68 | size_t required(size_t P) 69 | { 70 | char* size = nullptr; 71 | T::fromChunk(size, P); 72 | return ((size_t)size) + 128; 73 | } 74 | }; -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 18 | m.def("mark_visible", &markVisible); 19 | } -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | RasterizeGaussiansCUDA( 20 | const torch::Tensor& background, 21 | const torch::Tensor& means3D, 22 | const torch::Tensor& colors, 23 | const torch::Tensor& opacity, 24 | const torch::Tensor& scales, 25 | const torch::Tensor& rotations, 26 | const float scale_modifier, 27 | const torch::Tensor& cov3D_precomp, 28 | const torch::Tensor& viewmatrix, 29 | const torch::Tensor& projmatrix, 30 | const float tan_fovx, 31 | const float tan_fovy, 32 | const int image_height, 33 | const int image_width, 34 | const torch::Tensor& sh, 35 | const int degree, 36 | const torch::Tensor& campos, 37 | const bool prefiltered); 38 | 39 | std::tuple 40 | RasterizeGaussiansBackwardCUDA( 41 | const torch::Tensor& background, 42 | const torch::Tensor& means3D, 43 | const torch::Tensor& radii, 44 | const torch::Tensor& colors, 45 | const torch::Tensor& scales, 46 | const torch::Tensor& rotations, 47 | const float scale_modifier, 48 | const torch::Tensor& cov3D_precomp, 49 | const torch::Tensor& viewmatrix, 50 | const torch::Tensor& projmatrix, 51 | const float tan_fovx, 52 | const float tan_fovy, 53 | const torch::Tensor& dL_dout_color, 54 | const torch::Tensor& sh, 55 | const int degree, 56 | const torch::Tensor& campos, 57 | const torch::Tensor& geomBuffer, 58 | const int R, 59 | const torch::Tensor& binningBuffer, 60 | const torch::Tensor& imageBuffer); 61 | 62 | torch::Tensor markVisible( 63 | torch::Tensor& means3D, 64 | torch::Tensor& viewmatrix, 65 | torch::Tensor& projmatrix); -------------------------------------------------------------------------------- /submodules/rgbd-rasterization/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | os.path.dirname(os.path.abspath(__file__)) 17 | 18 | setup( 19 | name="rgbd_rasterization", 20 | packages=["rgbd_rasterization"], 21 | ext_modules=[ 22 | CUDAExtension( 23 | name="rgbd_rasterization._C", 24 | sources=[ 25 | "cuda_rasterizer/rasterizer_impl.cu", 26 | "cuda_rasterizer/forward.cu", 27 | "cuda_rasterizer/backward.cu", 28 | "rasterize_points.cu", 29 | "ext.cpp", 30 | ], 31 | extra_compile_args={ 32 | "nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")] 33 | }, 34 | ) 35 | ], 36 | cmdclass={"build_ext": BuildExtension}, 37 | ) 38 | -------------------------------------------------------------------------------- /submodules/segment-anything/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002 3 | max-line-length = 100 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | per-file-ignores = 7 | **/__init__.py:F401,F403,E402 8 | -------------------------------------------------------------------------------- /submodules/segment-anything/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /submodules/segment-anything/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /submodules/segment-anything/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .build_sam_hq import ( 15 | build_sam_hq, 16 | build_sam_hq_vit_h, 17 | build_sam_hq_vit_l, 18 | build_sam_hq_vit_b, 19 | sam_hq_model_registry, 20 | ) 21 | from .predictor import SamPredictor 22 | from .automatic_mask_generator import SamAutomaticMaskGenerator 23 | -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam, 49 | "vit_h": build_sam, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/build_sam_hq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_hq_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam_hq = build_sam_hq_vit_h 25 | 26 | 27 | def build_sam_hq_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_hq_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_hq_model_registry = { 48 | "default": build_sam_hq_vit_h, 49 | "vit_h": build_sam_hq_vit_h, 50 | "vit_l": build_sam_hq_vit_l, 51 | "vit_b": build_sam_hq_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoderHQ( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | vit_dim=encoder_embed_dim, 99 | ), 100 | pixel_mean=[123.675, 116.28, 103.53], 101 | pixel_std=[58.395, 57.12, 57.375], 102 | ) 103 | # sam.eval() 104 | if checkpoint is not None: 105 | with open(checkpoint, "rb") as f: 106 | state_dict = torch.load(f) 107 | info = sam.load_state_dict(state_dict, strict=False) 108 | print(info) 109 | for n, p in sam.named_parameters(): 110 | if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: 111 | p.requires_grad = False 112 | 113 | return sam -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder_hq import MaskDecoderHQ 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /submodules/segment-anything/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /submodules/segment-anything/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /submodules/segment-anything/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /submodules/simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /submodules/simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == "nt": 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=["spatial.cu", "simple_knn.cu", "ext.cpp"], 27 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}, 28 | ) 29 | ], 30 | cmdclass={"build_ext": BuildExtension}, 31 | ) 32 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sharinka0715/semantic-gaussians/ae531372ee37e4016030362397ddb1c64ad4c58d/submodules/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /tools/unzip_label_filt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from tqdm import tqdm 4 | import traceback 5 | 6 | LABEL_ROOT = "/home/guojun/gaussian_splatting/datasets/scannet_12scene_sens/" 7 | OUT_ROOT = "/home/guojun/gaussian_splatting/datasets/scannet_12scene_extract/" 8 | 9 | 10 | for split in [""]: 11 | ls = os.listdir(os.path.join(OUT_ROOT, split)) 12 | ls.sort() 13 | for scene in tqdm(ls): 14 | img_path = os.path.join(OUT_ROOT, split, scene, "color") 15 | ext_imgs = os.listdir(img_path) 16 | ext_imgs.sort() 17 | 18 | out_path = os.path.join(OUT_ROOT, split, scene) 19 | os.makedirs(out_path, exist_ok=True) 20 | label_zip = os.path.join(LABEL_ROOT, split, scene, f"{scene}_2d-label-filt.zip") 21 | with zipfile.ZipFile(label_zip, "r") as zip_ref: 22 | for img in ext_imgs: 23 | try: 24 | zip_ref.extract(f"label-filt/{img}".replace("jpg", "png"), out_path) 25 | except Exception: 26 | print(traceback.format_exc()) 27 | print(scene) 28 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | from PIL import Image 14 | from scene.camera import Camera 15 | from utils.general_utils import PILtoTorch 16 | from utils.graphics_utils import fov2focal, focal2fov 17 | 18 | WARNED = False 19 | 20 | 21 | def loadCam(args, id, cam_info, resolution_scale): 22 | bg = np.array([1, 1, 1]) if args.white_background else np.array([0, 0, 0]) 23 | 24 | image = Image.open(cam_info.image_path) 25 | im_data = np.array(image.convert("RGBA")) 26 | norm_data = im_data / 255.0 27 | arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 28 | image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") 29 | 30 | orig_w, orig_h = image.size 31 | 32 | if args.downscale_ratio == -1: 33 | if orig_w > 1600: 34 | global WARNED 35 | if not WARNED: 36 | print( 37 | "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 38 | "If this is not desired, please explicitly specify '--resolution/-r' as 1" 39 | ) 40 | WARNED = True 41 | global_down = orig_w / 1600 42 | else: 43 | global_down = 1 44 | else: 45 | global_down = 1 / args.downscale_ratio 46 | 47 | scale = float(global_down) * float(resolution_scale) 48 | resolution = (int(orig_w / scale), int(orig_h / scale)) 49 | 50 | resized_image_rgb = PILtoTorch(image, resolution) 51 | 52 | gt_image = resized_image_rgb[:3, ...] 53 | loaded_mask = None 54 | 55 | if resized_image_rgb.shape[1] == 4: 56 | loaded_mask = resized_image_rgb[3:4, ...] 57 | 58 | return Camera( 59 | colmap_id=cam_info.uid, 60 | R=cam_info.R, 61 | T=cam_info.T, 62 | FoVx=cam_info.FovX, 63 | FoVy=cam_info.FovY, 64 | image=gt_image, 65 | gt_alpha_mask=loaded_mask, 66 | image_name=cam_info.image_name, 67 | image_path=cam_info.image_path, 68 | uid=id, 69 | device=args.device, 70 | ) 71 | 72 | 73 | def get_camera_from_directions(scene_camera, R, T): 74 | return Camera( 75 | colmap_id=scene_camera.colmap_id, 76 | R=R, 77 | T=T, 78 | FoVx=scene_camera.FoVx, 79 | FoVy=scene_camera.FoVy, 80 | image=scene_camera.original_image, 81 | gt_alpha_mask=None, 82 | image_name=scene_camera.image_name, 83 | image_path=scene_camera.image_path, 84 | uid=scene_camera.uid, 85 | device=scene_camera.data_device, 86 | ) 87 | 88 | 89 | def get_camera_viser(scene_camera, R, T, fovy, wh_ratio): 90 | fovx = focal2fov(fov2focal(fovy, 1000), 1000 * wh_ratio) 91 | return Camera( 92 | colmap_id=scene_camera.colmap_id, 93 | R=R, 94 | T=T, 95 | FoVx=fovx, 96 | FoVy=fovy, 97 | image=scene_camera.original_image, 98 | gt_alpha_mask=None, 99 | image_name=scene_camera.image_name, 100 | image_path=scene_camera.image_path, 101 | uid=scene_camera.uid, 102 | device=scene_camera.data_device, 103 | ) 104 | 105 | 106 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 107 | camera_list = [] 108 | 109 | for id, c in enumerate(cam_infos): 110 | camera_list.append(loadCam(args, id, c, resolution_scale)) 111 | 112 | return camera_list 113 | 114 | 115 | def camera_to_JSON(id, camera: Camera): 116 | Rt = np.zeros((4, 4)) 117 | Rt[:3, :3] = camera.R.transpose() 118 | Rt[:3, 3] = camera.T 119 | Rt[3, 3] = 1.0 120 | 121 | W2C = np.linalg.inv(Rt) 122 | pos = W2C[:3, 3] 123 | rot = W2C[:3, :3] 124 | serializable_array_2d = [x.tolist() for x in rot] 125 | camera_entry = { 126 | "id": id, 127 | "img_name": camera.image_name, 128 | "width": camera.width, 129 | "height": camera.height, 130 | "position": pos.tolist(), 131 | "rotation": serializable_array_2d, 132 | "fy": fov2focal(camera.FovY, camera.height), 133 | "fx": fov2focal(camera.FovX, camera.width), 134 | } 135 | return camera_entry 136 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | 32 | def get_expon_lr_func(lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000): 33 | """ 34 | Copied from Plenoxels 35 | 36 | Continuous learning rate decay function. Adapted from JaxNeRF 37 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 38 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 39 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 40 | function of lr_delay_mult, such that the initial learning rate is 41 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 42 | to the normal learning rate when steps>lr_delay_steps. 43 | :param conf: config subtree 'lr' or similar 44 | :param max_steps: int, the number of steps during optimization. 45 | :return HoF which takes step as input 46 | """ 47 | 48 | def helper(step): 49 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 50 | # Disable this parameter 51 | return 0.0 52 | if lr_delay_steps > 0: 53 | # A kind of reverse cosine decay. 54 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 55 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 56 | ) 57 | else: 58 | delay_rate = 1.0 59 | t = np.clip(step / max_steps, 0, 1) 60 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 61 | return delay_rate * log_lerp 62 | 63 | return helper 64 | 65 | 66 | def strip_lowerdiag(L): 67 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 68 | 69 | uncertainty[:, 0] = L[:, 0, 0] 70 | uncertainty[:, 1] = L[:, 0, 1] 71 | uncertainty[:, 2] = L[:, 0, 2] 72 | uncertainty[:, 3] = L[:, 1, 1] 73 | uncertainty[:, 4] = L[:, 1, 2] 74 | uncertainty[:, 5] = L[:, 2, 2] 75 | return uncertainty 76 | 77 | 78 | def strip_symmetric(sym): 79 | return strip_lowerdiag(sym) 80 | 81 | 82 | def build_rotation(r): 83 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 84 | 85 | q = r / norm[:, None] 86 | 87 | R = torch.zeros((q.size(0), 3, 3), device="cuda") 88 | 89 | r = q[:, 0] 90 | x = q[:, 1] 91 | y = q[:, 2] 92 | z = q[:, 3] 93 | 94 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 95 | R[:, 0, 1] = 2 * (x * y - r * z) 96 | R[:, 0, 2] = 2 * (x * z + r * y) 97 | R[:, 1, 0] = 2 * (x * y + r * z) 98 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 99 | R[:, 1, 2] = 2 * (y * z - r * x) 100 | R[:, 2, 0] = 2 * (x * z - r * y) 101 | R[:, 2, 1] = 2 * (y * z + r * x) 102 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 103 | return R 104 | 105 | 106 | def build_scaling_rotation(s, r): 107 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 108 | R = build_rotation(r) 109 | 110 | L[:, 0, 0] = s[:, 0] 111 | L[:, 1, 1] = s[:, 1] 112 | L[:, 2, 2] = s[:, 2] 113 | 114 | L = R @ L 115 | return L 116 | 117 | 118 | def safe_state(silent): 119 | old_f = sys.stdout 120 | 121 | class F: 122 | def __init__(self, silent): 123 | self.silent = silent 124 | 125 | def write(self, x): 126 | if not self.silent: 127 | if x.endswith("\n"): 128 | old_f.write( 129 | x.replace( 130 | "\n", 131 | " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))), 132 | ) 133 | ) 134 | else: 135 | old_f.write(x) 136 | 137 | def flush(self): 138 | old_f.flush() 139 | 140 | sys.stdout = F(silent) 141 | 142 | random.seed(0) 143 | np.random.seed(0) 144 | torch.manual_seed(0) 145 | torch.cuda.set_device(torch.device("cuda:0")) 146 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def l2_loss(network_output, gt): 23 | return ((network_output - gt) ** 2).mean() 24 | 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) 28 | return gauss / gauss.sum() 29 | 30 | 31 | def create_window(window_size, channel): 32 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 33 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 34 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 35 | return window 36 | 37 | 38 | def ssim(img1, img2, window_size=11, size_average=True): 39 | channel = img1.size(-3) 40 | window = create_window(window_size, channel) 41 | 42 | if img1.is_cuda: 43 | window = window.cuda(img1.get_device()) 44 | window = window.type_as(img1) 45 | 46 | return _ssim(img1, img2, window, window_size, channel, size_average) 47 | 48 | 49 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 50 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 51 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 52 | 53 | mu1_sq = mu1.pow(2) 54 | mu2_sq = mu2.pow(2) 55 | mu1_mu2 = mu1 * mu2 56 | 57 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 58 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 59 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 60 | 61 | C1 = 0.01**2 62 | C2 = 0.03**2 63 | 64 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 65 | 66 | if size_average: 67 | return ssim_map.mean() 68 | else: 69 | return ssim_map.mean(1).mean(1).mean(1) 70 | 71 | 72 | def mse(img1, img2): 73 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 74 | 75 | 76 | def psnr(img1, img2): 77 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 78 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 79 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | """IoU""" 2 | 3 | import numpy as np 4 | from dataset.scannet.scannet_constants import SCANNET20_CLASS_LABELS, COCOMAP_CLASS_LABELS 5 | 6 | def confusion_matrix(pred_ids, gt_ids, num_classes): 7 | """calculate the confusion matrix.""" 8 | 9 | assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape) 10 | 11 | # the sum of each row (axis=1) is predicted truth, the sum of each column (axis=0) is ground truth 12 | confusion = ( 13 | np.bincount(pred_ids * (num_classes + 1) + gt_ids, minlength=(num_classes + 1) ** 2) 14 | .reshape((num_classes + 1, num_classes + 1)) 15 | .astype(np.ulonglong) 16 | ) 17 | return confusion[:, 1:] # do not calculate unlabeled points (the first column) 18 | 19 | def get_iou(label_id, confusion): 20 | """calculate IoU.""" 21 | 22 | # true positives 23 | tp = np.longlong(confusion[label_id + 1, label_id]) 24 | # false positives 25 | fp = np.longlong(confusion[label_id + 1, :].sum()) - tp 26 | # false negatives 27 | fn = np.longlong(confusion[:, label_id].sum()) - tp 28 | 29 | denom = tp + fp + fn 30 | if denom == 0: 31 | return float("nan") 32 | return float(tp) / denom, tp, denom 33 | 34 | 35 | def evaluate_confusion(confusion, stdout=False, dataset="scannet20"): 36 | if stdout: 37 | print("evaluating", confusion.sum(), "points...") 38 | 39 | if "scannet20" in dataset: 40 | CLASS_LABELS = SCANNET20_CLASS_LABELS 41 | elif "cocomap" in dataset: 42 | CLASS_LABELS = COCOMAP_CLASS_LABELS 43 | else: 44 | raise NotImplementedError 45 | N_CLASSES = len(CLASS_LABELS) 46 | print("num_classes:", N_CLASSES) 47 | 48 | class_ious = {} 49 | class_accs = {} 50 | mean_iou = 0 51 | mean_acc = 0 52 | 53 | count = 0 54 | for i in range(N_CLASSES): 55 | label_name = CLASS_LABELS[i] 56 | if confusion.sum(axis=0)[i] == 0: # at least 1 point needs to be in the evaluation for this class 57 | continue 58 | 59 | class_ious[label_name] = get_iou(i, confusion) 60 | class_accs[label_name] = class_ious[label_name][1] / confusion.sum(axis=0)[i] 61 | count += 1 62 | 63 | mean_iou += class_ious[label_name][0] 64 | mean_acc += class_accs[label_name] 65 | 66 | mean_iou /= count 67 | mean_acc /= count 68 | if stdout: 69 | print("classes IoU") 70 | print("----------------------------") 71 | for i in range(N_CLASSES): 72 | label_name = CLASS_LABELS[i] 73 | try: 74 | print( 75 | "{0:<14s}: {1:>5.3f} ({2:>6d}/{3:<6d})".format( 76 | label_name, 77 | class_ious[label_name][0], 78 | class_ious[label_name][1], 79 | class_ious[label_name][2], 80 | ) 81 | ) 82 | except: 83 | print(label_name + " error!") 84 | continue 85 | print("Mean IoU", mean_iou) 86 | print("Mean Acc", mean_acc) 87 | 88 | with open("eval_result.log", "a") as fp: 89 | fp.write("classes,IoU\n") 90 | for i in range(N_CLASSES): 91 | label_name = CLASS_LABELS[i] 92 | try: 93 | fp.write( 94 | "{0:<14s}: {1:>5.3f} ({2:>6d}/{3:<6d})\n".format( 95 | label_name, 96 | class_ious[label_name][0], 97 | class_ious[label_name][1], 98 | class_ious[label_name][2], 99 | ) 100 | ) 101 | except: 102 | fp.write(label_name + ",error\n") 103 | fp.write("mean IoU,{}\n".format(mean_iou)) 104 | fp.write("mean Acc,{}\n\n".format(mean_acc)) 105 | return mean_iou 106 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | 25 | C0 = 0.28209479177387814 26 | C1 = 0.4886025119029199 27 | C2 = [ 28 | 1.0925484305920792, 29 | -1.0925484305920792, 30 | 0.31539156525252005, 31 | -1.0925484305920792, 32 | 0.5462742152960396, 33 | ] 34 | C3 = [ 35 | -0.5900435899266435, 36 | 2.890611442640554, 37 | -0.4570457994644658, 38 | 0.3731763325901154, 39 | -0.4570457994644658, 40 | 1.445305721320277, 41 | -0.5900435899266435, 42 | ] 43 | C4 = [ 44 | 2.5033429417967046, 45 | -1.7701307697799304, 46 | 0.9461746957575601, 47 | -0.6690465435572892, 48 | 0.10578554691520431, 49 | -0.6690465435572892, 50 | 0.47308734787878004, 51 | -1.7701307697799304, 52 | 0.6258357354491761, 53 | ] 54 | 55 | 56 | def eval_sh(deg, sh, dirs): 57 | """ 58 | Evaluate spherical harmonics at unit directions 59 | using hardcoded SH polynomials. 60 | Works with torch/np/jnp. 61 | ... Can be 0 or more batch dimensions. 62 | Args: 63 | deg: int SH deg. Currently, 0-3 supported 64 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 65 | dirs: jnp.ndarray unit directions [..., 3] 66 | Returns: 67 | [..., C] 68 | """ 69 | assert deg <= 4 and deg >= 0 70 | coeff = (deg + 1) ** 2 71 | assert sh.shape[-1] >= coeff 72 | 73 | result = C0 * sh[..., 0] 74 | if deg > 0: 75 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 76 | result = result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 77 | 78 | if deg > 1: 79 | xx, yy, zz = x * x, y * y, z * z 80 | xy, yz, xz = x * y, y * z, x * z 81 | result = ( 82 | result 83 | + C2[0] * xy * sh[..., 4] 84 | + C2[1] * yz * sh[..., 5] 85 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 86 | + C2[3] * xz * sh[..., 7] 87 | + C2[4] * (xx - yy) * sh[..., 8] 88 | ) 89 | 90 | if deg > 2: 91 | result = ( 92 | result 93 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 94 | + C3[1] * xy * z * sh[..., 10] 95 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 96 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 97 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 98 | + C3[5] * z * (xx - yy) * sh[..., 14] 99 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 100 | ) 101 | 102 | if deg > 3: 103 | result = ( 104 | result 105 | + C4[0] * xy * (xx - yy) * sh[..., 16] 106 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 107 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 108 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 109 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 110 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 111 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 112 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 113 | + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] 114 | ) 115 | return result 116 | 117 | 118 | def RGB2SH(rgb): 119 | return (rgb - 0.5) / C0 120 | 121 | 122 | def SH2RGB(sh): 123 | return sh * C0 + 0.5 124 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import random 15 | import numpy as np 16 | from errno import EEXIST 17 | from os import makedirs, path 18 | 19 | 20 | def mkdir_p(folder_path): 21 | # Creates a directory. equivalent to using mkdir -p on the command line 22 | try: 23 | makedirs(folder_path) 24 | except OSError as exc: # Python >2.5 25 | if exc.errno == EEXIST and path.isdir(folder_path): 26 | pass 27 | else: 28 | raise 29 | 30 | 31 | def searchForMaxIteration(folder): 32 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 33 | return max(saved_iters) 34 | 35 | 36 | def set_seed(seed): 37 | """Seed the program.""" 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | os.environ["PYTHONHASHSEED"] = str(seed) 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | --------------------------------------------------------------------------------