├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── assets ├── main_performance.png └── teaser.png ├── environment.yml ├── gaussian_renderer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── network_gui.cpython-37.pyc └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── results ├── DeepBlending │ ├── drjohnson.csv │ └── playroom.csv ├── MipNeRF360 │ ├── bicycle.csv │ ├── bonsai.csv │ ├── counter.csv │ ├── flowers.csv │ ├── garden.csv │ ├── kitchen.csv │ ├── room.csv │ ├── stump.csv │ └── treehill.csv ├── README.md ├── SyntheticNeRF │ ├── chair.csv │ ├── drums.csv │ ├── ficus.csv │ ├── hotdog.csv │ ├── lego.csv │ ├── materials.csv │ ├── mic.csv │ └── ship.csv └── TanksAndTemples │ ├── train.csv │ └── truck.csv ├── run_shell_blender.py ├── run_shell_bungee.py ├── run_shell_db.py ├── run_shell_mip360.py ├── run_shell_tnt.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── submodules ├── arithmetic.zip ├── diff-gaussian-rasterization.zip ├── gridencoder.zip └── simple-knn.zip ├── train.py └── utils ├── camera_utils.py ├── encodings.py ├── encodings_cuda.py ├── entropy_models.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py ├── system_utils.py └── visualize_utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV'24] HAC 2 | Official Pytorch implementation of **HAC: Hash-grid Assisted Context for 3D Gaussian Splatting Compression**. 3 | ## Compress 3D Gaussian Splatting for 75X without fidelity drop! 4 | ## 💪 An enhanced compression method, [HAC++](https://github.com/YihangChen-ee/HAC-plus), has been released! 5 | 6 | [Yihang Chen](https://yihangchen-ee.github.io), 7 | [Qianyi Wu](https://qianyiwu.github.io), 8 | [Weiyao Lin](https://weiyaolin.github.io), 9 | [Mehrtash Harandi](https://sites.google.com/site/mehrtashharandi/), 10 | [Jianfei Cai](http://jianfei-cai.github.io) 11 | 12 | [[`Paper`](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/01178.pdf)] [[`Arxiv`](https://arxiv.org/pdf/2403.14530)] [[`Project`](https://yihangchen-ee.github.io/project_hac/)] [[`Github`](https://github.com/YihangChen-ee/HAC)] 13 | 14 | ## Links 15 | You are welcomed to check a series of works from our group on 3D radiance field representation compression as listed below: 16 | - 🎉 [CNC](https://github.com/yihangchen-ee/cnc/) [CVPR'24]: efficient NeRF compression! [[`Paper`](https://openaccess.thecvf.com/content/CVPR2024/papers/Chen_How_Far_Can_We_Compress_Instant-NGP-Based_NeRF_CVPR_2024_paper.pdf)] [[`Arxiv`](https://arxiv.org/pdf/2406.04101)] [[`Project`](https://yihangchen-ee.github.io/project_cnc/)] 17 | - 🏠 [HAC](https://github.com/yihangchen-ee/hac/) [ECCV'24]: efficient 3DGS compression! [[`Paper`](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/01178.pdf)] [[`Arxiv`](https://arxiv.org/pdf/2403.14530)] [[`Project`](https://yihangchen-ee.github.io/project_hac/)] 18 | - 💪 [HAC++](https://github.com/yihangchen-ee/hac-plus/) [ARXIV'25]: an enhanced compression method over HAC! [[`Arxiv`](https://arxiv.org/pdf/2501.12255)] [[`Project`](https://yihangchen-ee.github.io/project_hac++/)] 19 | - 🚀 [FCGS](https://github.com/yihangchen-ee/fcgs/) [ICLR'25]: fast optimization-free 3DGS compression! [[`Paper`](https://openreview.net/pdf?id=DCandSZ2F1)] [[`Arxiv`](https://arxiv.org/pdf/2410.08017)] [[`Project`](https://yihangchen-ee.github.io/project_fcgs/)] 20 | - 🪜 [PCGS](https://github.com/yihangchen-ee/pcgs/) [ARXIV'25]: progressive 3DGS compression! [[`Arxiv`](https://arxiv.org/pdf/2503.08511)] [[`Project`](https://yihangchen-ee.github.io/project_pcgs/)] 21 | 22 | ## Updates 23 | 🔥8-Aug-2024: HAC now utilizes a ```cuda-based codec``` instead of the original ```torchac```, which significantly reduces the codec runtime by over ```10``` times compared to that reported in the paper! 24 | 25 | ## Overview 26 |

27 | 29 |

30 | 31 | Our approach introduces a binary hash grid to establish continuous spatial consistencies, 32 | allowing us to unveil the inherent spatial relations of anchors through a carefully designed context model. 33 | To facilitate entropy coding, we utilize Gaussian distributions to accurately estimate the probability of each quantized attribute, 34 | where an adaptive quantization module is proposed to enable high-precision quantization of these attributes for improved fidelity restoration. 35 | Additionally, we incorporate an adaptive masking strategy to eliminate invalid Gaussians and anchors. 36 | Importantly, our work is the pioneer to explore context-based compression for 3DGS representation, resulting in a remarkable size reduction. 37 | 38 | ## Performance 39 |

40 | 42 |

43 | 44 | 45 | ## Installation 46 | 47 | We tested our code on a server with Ubuntu 20.04.1, cuda 11.8, gcc 9.4.0 48 | 1. Unzip files 49 | ``` 50 | cd submodules 51 | unzip diff-gaussian-rasterization.zip 52 | unzip gridencoder.zip 53 | unzip simple-knn.zip 54 | unzip arithmetic.zip 55 | cd .. 56 | ``` 57 | 2. Install environment 58 | ``` 59 | conda env create --file environment.yml 60 | conda activate HAC_env 61 | ``` 62 | 63 | ## Data 64 | 65 | First, create a ```data/``` folder inside the project path by 66 | ``` 67 | mkdir data 68 | ``` 69 | 70 | The data structure will be organised as follows: 71 | 72 | ``` 73 | data/ 74 | ├── dataset_name 75 | │   ├── scene1/ 76 | │   │   ├── images 77 | │   │   │   ├── IMG_0.jpg 78 | │   │   │   ├── IMG_1.jpg 79 | │   │   │   ├── ... 80 | │   │   ├── sparse/ 81 | │   │   └──0/ 82 | │   ├── scene2/ 83 | │   │   ├── images 84 | │   │   │   ├── IMG_0.jpg 85 | │   │   │   ├── IMG_1.jpg 86 | │   │   │   ├── ... 87 | │   │   ├── sparse/ 88 | │   │   └──0/ 89 | ... 90 | ``` 91 | 92 | - For instance: `./data/blending/drjohnson/` 93 | - For instance: `./data/bungeenerf/amsterdam/` 94 | - For instance: `./data/mipnerf360/bicycle/` 95 | - For instance: `./data/nerf_synthetic/chair/` 96 | - For instance: `./data/tandt/train/` 97 | 98 | 99 | ### Public Data (We follow suggestions from [Scaffold-GS](https://github.com/city-super/Scaffold-GS)) 100 | 101 | - The **BungeeNeRF** dataset is available in [Google Drive](https://drive.google.com/file/d/1nBLcf9Jrr6sdxKa1Hbd47IArQQ_X8lww/view?usp=sharing)/[百度网盘[提取码:4whv]](https://pan.baidu.com/s/1AUYUJojhhICSKO2JrmOnCA). 102 | - The **MipNeRF360** scenes are provided by the paper author [here](https://jonbarron.info/mipnerf360/). And we test on its entire 9 scenes ```bicycle, bonsai, counter, garden, kitchen, room, stump, flowers, treehill```. 103 | - The SfM datasets for **Tanks&Temples** and **Deep Blending** are hosted by 3D-Gaussian-Splatting [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip). Download and uncompress them into the ```data/``` folder. 104 | 105 | ### Custom Data 106 | 107 | For custom data, you should process the image sequences with [Colmap](https://colmap.github.io/) to obtain the SfM points and camera poses. Then, place the results into ```data/``` folder. 108 | 109 | ## Training 110 | 111 | To train scenes, we provide the following training scripts: 112 | - Tanks&Temples: ```run_shell_tnt.py``` 113 | - MipNeRF360: ```run_shell_mip360.py``` 114 | - BungeeNeRF: ```run_shell_bungee.py``` 115 | - Deep Blending: ```run_shell_db.py``` 116 | - Nerf Synthetic: ```run_shell_blender.py``` 117 | 118 | run them with 119 | ``` 120 | python run_shell_xxx.py 121 | ``` 122 | 123 | The code will automatically run the entire process of: **training, encoding, decoding, testing**. 124 | - Training log will be recorded in `output.log` of the output directory. Results of **detailed fidelity, detailed size, detailed time** will all be recorded 125 | - Encoded bitstreams will be stored in `./bitstreams` of the output directory. 126 | - Evaluated output images will be saved in `./test/ours_30000/renders` of the output directory. 127 | - Optionally, you can change `lmbda` in these `run_shell_xxx.py` scripts to try variable bitrate. 128 | - **After training, the original model `point_cloud.ply` is losslessly compressed as `./bitstreams`. You should refer to `./bitstreams` to get the final model size, but not `point_cloud.ply`. You can even delete `point_cloud.ply` if you like :).** 129 | 130 | 131 | ## Contact 132 | 133 | - Yihang Chen: yhchen.ee@sjtu.edu.cn 134 | 135 | ## Citation 136 | 137 | If you find our work helpful, please consider citing: 138 | 139 | ```bibtex 140 | @article{hac++2025, 141 | title={HAC++: Towards 100X Compression of 3D Gaussian Splatting}, 142 | author={Chen, Yihang and Wu, Qianyi and Lin, Weiyao and Harandi, Mehrtash and Cai, Jianfei}, 143 | year={2025} 144 | } 145 | ``` 146 | ```bibtex 147 | @inproceedings{hac2024, 148 | title={HAC: Hash-grid Assisted Context for 3D Gaussian Splatting Compression}, 149 | author={Chen, Yihang and Wu, Qianyi and Lin, Weiyao and Harandi, Mehrtash and Cai, Jianfei}, 150 | booktitle={European Conference on Computer Vision}, 151 | year={2024} 152 | } 153 | ``` 154 | 155 | 156 | ## LICENSE 157 | 158 | Please follow the LICENSE of [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting). 159 | 160 | ## Acknowledgement 161 | 162 | - We thank all authors from [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting) for presenting such an excellent work. 163 | - We thank all authors from [Scaffold-GS](https://github.com/city-super/Scaffold-GS) for presenting such an excellent work. 164 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self.feat_dim = 50 51 | self.n_offsets = 10 52 | self.voxel_size = 0.001 # if voxel_size<=0, using 1nn dist 53 | self.update_depth = 3 54 | self.update_init_factor = 16 55 | self.update_hierachy_factor = 4 56 | 57 | self.use_feat_bank = False 58 | self._source_path = "" 59 | self._model_path = "" 60 | self._images = "images" 61 | self._resolution = -1 62 | self._white_background = False 63 | self.data_device = "cuda" 64 | self.eval = True 65 | self.lod = 0 66 | super().__init__(parser, "Loading Parameters", sentinel) 67 | 68 | def extract(self, args): 69 | g = super().extract(args) 70 | g.source_path = os.path.abspath(g.source_path) 71 | return g 72 | 73 | class PipelineParams(ParamGroup): 74 | def __init__(self, parser): 75 | self.convert_SHs_python = False 76 | self.compute_cov3D_python = False 77 | self.debug = False 78 | super().__init__(parser, "Pipeline Parameters") 79 | 80 | class OptimizationParams(ParamGroup): 81 | def __init__(self, parser): 82 | self.iterations = 30_000 83 | self.position_lr_init = 0.0 84 | self.position_lr_final = 0.0 85 | self.position_lr_delay_mult = 0.01 86 | self.position_lr_max_steps = 30_000 87 | 88 | self.offset_lr_init = 0.01 89 | self.offset_lr_final = 0.0001 90 | self.offset_lr_delay_mult = 0.01 91 | self.offset_lr_max_steps = 30_000 92 | 93 | self.mask_lr_init = 0.01 94 | self.mask_lr_final = 0.0001 95 | self.mask_lr_delay_mult = 0.01 96 | self.mask_lr_max_steps = 30_000 97 | 98 | self.feature_lr = 0.0075 99 | self.opacity_lr = 0.02 100 | self.scaling_lr = 0.007 101 | self.rotation_lr = 0.002 102 | 103 | self.mlp_opacity_lr_init = 0.002 104 | self.mlp_opacity_lr_final = 0.00002 105 | self.mlp_opacity_lr_delay_mult = 0.01 106 | self.mlp_opacity_lr_max_steps = 30_000 107 | 108 | self.mlp_cov_lr_init = 0.004 109 | self.mlp_cov_lr_final = 0.004 110 | self.mlp_cov_lr_delay_mult = 0.01 111 | self.mlp_cov_lr_max_steps = 30_000 112 | 113 | self.mlp_color_lr_init = 0.008 114 | self.mlp_color_lr_final = 0.00005 115 | self.mlp_color_lr_delay_mult = 0.01 116 | self.mlp_color_lr_max_steps = 30_000 117 | 118 | self.mlp_featurebank_lr_init = 0.01 119 | self.mlp_featurebank_lr_final = 0.00001 120 | self.mlp_featurebank_lr_delay_mult = 0.01 121 | self.mlp_featurebank_lr_max_steps = 30_000 122 | 123 | self.encoding_xyz_lr_init = 0.005 124 | self.encoding_xyz_lr_final = 0.00001 125 | self.encoding_xyz_lr_delay_mult = 0.33 126 | self.encoding_xyz_lr_max_steps = 30_000 127 | 128 | self.mlp_grid_lr_init = 0.005 129 | self.mlp_grid_lr_final = 0.00001 130 | self.mlp_grid_lr_delay_mult = 0.01 131 | self.mlp_grid_lr_max_steps = 30_000 132 | 133 | self.mlp_deform_lr_init = 0.005 134 | self.mlp_deform_lr_final = 0.0005 135 | self.mlp_deform_lr_delay_mult = 0.01 136 | self.mlp_deform_lr_max_steps = 30_000 137 | 138 | self.percent_dense = 0.01 139 | self.lambda_dssim = 0.2 140 | 141 | # for anchor densification 142 | self.start_stat = 500 143 | self.update_from = 1500 144 | self.update_interval = 100 145 | self.update_until = 15_000 146 | 147 | self.min_opacity = 0.005 # 0.2 148 | self.success_threshold = 0.8 149 | self.densify_grad_threshold = 0.0002 150 | 151 | super().__init__(parser, "Optimization Parameters") 152 | 153 | def get_combined_args(parser : ArgumentParser): 154 | cmdlne_string = sys.argv[1:] 155 | cfgfile_string = "Namespace()" 156 | args_cmdline = parser.parse_args(cmdlne_string) 157 | 158 | try: 159 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 160 | print("Looking for config file in", cfgfilepath) 161 | with open(cfgfilepath) as cfg_file: 162 | print("Config file found: {}".format(cfgfilepath)) 163 | cfgfile_string = cfg_file.read() 164 | except TypeError: 165 | print("Config file not found at") 166 | pass 167 | args_cfgfile = eval(cfgfile_string) 168 | 169 | merged_dict = vars(args_cfgfile).copy() 170 | for k,v in vars(args_cmdline).items(): 171 | if v != None: 172 | merged_dict[k] = v 173 | return Namespace(**merged_dict) 174 | -------------------------------------------------------------------------------- /assets/main_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/assets/main_performance.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/assets/teaser.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: HAC_env 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - cudatoolkit=11.6 9 | - plyfile=0.8.1 10 | - python=3.7.13 11 | - pip=22.3.1 12 | - pytorch=1.12.1 13 | - torchaudio=0.12.1 14 | - torchvision=0.13.1 15 | - pytorch-scatter 16 | - tqdm 17 | - pip: 18 | - einops 19 | - wandb 20 | - lpips 21 | - submodules/diff-gaussian-rasterization 22 | - submodules/simple-knn 23 | - submodules/gridencoder 24 | - submodules/arithmetic 25 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import os.path 12 | import time 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as nnf 17 | from einops import repeat 18 | 19 | import math 20 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 21 | from scene.gaussian_model import GaussianModel 22 | from utils.encodings import STE_binary, STE_multistep 23 | 24 | 25 | def generate_neural_gaussians(viewpoint_camera, pc : GaussianModel, visible_mask=None, is_training=False, step=0): 26 | ## view frustum filtering for acceleration 27 | 28 | time_sub = 0 29 | 30 | if visible_mask is None: 31 | visible_mask = torch.ones(pc.get_anchor.shape[0], dtype=torch.bool, device = pc.get_anchor.device) 32 | 33 | anchor = pc.get_anchor[visible_mask] 34 | # 35 | feat = pc._anchor_feat[visible_mask] 36 | grid_offsets = pc._offset[visible_mask] 37 | grid_scaling = pc.get_scaling[visible_mask] 38 | binary_grid_masks = pc.get_mask[visible_mask] 39 | mask_anchor = pc.get_mask_anchor[visible_mask] 40 | mask_anchor_bool = mask_anchor.to(torch.bool) 41 | mask_anchor_rate = (mask_anchor.sum() / mask_anchor.numel()).detach() 42 | 43 | bit_per_param = None 44 | bit_per_feat_param = None 45 | bit_per_scaling_param = None 46 | bit_per_offsets_param = None 47 | Q_feat = 1 48 | Q_scaling = 0.001 49 | Q_offsets = 0.2 50 | if is_training: 51 | if step > 3000 and step <= 10000: 52 | # quantization 53 | feat = feat + torch.empty_like(feat).uniform_(-0.5, 0.5) * Q_feat 54 | grid_scaling = grid_scaling + torch.empty_like(grid_scaling).uniform_(-0.5, 0.5) * Q_scaling 55 | grid_offsets = grid_offsets + torch.empty_like(grid_offsets).uniform_(-0.5, 0.5) * Q_offsets 56 | 57 | if step == 10000: 58 | pc.update_anchor_bound() 59 | 60 | if step > 10000: 61 | feat_context = pc.calc_interp_feat(anchor) 62 | feat_context = pc.get_grid_mlp(feat_context) 63 | mean, scale, mean_scaling, scale_scaling, mean_offsets, scale_offsets, Q_feat_adj, Q_scaling_adj, Q_offsets_adj = \ 64 | torch.split(feat_context, split_size_or_sections=[pc.feat_dim, pc.feat_dim, 6, 6, 3*pc.n_offsets, 3*pc.n_offsets, 1, 1, 1], dim=-1) 65 | 66 | Q_feat = Q_feat * (1 + torch.tanh(Q_feat_adj)) 67 | Q_scaling = Q_scaling * (1 + torch.tanh(Q_scaling_adj)) 68 | Q_offsets = Q_offsets * (1 + torch.tanh(Q_offsets_adj)) 69 | feat = feat + torch.empty_like(feat).uniform_(-0.5, 0.5) * Q_feat 70 | grid_scaling = grid_scaling + torch.empty_like(grid_scaling).uniform_(-0.5, 0.5) * Q_scaling 71 | grid_offsets = grid_offsets + torch.empty_like(grid_offsets).uniform_(-0.5, 0.5) * Q_offsets.unsqueeze(1) 72 | 73 | choose_idx = torch.rand_like(anchor[:, 0]) <= 0.05 74 | choose_idx = choose_idx & mask_anchor_bool 75 | feat_chosen = feat[choose_idx] 76 | grid_scaling_chosen = grid_scaling[choose_idx] 77 | grid_offsets_chosen = grid_offsets[choose_idx].view(-1, 3*pc.n_offsets) 78 | mean = mean[choose_idx] 79 | scale = scale[choose_idx] 80 | mean_scaling = mean_scaling[choose_idx] 81 | scale_scaling = scale_scaling[choose_idx] 82 | mean_offsets = mean_offsets[choose_idx] 83 | scale_offsets = scale_offsets[choose_idx] 84 | Q_feat = Q_feat[choose_idx] 85 | Q_scaling = Q_scaling[choose_idx] 86 | Q_offsets = Q_offsets[choose_idx] 87 | binary_grid_masks_chosen = binary_grid_masks[choose_idx].repeat(1, 1, 3).view(-1, 3*pc.n_offsets) 88 | bit_feat = pc.entropy_gaussian.forward(feat_chosen, mean, scale, Q_feat, pc._anchor_feat.mean()) 89 | bit_scaling = pc.entropy_gaussian.forward(grid_scaling_chosen, mean_scaling, scale_scaling, Q_scaling, pc.get_scaling.mean()) 90 | bit_offsets = pc.entropy_gaussian.forward(grid_offsets_chosen, mean_offsets, scale_offsets, Q_offsets, pc._offset.mean()) 91 | bit_offsets = bit_offsets * binary_grid_masks_chosen 92 | bit_per_feat_param = torch.sum(bit_feat) / bit_feat.numel() * mask_anchor_rate 93 | bit_per_scaling_param = torch.sum(bit_scaling) / bit_scaling.numel() * mask_anchor_rate 94 | bit_per_offsets_param = torch.sum(bit_offsets) / bit_offsets.numel() * mask_anchor_rate 95 | bit_per_param = (torch.sum(bit_feat) + torch.sum(bit_scaling) + torch.sum(bit_offsets)) / \ 96 | (bit_feat.numel() + bit_scaling.numel() + bit_offsets.numel()) * mask_anchor_rate 97 | 98 | elif not pc.decoded_version: 99 | torch.cuda.synchronize(); t1 = time.time() 100 | feat_context = pc.calc_interp_feat(anchor) 101 | mean, scale, mean_scaling, scale_scaling, mean_offsets, scale_offsets, Q_feat_adj, Q_scaling_adj, Q_offsets_adj = \ 102 | torch.split(pc.get_grid_mlp(feat_context), split_size_or_sections=[pc.feat_dim, pc.feat_dim, 6, 6, 3*pc.n_offsets, 3*pc.n_offsets, 1, 1, 1], dim=-1) 103 | 104 | Q_feat = Q_feat * (1 + torch.tanh(Q_feat_adj)) 105 | Q_scaling = Q_scaling * (1 + torch.tanh(Q_scaling_adj)) 106 | Q_offsets = Q_offsets * (1 + torch.tanh(Q_offsets_adj)) # [N_visible_anchor, 1] 107 | feat = (STE_multistep.apply(feat, Q_feat, pc._anchor_feat.mean())).detach() 108 | grid_scaling = (STE_multistep.apply(grid_scaling, Q_scaling, pc.get_scaling.mean())).detach() 109 | grid_offsets = (STE_multistep.apply(grid_offsets, Q_offsets.unsqueeze(1), pc._offset.mean())).detach() 110 | torch.cuda.synchronize(); time_sub = time.time() - t1 111 | 112 | else: 113 | pass 114 | 115 | ob_view = anchor - viewpoint_camera.camera_center 116 | ob_dist = ob_view.norm(dim=1, keepdim=True) 117 | ob_view = ob_view / ob_dist 118 | 119 | ## view-adaptive feature 120 | if pc.use_feat_bank: 121 | cat_view = torch.cat([ob_view, ob_dist], dim=1) # [3+1] 122 | 123 | bank_weight = pc.get_featurebank_mlp(cat_view).unsqueeze(dim=1) # [N_visible_anchor, 1, 3] 124 | 125 | feat = feat.unsqueeze(dim=-1) # feat: [N_visible_anchor, 32] 126 | feat = \ 127 | feat[:, ::4, :1].repeat([1, 4, 1])*bank_weight[:, :, :1] + \ 128 | feat[:, ::2, :1].repeat([1, 2, 1])*bank_weight[:, :, 1:2] + \ 129 | feat[:, ::1, :1]*bank_weight[:, :, 2:] 130 | feat = feat.squeeze(dim=-1) # [N_visible_anchor, 32] 131 | 132 | cat_local_view = torch.cat([feat, ob_view, ob_dist], dim=1) # [N_visible_anchor, 32+3+1] 133 | 134 | neural_opacity = pc.get_opacity_mlp(cat_local_view) # [N_visible_anchor, K] 135 | neural_opacity = neural_opacity.reshape([-1, 1]) # [N_visible_anchor*K, 1] 136 | neural_opacity = neural_opacity * binary_grid_masks.view(-1, 1) 137 | mask = (neural_opacity > 0.0) 138 | mask = mask.view(-1) # [N_visible_anchor*K] 139 | 140 | # select opacity 141 | opacity = neural_opacity[mask] # [N_opacity_pos_gaussian, 1] 142 | 143 | # get offset's color 144 | color = pc.get_color_mlp(cat_local_view) # [N_visible_anchor, K*3] 145 | color = color.reshape([anchor.shape[0] * pc.n_offsets, 3]) # [N_visible_anchor*K, 3] 146 | 147 | # get offset's cov 148 | scale_rot = pc.get_cov_mlp(cat_local_view) # [N_visible_anchor, K*7] 149 | scale_rot = scale_rot.reshape([anchor.shape[0] * pc.n_offsets, 7]) # [N_visible_anchor*K, 7] 150 | 151 | offsets = grid_offsets.view([-1, 3]) # [N_visible_anchor*K, 3] 152 | 153 | # combine for parallel masking 154 | concatenated = torch.cat([grid_scaling, anchor], dim=-1) # [N_visible_anchor, 6+3] 155 | concatenated_repeated = repeat(concatenated, 'n (c) -> (n k) (c)', k=pc.n_offsets) # [N_visible_anchor*K, 6+3] 156 | concatenated_all = torch.cat([concatenated_repeated, color, scale_rot, offsets], 157 | dim=-1) # [N_visible_anchor*K, (6+3)+3+7+3] 158 | masked = concatenated_all[mask] # [N_opacity_pos_gaussian, (6+3)+3+7+3] 159 | scaling_repeat, repeat_anchor, color, scale_rot, offsets = masked.split([6, 3, 3, 7, 3], dim=-1) 160 | 161 | # post-process cov 162 | scaling = scaling_repeat[:, 3:] * torch.sigmoid( 163 | scale_rot[:, :3]) 164 | rot = pc.rotation_activation(scale_rot[:, 3:7]) # [N_opacity_pos_gaussian, 4] 165 | 166 | offsets = offsets * scaling_repeat[:, :3] # [N_opacity_pos_gaussian, 3] 167 | xyz = repeat_anchor + offsets # [N_opacity_pos_gaussian, 3] 168 | 169 | if is_training: 170 | return xyz, color, opacity, scaling, rot, neural_opacity, mask, bit_per_param, bit_per_feat_param, bit_per_scaling_param, bit_per_offsets_param 171 | else: 172 | return xyz, color, opacity, scaling, rot, time_sub 173 | 174 | 175 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, visible_mask=None, retain_grad=False, step=0): 176 | """ 177 | Render the scene. 178 | 179 | Background tensor (bg_color) must be on GPU! 180 | """ 181 | is_training = pc.get_color_mlp.training 182 | 183 | if is_training: 184 | xyz, color, opacity, scaling, rot, neural_opacity, mask, bit_per_param, bit_per_feat_param, bit_per_scaling_param, bit_per_offsets_param = generate_neural_gaussians(viewpoint_camera, pc, visible_mask, is_training=is_training, step=step) 185 | else: 186 | xyz, color, opacity, scaling, rot, time_sub = generate_neural_gaussians(viewpoint_camera, pc, visible_mask, is_training=is_training, step=step) 187 | 188 | screenspace_points = torch.zeros_like(xyz, dtype=pc.get_anchor.dtype, requires_grad=True, device="cuda") + 0 189 | if retain_grad: 190 | try: 191 | screenspace_points.retain_grad() 192 | except: 193 | pass 194 | 195 | # Set up rasterization configuration 196 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 197 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 198 | 199 | raster_settings = GaussianRasterizationSettings( 200 | image_height=int(viewpoint_camera.image_height), 201 | image_width=int(viewpoint_camera.image_width), 202 | tanfovx=tanfovx, 203 | tanfovy=tanfovy, 204 | bg=bg_color, 205 | scale_modifier=scaling_modifier, 206 | viewmatrix=viewpoint_camera.world_view_transform, 207 | projmatrix=viewpoint_camera.full_proj_transform, 208 | sh_degree=1, 209 | campos=viewpoint_camera.camera_center, 210 | prefiltered=False, 211 | debug=pipe.debug 212 | ) 213 | 214 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 215 | 216 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 217 | rendered_image, radii = rasterizer( 218 | means3D = xyz, 219 | means2D = screenspace_points, 220 | shs = None, 221 | colors_precomp = color, 222 | opacities = opacity, 223 | scales = scaling, 224 | rotations = rot, 225 | cov3D_precomp = None) 226 | 227 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 228 | if is_training: 229 | return {"render": rendered_image, 230 | "viewspace_points": screenspace_points, 231 | "visibility_filter" : radii > 0, 232 | "radii": radii, 233 | "selection_mask": mask, 234 | "neural_opacity": neural_opacity, 235 | "scaling": scaling, 236 | "bit_per_param": bit_per_param, 237 | "bit_per_feat_param": bit_per_feat_param, 238 | "bit_per_scaling_param": bit_per_scaling_param, 239 | "bit_per_offsets_param": bit_per_offsets_param, 240 | } 241 | else: 242 | return {"render": rendered_image, 243 | "viewspace_points": screenspace_points, 244 | "visibility_filter" : radii > 0, 245 | "radii": radii, 246 | "time_sub": time_sub, 247 | } 248 | 249 | 250 | def prefilter_voxel(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0, 251 | override_color=None): 252 | """ 253 | Render the scene. 254 | 255 | Background tensor (bg_color) must be on GPU! 256 | """ 257 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 258 | screenspace_points = torch.zeros_like(pc.get_anchor, dtype=pc.get_anchor.dtype, requires_grad=True, 259 | device="cuda") + 0 260 | try: 261 | screenspace_points.retain_grad() 262 | except: 263 | pass 264 | 265 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 266 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 267 | 268 | raster_settings = GaussianRasterizationSettings( 269 | image_height=int(viewpoint_camera.image_height), 270 | image_width=int(viewpoint_camera.image_width), 271 | tanfovx=tanfovx, 272 | tanfovy=tanfovy, 273 | bg=bg_color, 274 | scale_modifier=scaling_modifier, 275 | viewmatrix=viewpoint_camera.world_view_transform, 276 | projmatrix=viewpoint_camera.full_proj_transform, 277 | sh_degree=1, 278 | campos=viewpoint_camera.camera_center, 279 | prefiltered=False, 280 | debug=pipe.debug 281 | ) 282 | 283 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 284 | 285 | means3D = pc.get_anchor 286 | 287 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 288 | # scaling / rotation by the rasterizer. 289 | scales = None 290 | rotations = None 291 | cov3D_precomp = None 292 | if pipe.compute_cov3D_python: # False 293 | cov3D_precomp = pc.get_covariance(scaling_modifier) 294 | else: # into here 295 | scales = pc.get_scaling # requires_grad = True 296 | rotations = pc.get_rotation # requires_grad = True 297 | 298 | radii_pure = rasterizer.visible_filter( 299 | means3D=means3D, 300 | scales=scales[:, :3], 301 | rotations=rotations, 302 | cov3D_precomp=cov3D_precomp, # None 303 | ) 304 | 305 | return radii_pure > 0 306 | -------------------------------------------------------------------------------- /gaussian_renderer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/gaussian_renderer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /gaussian_renderer/__pycache__/network_gui.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/gaussian_renderer/__pycache__/network_gui.cpython-37.pyc -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /results/DeepBlending/drjohnson.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,29.8491783,0.9062452,0.2545068,8044150.784,800442 3 | HAC-lowrate,29.5277023,0.9025214,0.2650948,5815821.9264,757083 4 | -------------------------------------------------------------------------------- /results/DeepBlending/playroom.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,30.8383617,0.905672,0.2616119,5277483.008,448521 3 | HAC-lowrate,30.4393959,0.902317,0.2723858,3306474.7008,380598 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/bicycle.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,25.1073151,0.7424726,0.2588055,41048080.384,4462216 3 | HAC-lowrate,25.0472488,0.7415894,0.2642056,28876944.1792,4337034 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/bonsai.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,32.9719124,0.9480926,0.1801753,13334216.704,1110454 3 | HAC-lowrate,32.2754822,0.9419372,0.1894576,8976754.2784,1016354 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/counter.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,29.7447491,0.917532,0.1838127,10949964.5952,910833 3 | HAC-lowrate,29.347311,0.9105615,0.1951433,7614024.9088,786790 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/flowers.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,21.2706013,0.5747946,0.3767022,28892463.104,3111616 3 | HAC-lowrate,21.2626057,0.572156,0.3812912,20542232.9856,3030281 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/garden.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,27.4615631,0.8488841,0.1394164,33734157.9264,3159104 3 | HAC-lowrate,27.2805805,0.8419042,0.1514979,23796278.8864,2783590 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/kitchen.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,31.6308594,0.9295087,0.1216748,12658094.8992,1160861 3 | HAC-lowrate,31.1647701,0.9230514,0.1306984,8445755.392,955774 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/room.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,31.9047031,0.9261015,0.1975124,8234886.7584,638887 3 | HAC-lowrate,31.5470905,0.9211181,0.2078761,5796528.128,613008 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/stump.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,26.5885735,0.7628535,0.2635286,26482520.8832,2811861 3 | HAC-lowrate,26.5819645,0.7618363,0.2692434,18989711.36,3139864 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/treehill.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,23.2636204,0.6480887,0.3448641,31086503.5264,3016897 3 | HAC-lowrate,23.3030071,0.6446206,0.3564054,21010422.1696,2835570 4 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | Results formatted following [3DGS.zip](https://github.com/w-m/3dgs-compression-survey?tab=readme-ov-file#including-your-own-results). 2 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/chair.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,35.4915924,0.9860504,0.0126323,1753219.072,127611 3 | HAC-lowrate,34.7256927,0.9837951,0.0156229,1080662.4256,82370 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/drums.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,26.4458008,0.9520242,0.040858,2432906.0352,188193 3 | HAC-lowrate,26.3249664,0.9518869,0.0425592,1524944.0768,130122 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/ficus.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,35.3031387,0.9861924,0.0125956,1608830.1568,111752 3 | HAC-lowrate,34.8955841,0.9849692,0.0141863,990170.3168,70576 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/hotdog.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,37.8689346,0.983838,0.0237683,1020054.7328,62248 3 | HAC-lowrate,37.1095505,0.9814867,0.0286784,669201.2032,39299 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/lego.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,35.6688728,0.9810752,0.0185425,1997012.992,128854 3 | HAC-lowrate,35.0378571,0.9789792,0.0217885,1313656.0128,106010 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/materials.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,30.7035522,0.9620085,0.0386389,2171391.1808,204113 3 | HAC-lowrate,30.5300407,0.9608711,0.0413285,1518967.1936,152503 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/mic.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,36.7093277,0.9918905,0.0075859,1063256.064,83890 3 | HAC-lowrate,35.9191742,0.9903358,0.0095429,700763.3408,54533 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/ship.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,31.5231514,0.9036886,0.1153168,3558866.944,238351 3 | HAC-lowrate,31.3790035,0.903459,0.1189712,2089707.1104,163975 4 | -------------------------------------------------------------------------------- /results/TanksAndTemples/train.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,22.7802544,0.8230955,0.2065995,10554232.0128,888307 3 | HAC-lowrate,22.1893082,0.8151628,0.2155594,7277641.728,746258 4 | -------------------------------------------------------------------------------- /results/TanksAndTemples/truck.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | HAC-highrate,26.0207634,0.8831319,0.1471665,13024677.0688,1163467 3 | HAC-lowrate,25.8835316,0.8775318,0.1575988,9709813.76,951290 4 | -------------------------------------------------------------------------------- /run_shell_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005 4 | for cuda, scene in enumerate(['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship']): 5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/nerf_synthetic/{scene} --eval --lod 0 --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 -m outputs/nerf_synthetic/{scene}/{lmbda} --lmbda {lmbda}' 6 | os.system(one_cmd) 7 | -------------------------------------------------------------------------------- /run_shell_bungee.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005 4 | for cuda, scene in enumerate(['amsterdam', 'bilbao', 'hollywood', 'pompidou', 'quebec', 'rome']): 5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/bungeenerf/{scene} --eval --lod 30 --voxel_size 0 --update_init_factor 128 --iterations 30_000 -m outputs/bungeenerf/{scene}/{lmbda} --lmbda {lmbda}' 6 | os.system(one_cmd) 7 | -------------------------------------------------------------------------------- /run_shell_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005 4 | for cuda, scene in enumerate(['playroom', 'drjohnson']): 5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/blending/{scene} --eval --lod 0 --voxel_size 0.005 --update_init_factor 16 --iterations 30_000 -m outputs/blending/{scene}/{lmbda} --lmbda {lmbda}' 6 | os.system(one_cmd) 7 | -------------------------------------------------------------------------------- /run_shell_mip360.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005 4 | for cuda, scene in enumerate(['bicycle', 'garden', 'stump', 'room', 'counter', 'kitchen', 'bonsai', 'flowers', 'treehill']): 5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/mipnerf360/{scene} --eval --lod 0 --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -m outputs/mipnerf360/{scene}/{lmbda} --lmbda {lmbda}' 6 | os.system(one_cmd) -------------------------------------------------------------------------------- /run_shell_tnt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005 4 | for cuda, scene in enumerate(['truck', 'train']): 5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/tandt/{scene} --eval --lod 0 --voxel_size 0.01 --update_init_factor 16 --iterations 30_000 -m outputs/tandt/{scene}/{lmbda} --lmbda {lmbda}' 6 | os.system(one_cmd) 7 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], ply_path=None): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | if load_iteration: 34 | if load_iteration == -1: 35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 36 | else: 37 | self.loaded_iter = load_iteration 38 | 39 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 40 | 41 | self.train_cameras = {} 42 | self.test_cameras = {} 43 | 44 | self.x_bound = None 45 | if os.path.exists(os.path.join(args.source_path, "sparse")): 46 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args.lod) 47 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 48 | print("Found transforms_train.json file, assuming Blender data set!") 49 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, ply_path=ply_path) 50 | self.x_bound = 1.3 51 | else: 52 | assert False, "Could not recognize scene type!" 53 | 54 | if not self.loaded_iter: 55 | if ply_path is not None: 56 | with open(ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 57 | dest_file.write(src_file.read()) 58 | else: 59 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 60 | dest_file.write(src_file.read()) 61 | json_cams = [] 62 | camlist = [] 63 | if scene_info.test_cameras: 64 | camlist.extend(scene_info.test_cameras) 65 | if scene_info.train_cameras: 66 | camlist.extend(scene_info.train_cameras) 67 | for id, cam in enumerate(camlist): 68 | json_cams.append(camera_to_JSON(id, cam)) 69 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 70 | json.dump(json_cams, file) 71 | 72 | if shuffle: 73 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 74 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 75 | 76 | self.cameras_extent = scene_info.nerf_normalization["radius"] 77 | 78 | # print(f'self.cameras_extent: {self.cameras_extent}') 79 | 80 | for resolution_scale in resolution_scales: 81 | print("Loading Training Cameras") 82 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 83 | print("Loading Test Cameras") 84 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 85 | 86 | if self.loaded_iter: 87 | self.gaussians.load_ply_sparse_gaussian(os.path.join(self.model_path, 88 | "point_cloud", 89 | "iteration_" + str(self.loaded_iter), 90 | "point_cloud.ply")) 91 | self.gaussians.load_mlp_checkpoints(os.path.join(self.model_path, 92 | "point_cloud", 93 | "iteration_" + str(self.loaded_iter), 94 | "checkpoint.pth")) 95 | else: 96 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 97 | 98 | def save(self, iteration): 99 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 100 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 101 | self.gaussians.save_mlp_checkpoints(os.path.join(point_cloud_path, "checkpoint.pth")) 102 | 103 | def getTrainCameras(self, scale=1.0): 104 | return self.train_cameras[scale] 105 | 106 | def getTestCameras(self, scale=1.0): 107 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | 32 | try: 33 | self.data_device = torch.device(data_device) 34 | except Exception as e: 35 | print(e) 36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 37 | self.data_device = torch.device("cuda") 38 | 39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | 43 | if gt_alpha_mask is not None: 44 | self.original_image *= gt_alpha_mask.to(self.data_device) 45 | else: 46 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 47 | 48 | self.zfar = 100.0 49 | self.znear = 0.01 50 | 51 | self.trans = trans 52 | self.scale = scale 53 | 54 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 55 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 56 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 57 | self.camera_center = self.world_view_transform.inverse()[3, :3] 58 | 59 | class MiniCam: 60 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 61 | self.image_width = width 62 | self.image_height = height 63 | self.FoVy = fovy 64 | self.FoVx = fovx 65 | self.znear = znear 66 | self.zfar = zfar 67 | self.world_view_transform = world_view_transform 68 | self.full_proj_transform = full_proj_transform 69 | view_inv = torch.inverse(self.world_view_transform) 70 | self.camera_center = view_inv[3][:3] 71 | 72 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from tqdm import tqdm 16 | from typing import NamedTuple 17 | from colorama import Fore, init, Style 18 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 19 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 20 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 21 | import numpy as np 22 | import json 23 | from pathlib import Path 24 | from plyfile import PlyData, PlyElement 25 | from utils.sh_utils import SH2RGB 26 | from scene.gaussian_model import BasicPointCloud 27 | 28 | class CameraInfo(NamedTuple): 29 | uid: int 30 | R: np.array 31 | T: np.array 32 | FovY: np.array 33 | FovX: np.array 34 | image: np.array 35 | image_path: str 36 | image_name: str 37 | width: int 38 | height: int 39 | 40 | class SceneInfo(NamedTuple): 41 | point_cloud: BasicPointCloud 42 | train_cameras: list 43 | test_cameras: list 44 | nerf_normalization: dict 45 | ply_path: str 46 | 47 | def getNerfppNorm(cam_info): 48 | def get_center_and_diag(cam_centers): 49 | cam_centers = np.hstack(cam_centers) 50 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 51 | center = avg_cam_center 52 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 53 | diagonal = np.max(dist) 54 | return center.flatten(), diagonal 55 | 56 | cam_centers = [] 57 | 58 | for cam in cam_info: 59 | W2C = getWorld2View2(cam.R, cam.T) 60 | C2W = np.linalg.inv(W2C) 61 | cam_centers.append(C2W[:3, 3:4]) 62 | 63 | center, diagonal = get_center_and_diag(cam_centers) 64 | radius = diagonal * 1.1 65 | 66 | translate = -center 67 | 68 | return {"translate": translate, "radius": radius} 69 | 70 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 71 | cam_infos = [] 72 | for idx, key in enumerate(cam_extrinsics): 73 | sys.stdout.write('\r') 74 | # the exact output you're looking for: 75 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 76 | sys.stdout.flush() 77 | 78 | extr = cam_extrinsics[key] 79 | intr = cam_intrinsics[extr.camera_id] 80 | height = intr.height 81 | width = intr.width 82 | 83 | uid = intr.id 84 | R = np.transpose(qvec2rotmat(extr.qvec)) 85 | T = np.array(extr.tvec) 86 | 87 | # if intr.model=="SIMPLE_PINHOLE": 88 | if intr.model=="SIMPLE_PINHOLE" or intr.model == "SIMPLE_RADIAL": 89 | focal_length_x = intr.params[0] 90 | FovY = focal2fov(focal_length_x, height) 91 | FovX = focal2fov(focal_length_x, width) 92 | elif intr.model=="PINHOLE": 93 | focal_length_x = intr.params[0] 94 | focal_length_y = intr.params[1] 95 | FovY = focal2fov(focal_length_y, height) 96 | FovX = focal2fov(focal_length_x, width) 97 | else: 98 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 99 | 100 | # print(f'FovX: {FovX}, FovY: {FovY}') 101 | 102 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 103 | image_name = os.path.basename(image_path).split(".")[0] 104 | image = Image.open(image_path) 105 | 106 | # print(f'image: {image.size}') 107 | 108 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 109 | image_path=image_path, image_name=image_name, width=width, height=height) 110 | cam_infos.append(cam_info) 111 | sys.stdout.write('\n') 112 | return cam_infos 113 | 114 | def fetchPly(path): 115 | plydata = PlyData.read(path) 116 | vertices = plydata['vertex'] 117 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 118 | try: 119 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 120 | except: 121 | colors = np.random.rand(positions.shape[0], positions.shape[1]) 122 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 123 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 124 | 125 | def storePly(path, xyz, rgb): 126 | # Define the dtype for the structured array 127 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 128 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 129 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 130 | 131 | normals = np.zeros_like(xyz) 132 | 133 | elements = np.empty(xyz.shape[0], dtype=dtype) 134 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 135 | elements[:] = list(map(tuple, attributes)) 136 | 137 | # Create the PlyData object and write to file 138 | vertex_element = PlyElement.describe(elements, 'vertex') 139 | ply_data = PlyData([vertex_element]) 140 | ply_data.write(path) 141 | 142 | def readColmapSceneInfo(path, images, eval, lod, llffhold=8): 143 | try: 144 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 145 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 146 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 147 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 148 | except: 149 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 150 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 151 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 152 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 153 | 154 | reading_dir = "images" if images == None else images 155 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 156 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 157 | 158 | if eval: 159 | if lod>0: 160 | print(f'using lod, using eval') 161 | if lod < 50: 162 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx > lod] 163 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx <= lod] 164 | print(f'test_cam_infos: {len(test_cam_infos)}') 165 | else: 166 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx <= lod] 167 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx > lod] 168 | 169 | else: 170 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 171 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 172 | 173 | else: 174 | train_cam_infos = cam_infos 175 | test_cam_infos = [] 176 | 177 | nerf_normalization = getNerfppNorm(train_cam_infos) 178 | 179 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 180 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 181 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 182 | if not os.path.exists(ply_path): 183 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 184 | try: 185 | xyz, rgb, _ = read_points3D_binary(bin_path) 186 | except: 187 | xyz, rgb, _ = read_points3D_text(txt_path) 188 | storePly(ply_path, xyz, rgb) 189 | # try: 190 | print(f'start fetching data from ply file') 191 | pcd = fetchPly(ply_path) 192 | # except: 193 | # pcd = None 194 | 195 | scene_info = SceneInfo(point_cloud=pcd, 196 | train_cameras=train_cam_infos, 197 | test_cameras=test_cam_infos, 198 | nerf_normalization=nerf_normalization, 199 | ply_path=ply_path) 200 | return scene_info 201 | 202 | # def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 203 | # cam_infos = [] 204 | 205 | # with open(os.path.join(path, transformsfile)) as json_file: 206 | # contents = json.load(json_file) 207 | # fovx = contents["camera_angle_x"] 208 | 209 | # frames = contents["frames"] 210 | # for idx, frame in enumerate(frames): 211 | # cam_name = os.path.join(path, frame["file_path"] + extension) 212 | 213 | # # NeRF 'transform_matrix' is a camera-to-world transform 214 | # c2w = np.array(frame["transform_matrix"]) 215 | # # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 216 | # c2w[:3, 1:3] *= -1 217 | 218 | # # get the world-to-camera transform and set R, T 219 | # w2c = np.linalg.inv(c2w) 220 | # R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 221 | # T = w2c[:3, 3] 222 | 223 | # image_path = os.path.join(path, cam_name) 224 | # image_name = Path(cam_name).stem 225 | # image = Image.open(image_path) 226 | 227 | # im_data = np.array(image.convert("RGBA")) 228 | 229 | # bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 230 | 231 | # norm_data = im_data / 255.0 232 | # arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 233 | # image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 234 | 235 | # fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 236 | # FovY = fovy 237 | # FovX = fovx 238 | 239 | # cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 240 | # image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 241 | 242 | # return cam_infos 243 | 244 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png", is_debug=False): 245 | cam_infos = [] 246 | with open(os.path.join(path, transformsfile)) as json_file: 247 | contents = json.load(json_file) 248 | try: 249 | fovx = contents["camera_angle_x"] 250 | except: 251 | fovx = None 252 | 253 | frames = contents["frames"] 254 | # check if filename already contain postfix 255 | if frames[0]["file_path"].split('.')[-1] in ['jpg', 'jpeg', 'JPG', 'png']: 256 | extension = "" 257 | 258 | c2ws = np.array([frame["transform_matrix"] for frame in frames]) 259 | 260 | Ts = c2ws[:,:3,3] 261 | 262 | ct = 0 263 | 264 | progress_bar = tqdm(frames, desc="Loading dataset") 265 | 266 | for idx, frame in enumerate(frames): 267 | cam_name = os.path.join(path, frame["file_path"] + extension) 268 | if not os.path.exists(cam_name): 269 | continue 270 | # NeRF 'transform_matrix' is a camera-to-world transform 271 | c2w = np.array(frame["transform_matrix"]) 272 | 273 | if idx % 10 == 0: 274 | progress_bar.set_postfix({"num": Fore.YELLOW+f"{ct}/{len(frames)}"+Style.RESET_ALL}) 275 | progress_bar.update(10) 276 | if idx == len(frames) - 1: 277 | progress_bar.close() 278 | 279 | ct += 1 280 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 281 | c2w[:3, 1:3] *= -1 282 | if "small_city_img" in path: 283 | c2w[-1,-1] = 1 284 | 285 | # get the world-to-camera transform and set R, T 286 | w2c = np.linalg.inv(c2w) 287 | 288 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 289 | T = w2c[:3, 3] 290 | 291 | image_path = os.path.join(path, cam_name) 292 | image_name = Path(cam_name).stem 293 | image = Image.open(image_path) 294 | 295 | im_data = np.array(image.convert("RGBA")) 296 | 297 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 298 | 299 | norm_data = im_data / 255.0 300 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 301 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 302 | 303 | if fovx is not None: 304 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 305 | FovY = fovy 306 | FovX = fovx 307 | else: 308 | # given focal in pixel unit 309 | FovY = focal2fov(frame["fl_y"], image.size[1]) 310 | FovX = focal2fov(frame["fl_x"], image.size[0]) 311 | 312 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 313 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 314 | 315 | if is_debug and idx > 50: 316 | break 317 | return cam_infos 318 | 319 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png", ply_path=None): 320 | print("Reading Training Transforms") 321 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 322 | print("Reading Test Transforms") 323 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 324 | 325 | if not eval: 326 | train_cam_infos.extend(test_cam_infos) 327 | test_cam_infos = [] 328 | 329 | nerf_normalization = getNerfppNorm(train_cam_infos) 330 | if ply_path is None: 331 | ply_path = os.path.join(path, "points3d.ply") 332 | if not os.path.exists(ply_path): 333 | # Since this data set has no colmap data, we start with random points 334 | num_pts = 10_000 335 | print(f"Generating random point cloud ({num_pts})...") 336 | 337 | # We create random points inside the bounds of the synthetic Blender scenes 338 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 339 | shs = np.random.random((num_pts, 3)) / 255.0 340 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 341 | 342 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 343 | try: 344 | pcd = fetchPly(ply_path) 345 | except: 346 | pcd = None 347 | 348 | scene_info = SceneInfo(point_cloud=pcd, 349 | train_cameras=train_cam_infos, 350 | test_cameras=test_cam_infos, 351 | nerf_normalization=nerf_normalization, 352 | ply_path=ply_path) 353 | return scene_info 354 | 355 | 356 | sceneLoadTypeCallbacks = { 357 | "Colmap": readColmapSceneInfo, 358 | "Blender": readNerfSyntheticInfo, 359 | } -------------------------------------------------------------------------------- /submodules/arithmetic.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/arithmetic.zip -------------------------------------------------------------------------------- /submodules/diff-gaussian-rasterization.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/diff-gaussian-rasterization.zip -------------------------------------------------------------------------------- /submodules/gridencoder.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/gridencoder.zip -------------------------------------------------------------------------------- /submodules/simple-knn.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/simple-knn.zip -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import numpy as np 14 | 15 | import subprocess 16 | # cmd = 'nvidia-smi -q -d Memory |grep -A4 GPU|grep Used' 17 | # result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode().split('\n') 18 | # os.environ['CUDA_VISIBLE_DEVICES']=str(np.argmin([int(x.split()[2]) for x in result[:-1]])) 19 | 20 | # os.system('echo $CUDA_VISIBLE_DEVICES') 21 | 22 | 23 | import torch 24 | import torchvision 25 | import json 26 | import wandb 27 | import time 28 | from os import makedirs 29 | import shutil, pathlib 30 | from pathlib import Path 31 | from PIL import Image 32 | import torchvision.transforms.functional as tf 33 | # from lpipsPyTorch import lpips 34 | # import lpips 35 | from random import randint 36 | from utils.loss_utils import l1_loss, ssim 37 | from gaussian_renderer import prefilter_voxel, render, network_gui 38 | import sys 39 | from scene import Scene, GaussianModel 40 | from utils.general_utils import safe_state 41 | import uuid 42 | from tqdm import tqdm 43 | from utils.image_utils import psnr 44 | from argparse import ArgumentParser, Namespace 45 | from arguments import ModelParams, PipelineParams, OptimizationParams 46 | from utils.encodings import get_binary_vxl_size 47 | 48 | # torch.set_num_threads(32) 49 | # lpips_fn = lpips.LPIPS(net='vgg').to('cuda') 50 | 51 | from lpipsPyTorch import lpips 52 | 53 | bit2MB_scale = 8 * 1024 * 1024 54 | run_codec = True 55 | 56 | try: 57 | from torch.utils.tensorboard import SummaryWriter 58 | TENSORBOARD_FOUND = True 59 | print("found tf board") 60 | except ImportError: 61 | TENSORBOARD_FOUND = False 62 | print("not found tf board") 63 | 64 | def saveRuntimeCode(dst: str) -> None: 65 | additionalIgnorePatterns = ['.git', '.gitignore'] 66 | ignorePatterns = set() 67 | ROOT = '.' 68 | with open(os.path.join(ROOT, '.gitignore')) as gitIgnoreFile: 69 | for line in gitIgnoreFile: 70 | if not line.startswith('#'): 71 | if line.endswith('\n'): 72 | line = line[:-1] 73 | if line.endswith('/'): 74 | line = line[:-1] 75 | ignorePatterns.add(line) 76 | ignorePatterns = list(ignorePatterns) 77 | for additionalPattern in additionalIgnorePatterns: 78 | ignorePatterns.append(additionalPattern) 79 | 80 | log_dir = pathlib.Path(__file__).parent.resolve() 81 | 82 | 83 | shutil.copytree(log_dir, dst, ignore=shutil.ignore_patterns(*ignorePatterns)) 84 | 85 | print('Backup Finished!') 86 | 87 | 88 | def training(args_param, dataset, opt, pipe, dataset_name, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, wandb=None, logger=None, ply_path=None): 89 | first_iter = 0 90 | tb_writer = prepare_output_and_logger(dataset) 91 | 92 | gaussians = GaussianModel( 93 | dataset.feat_dim, 94 | dataset.n_offsets, 95 | dataset.voxel_size, 96 | dataset.update_depth, 97 | dataset.update_init_factor, 98 | dataset.update_hierachy_factor, 99 | dataset.use_feat_bank, 100 | n_features_per_level=args_param.n_features, 101 | log2_hashmap_size=args_param.log2, 102 | log2_hashmap_size_2D=args_param.log2_2D, 103 | ) 104 | scene = Scene(dataset, gaussians, ply_path=ply_path) 105 | gaussians.update_anchor_bound() 106 | 107 | gaussians.training_setup(opt) 108 | if checkpoint: 109 | (model_params, first_iter) = torch.load(checkpoint) 110 | gaussians.restore(model_params, opt) 111 | 112 | iter_start = torch.cuda.Event(enable_timing = True) 113 | iter_end = torch.cuda.Event(enable_timing = True) 114 | 115 | viewpoint_stack = None 116 | ema_loss_for_log = 0.0 117 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 118 | first_iter += 1 119 | torch.cuda.synchronize(); t_start = time.time() 120 | log_time_sub = 0 121 | for iteration in range(first_iter, opt.iterations + 1): 122 | # network gui not available in scaffold-gs yet 123 | if network_gui.conn == None: 124 | network_gui.try_connect() 125 | while network_gui.conn != None: 126 | try: 127 | net_image_bytes = None 128 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 129 | if custom_cam != None: 130 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 131 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 132 | network_gui.send(net_image_bytes, dataset.source_path) 133 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 134 | break 135 | except Exception as e: 136 | network_gui.conn = None 137 | 138 | iter_start.record() 139 | 140 | gaussians.update_learning_rate(iteration) 141 | 142 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 143 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 144 | 145 | # Pick a random Camera 146 | if not viewpoint_stack: 147 | viewpoint_stack = scene.getTrainCameras().copy() 148 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 149 | 150 | # Render 151 | if (iteration - 1) == debug_from: 152 | pipe.debug = True 153 | 154 | voxel_visible_mask = prefilter_voxel(viewpoint_cam, gaussians, pipe, background) 155 | # voxel_visible_mask:bool = radii_pure > 0: 应该是[N_anchor]? 156 | retain_grad = (iteration < opt.update_until and iteration >= 0) 157 | render_pkg = render(viewpoint_cam, gaussians, pipe, background, visible_mask=voxel_visible_mask, retain_grad=retain_grad, step=iteration) 158 | image, viewspace_point_tensor, visibility_filter, offset_selection_mask, radii, scaling, opacity = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["selection_mask"], render_pkg["radii"], render_pkg["scaling"], render_pkg["neural_opacity"] 159 | # image: [3, H, W]. inited as: torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); 160 | # viewspace_point_tensor=screenspace_points: [N_opacity_pos_gaussian, 3] 161 | # visibility_filter: radii > 0. 其中 radii inited as: torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 其中P=N_opacity_pos_gaussian 162 | # offset_selection_mask: [N_visible_anchor*k]。 用来表示visible anchor中哪几个gaussian是有效的,根据opacity>0.0得到 163 | # radii: [N_opacity_pos_gaussian]. inited as: torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 其中P=N_opacity_pos_gaussian 164 | # scaling: [N_opacity_pos_gaussian, 3] 165 | # opacity: [N_visible_anchor*K, 1] 166 | 167 | bit_per_param = render_pkg["bit_per_param"] 168 | bit_per_feat_param = render_pkg["bit_per_feat_param"] 169 | bit_per_scaling_param = render_pkg["bit_per_scaling_param"] 170 | bit_per_offsets_param = render_pkg["bit_per_offsets_param"] 171 | 172 | if iteration % 2000 == 0 and bit_per_param is not None: 173 | 174 | ttl_size_feat_MB = bit_per_feat_param.item() * gaussians.get_anchor.shape[0] * gaussians.feat_dim / bit2MB_scale 175 | ttl_size_scaling_MB = bit_per_scaling_param.item() * gaussians.get_anchor.shape[0] * 6 / bit2MB_scale 176 | ttl_size_offsets_MB = bit_per_offsets_param.item() * gaussians.get_anchor.shape[0] * 3 * gaussians.n_offsets / bit2MB_scale 177 | ttl_size_MB = ttl_size_feat_MB + ttl_size_scaling_MB + ttl_size_offsets_MB 178 | 179 | logger.info("\n----------------------------------------------------------------------------------------") 180 | logger.info("\n-----[ITER {}] bits info: bit_per_feat_param={}, anchor_num={}, ttl_size_feat_MB={}-----".format(iteration, bit_per_feat_param.item(), gaussians.get_anchor.shape[0], ttl_size_feat_MB)) 181 | logger.info("\n-----[ITER {}] bits info: bit_per_scaling_param={}, anchor_num={}, ttl_size_scaling_MB={}-----".format(iteration, bit_per_scaling_param.item(), gaussians.get_anchor.shape[0], ttl_size_scaling_MB)) 182 | logger.info("\n-----[ITER {}] bits info: bit_per_offsets_param={}, anchor_num={}, ttl_size_offsets_MB={}-----".format(iteration, bit_per_offsets_param.item(), gaussians.get_anchor.shape[0], ttl_size_offsets_MB)) 183 | logger.info("\n-----[ITER {}] bits info: bit_per_param={}, anchor_num={}, ttl_size_MB={}-----".format(iteration, bit_per_param.item(), gaussians.get_anchor.shape[0], ttl_size_MB)) 184 | with torch.no_grad(): 185 | grid_masks = gaussians._mask.data 186 | binary_grid_masks = (torch.sigmoid(grid_masks) > 0.01).float() 187 | mask_1_rate, mask_size_bit, mask_size_MB, mask_numel = get_binary_vxl_size(binary_grid_masks + 0.0) # [0, 1] -> [-1, 1] 188 | logger.info("\n-----[ITER {}] bits info: 1_rate_mask={}, mask_numel={}, mask_size_MB={}-----".format(iteration, mask_1_rate, mask_numel, mask_size_MB)) 189 | 190 | gt_image = viewpoint_cam.original_image.cuda() 191 | Ll1 = l1_loss(image, gt_image) 192 | 193 | ssim_loss = (1.0 - ssim(image, gt_image)) 194 | scaling_reg = scaling.prod(dim=1).mean() 195 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ssim_loss + 0.01*scaling_reg 196 | 197 | if bit_per_param is not None: 198 | _, bit_hash_grid, MB_hash_grid, _ = get_binary_vxl_size((gaussians.get_encoding_params()+1)/2) 199 | denom = gaussians._anchor.shape[0]*(gaussians.feat_dim+6+3*gaussians.n_offsets) 200 | loss = loss + args_param.lmbda * (bit_per_param + bit_hash_grid / denom) 201 | 202 | loss = loss + 5e-4 * torch.mean(torch.sigmoid(gaussians._mask)) 203 | 204 | loss.backward() 205 | 206 | iter_end.record() 207 | 208 | with torch.no_grad(): 209 | # Progress bar 210 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 211 | 212 | if iteration % 10 == 0: 213 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 214 | progress_bar.update(10) 215 | if iteration == opt.iterations: 216 | progress_bar.close() 217 | 218 | # Log and save 219 | torch.cuda.synchronize(); t_start_log = time.time() 220 | training_report(tb_writer, dataset_name, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), wandb, logger, args_param.model_path) 221 | if (iteration in saving_iterations): 222 | logger.info("\n[ITER {}] Saving Gaussians".format(iteration)) 223 | scene.save(iteration) 224 | torch.cuda.synchronize(); t_end_log = time.time() 225 | t_log = t_end_log - t_start_log 226 | log_time_sub += t_log 227 | 228 | # densification 229 | if iteration < opt.update_until and iteration > opt.start_stat: 230 | # add statis 231 | # viewspace_point_tensor=screenspace_points: [N_opacity_pos_gaussian, 3] 232 | # opacity: [N_visible_anchor*K, 1] 233 | # visibility_filter: radii > 0. 其中 radii inited as: torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 其中P=N_opacity_pos_gaussian 234 | # offset_selection_mask: [N_visible_anchor*k]。 用来表示visible anchor中哪几个gaussian是有效的,根据opacity>0.0得到 235 | # voxel_visible_mask:bool = radii_pure > 0: 应该是[N_anchor]? voxel_visible_mask.sum()=N_visible_anchor 236 | gaussians.training_statis(viewspace_point_tensor, opacity, visibility_filter, offset_selection_mask, voxel_visible_mask) 237 | if iteration not in range(3000, 4000): # let the model get fit to quantization 238 | # densification 239 | if iteration > opt.update_from and iteration % opt.update_interval == 0: 240 | gaussians.adjust_anchor(check_interval=opt.update_interval, success_threshold=opt.success_threshold, grad_threshold=opt.densify_grad_threshold, min_opacity=opt.min_opacity) 241 | elif iteration == opt.update_until: 242 | del gaussians.opacity_accum 243 | del gaussians.offset_gradient_accum 244 | del gaussians.offset_denom 245 | torch.cuda.empty_cache() 246 | 247 | if iteration < opt.iterations: 248 | gaussians.optimizer.step() 249 | gaussians.optimizer.zero_grad(set_to_none = True) 250 | if (iteration in checkpoint_iterations): 251 | logger.info("\n[ITER {}] Saving Checkpoint".format(iteration)) 252 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 253 | 254 | torch.cuda.synchronize(); t_end = time.time() 255 | logger.info("\n Total Training time: {}".format(t_end-t_start-log_time_sub)) 256 | 257 | return gaussians.x_bound_min, gaussians.x_bound_max 258 | 259 | def prepare_output_and_logger(args): 260 | if not args.model_path: 261 | if os.getenv('OAR_JOB_ID'): 262 | unique_str=os.getenv('OAR_JOB_ID') 263 | else: 264 | unique_str = str(uuid.uuid4()) 265 | args.model_path = os.path.join("./output/", unique_str[0:10]) 266 | 267 | # Set up output folder 268 | print("Output folder: {}".format(args.model_path)) 269 | os.makedirs(args.model_path, exist_ok = True) 270 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 271 | cfg_log_f.write(str(Namespace(**vars(args)))) 272 | 273 | # Create Tensorboard writer 274 | tb_writer = None 275 | if TENSORBOARD_FOUND: 276 | tb_writer = SummaryWriter(args.model_path) 277 | else: 278 | print("Tensorboard not available: not logging progress") 279 | return tb_writer 280 | 281 | 282 | def training_report(tb_writer, dataset_name, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, wandb=None, logger=None, pre_path_name=''): 283 | if tb_writer: 284 | tb_writer.add_scalar(f'{dataset_name}/train_loss_patches/l1_loss', Ll1.item(), iteration) 285 | tb_writer.add_scalar(f'{dataset_name}/train_loss_patches/total_loss', loss.item(), iteration) 286 | tb_writer.add_scalar(f'{dataset_name}/iter_time', elapsed, iteration) 287 | 288 | if wandb is not None: 289 | wandb.log({"train_l1_loss":Ll1, 'train_total_loss':loss, }) 290 | # Report test and samples of training set 291 | if iteration in testing_iterations: 292 | scene.gaussians.eval() 293 | 294 | if 1: 295 | if iteration == testing_iterations[-1]: 296 | with torch.no_grad(): 297 | log_info = scene.gaussians.estimate_final_bits() 298 | logger.info(log_info) 299 | if run_codec: # conduct encoding and decoding 300 | with torch.no_grad(): 301 | bit_stream_path = os.path.join(pre_path_name, 'bitstreams') 302 | os.makedirs(bit_stream_path, exist_ok=True) 303 | # conduct encoding 304 | patched_infos, log_info = scene.gaussians.conduct_encoding(pre_path_name=bit_stream_path) 305 | logger.info(log_info) 306 | # conduct decoding 307 | log_info = scene.gaussians.conduct_decoding(pre_path_name=bit_stream_path, patched_infos=patched_infos) 308 | logger.info(log_info) 309 | torch.cuda.empty_cache() 310 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 311 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 312 | 313 | for config in validation_configs: 314 | # if config['name'] == 'test': assert len(config['cameras']) == 200 315 | if config['cameras'] and len(config['cameras']) > 0: 316 | l1_test = 0.0 317 | psnr_test = 0.0 318 | ssim_test = 0.0 319 | lpips_test = 0.0 320 | 321 | if wandb is not None: 322 | gt_image_list = [] 323 | render_image_list = [] 324 | errormap_list = [] 325 | 326 | t_list = [] 327 | 328 | for idx, viewpoint in enumerate(config['cameras']): 329 | torch.cuda.synchronize(); t_start = time.time() 330 | voxel_visible_mask = prefilter_voxel(viewpoint, scene.gaussians, *renderArgs) 331 | # image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs, visible_mask=voxel_visible_mask)["render"], 0.0, 1.0) 332 | render_output = renderFunc(viewpoint, scene.gaussians, *renderArgs, visible_mask=voxel_visible_mask) 333 | image = torch.clamp(render_output["render"], 0.0, 1.0) 334 | time_sub = render_output["time_sub"] 335 | torch.cuda.synchronize(); t_end = time.time() 336 | t_list.append(t_end - t_start - time_sub) 337 | 338 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 339 | if tb_writer and (idx < 30): 340 | tb_writer.add_images(f'{dataset_name}/'+config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 341 | tb_writer.add_images(f'{dataset_name}/'+config['name'] + "_view_{}/errormap".format(viewpoint.image_name), (gt_image[None]-image[None]).abs(), global_step=iteration) 342 | 343 | if wandb: 344 | render_image_list.append(image[None]) 345 | errormap_list.append((gt_image[None]-image[None]).abs()) 346 | 347 | if iteration == testing_iterations[0]: 348 | tb_writer.add_images(f'{dataset_name}/'+config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 349 | if wandb: 350 | gt_image_list.append(gt_image[None]) 351 | l1_test += l1_loss(image, gt_image).mean().double() 352 | psnr_test += psnr(image, gt_image).mean().double() 353 | ssim_test += ssim(image, gt_image).mean().double() 354 | # lpips_test += lpips_fn(image, gt_image, normalize=True).detach().mean().double() 355 | lpips_test += lpips(image, gt_image, net_type='vgg').detach().mean().double() 356 | 357 | psnr_test /= len(config['cameras']) 358 | ssim_test /= len(config['cameras']) 359 | lpips_test /= len(config['cameras']) 360 | l1_test /= len(config['cameras']) 361 | logger.info("\n[ITER {}] Evaluating {}: L1 {} PSNR {} ssim {} lpips {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test)) 362 | test_fps = 1.0 / torch.tensor(t_list[0:]).mean() 363 | logger.info(f'Test FPS: {test_fps.item():.5f}') 364 | if tb_writer: 365 | tb_writer.add_scalar(f'{dataset_name}/test_FPS', test_fps.item(), 0) 366 | if wandb is not None: 367 | wandb.log({"test_fps": test_fps, }) 368 | 369 | if tb_writer: 370 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 371 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 372 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - ssim', ssim_test, iteration) 373 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - lpips', lpips_test, iteration) 374 | if wandb is not None: 375 | wandb.log({f"{config['name']}_loss_viewpoint_l1_loss":l1_test, f"{config['name']}_PSNR":psnr_test}, f"ssim{ssim_test}", f"lpips{lpips_test}") 376 | 377 | if tb_writer: 378 | # tb_writer.add_histogram(f'{dataset_name}/'+"scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 379 | tb_writer.add_scalar(f'{dataset_name}/'+'total_points', scene.gaussians.get_anchor.shape[0], iteration) 380 | torch.cuda.empty_cache() 381 | 382 | scene.gaussians.train() 383 | 384 | 385 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 386 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 387 | error_path = os.path.join(model_path, name, "ours_{}".format(iteration), "errors") 388 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 389 | 390 | makedirs(render_path, exist_ok=True) 391 | makedirs(error_path, exist_ok=True) 392 | makedirs(gts_path, exist_ok=True) 393 | 394 | t_list = [] 395 | visible_count_list = [] 396 | name_list = [] 397 | per_view_dict = {} 398 | psnr_list = [] 399 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 400 | 401 | torch.cuda.synchronize(); t_start = time.time() 402 | voxel_visible_mask = prefilter_voxel(view, gaussians, pipeline, background) 403 | render_pkg = render(view, gaussians, pipeline, background, visible_mask=voxel_visible_mask) 404 | torch.cuda.synchronize(); t_end = time.time() 405 | 406 | t_list.append(t_end - t_start) 407 | 408 | # renders 409 | rendering = torch.clamp(render_pkg["render"], 0.0, 1.0) 410 | visible_count = (render_pkg["radii"] > 0).sum() 411 | visible_count_list.append(visible_count) 412 | 413 | # gts 414 | gt = view.original_image[0:3, :, :] 415 | 416 | # 417 | gt_image = torch.clamp(view.original_image.to("cuda"), 0.0, 1.0) 418 | render_image = torch.clamp(rendering.to("cuda"), 0.0, 1.0) 419 | psnr_view = psnr(render_image, gt_image).mean().double() 420 | psnr_list.append(psnr_view) 421 | 422 | # error maps 423 | errormap = (rendering - gt).abs() 424 | 425 | 426 | name_list.append('{0:05d}'.format(idx) + ".png") 427 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 428 | torchvision.utils.save_image(errormap, os.path.join(error_path, '{0:05d}'.format(idx) + ".png")) 429 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 430 | per_view_dict['{0:05d}'.format(idx) + ".png"] = visible_count.item() 431 | 432 | with open(os.path.join(model_path, name, "ours_{}".format(iteration), "per_view_count.json"), 'w') as fp: 433 | json.dump(per_view_dict, fp, indent=True) 434 | 435 | print('testing_float_psnr=:', sum(psnr_list) / len(psnr_list)) 436 | 437 | return t_list, visible_count_list 438 | 439 | 440 | def render_sets(args_param, dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train=True, skip_test=False, wandb=None, tb_writer=None, dataset_name=None, logger=None, x_bound_min=None, x_bound_max=None): 441 | with torch.no_grad(): 442 | gaussians = GaussianModel( 443 | dataset.feat_dim, 444 | dataset.n_offsets, 445 | dataset.voxel_size, 446 | dataset.update_depth, 447 | dataset.update_init_factor, 448 | dataset.update_hierachy_factor, 449 | dataset.use_feat_bank, 450 | n_features_per_level=args_param.n_features, 451 | log2_hashmap_size=args_param.log2, 452 | log2_hashmap_size_2D=args_param.log2_2D, 453 | decoded_version=run_codec, 454 | ) 455 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 456 | gaussians.eval() 457 | if x_bound_min is not None: 458 | gaussians.x_bound_min = x_bound_min 459 | gaussians.x_bound_max = x_bound_max 460 | 461 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 462 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 463 | 464 | if not skip_train: 465 | t_train_list, _ = render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 466 | train_fps = 1.0 / torch.tensor(t_train_list[5:]).mean() 467 | logger.info(f'Train FPS: \033[1;35m{train_fps.item():.5f}\033[0m') 468 | if wandb is not None: 469 | wandb.log({"train_fps":train_fps.item(), }) 470 | 471 | if not skip_test: 472 | t_test_list, visible_count = render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 473 | test_fps = 1.0 / torch.tensor(t_test_list[5:]).mean() 474 | logger.info(f'Test FPS: \033[1;35m{test_fps.item():.5f}\033[0m') 475 | if tb_writer: 476 | tb_writer.add_scalar(f'{dataset_name}/test_FPS', test_fps.item(), 0) 477 | if wandb is not None: 478 | wandb.log({"test_fps":test_fps, }) 479 | 480 | return visible_count 481 | 482 | 483 | def readImages(renders_dir, gt_dir): 484 | renders = [] 485 | gts = [] 486 | image_names = [] 487 | for fname in os.listdir(renders_dir): 488 | render = Image.open(renders_dir / fname) 489 | gt = Image.open(gt_dir / fname) 490 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 491 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 492 | image_names.append(fname) 493 | return renders, gts, image_names 494 | 495 | 496 | def evaluate(model_paths, visible_count=None, wandb=None, tb_writer=None, dataset_name=None, logger=None): 497 | 498 | full_dict = {} 499 | per_view_dict = {} 500 | full_dict_polytopeonly = {} 501 | per_view_dict_polytopeonly = {} 502 | print("") 503 | 504 | scene_dir = model_paths 505 | full_dict[scene_dir] = {} 506 | per_view_dict[scene_dir] = {} 507 | full_dict_polytopeonly[scene_dir] = {} 508 | per_view_dict_polytopeonly[scene_dir] = {} 509 | 510 | test_dir = Path(scene_dir) / "test" 511 | 512 | for method in os.listdir(test_dir): 513 | 514 | full_dict[scene_dir][method] = {} 515 | per_view_dict[scene_dir][method] = {} 516 | full_dict_polytopeonly[scene_dir][method] = {} 517 | per_view_dict_polytopeonly[scene_dir][method] = {} 518 | 519 | method_dir = test_dir / method 520 | gt_dir = method_dir/ "gt" 521 | renders_dir = method_dir / "renders" 522 | renders, gts, image_names = readImages(renders_dir, gt_dir) 523 | 524 | ssims = [] 525 | psnrs = [] 526 | lpipss = [] 527 | 528 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 529 | ssims.append(ssim(renders[idx], gts[idx])) 530 | psnrs.append(psnr(renders[idx], gts[idx])) 531 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 532 | 533 | if wandb is not None: 534 | wandb.log({"test_SSIMS":torch.stack(ssims).mean().item(), }) 535 | wandb.log({"test_PSNR_final":torch.stack(psnrs).mean().item(), }) 536 | wandb.log({"test_LPIPS":torch.stack(lpipss).mean().item(), }) 537 | 538 | logger.info(f"model_paths: \033[1;35m{model_paths}\033[0m") 539 | logger.info(" SSIM : \033[1;35m{:>12.7f}\033[0m".format(torch.tensor(ssims).mean(), ".5")) 540 | logger.info(" PSNR : \033[1;35m{:>12.7f}\033[0m".format(torch.tensor(psnrs).mean(), ".5")) 541 | logger.info(" LPIPS: \033[1;35m{:>12.7f}\033[0m".format(torch.tensor(lpipss).mean(), ".5")) 542 | print("") 543 | 544 | 545 | if tb_writer: 546 | tb_writer.add_scalar(f'{dataset_name}/SSIM', torch.tensor(ssims).mean().item(), 0) 547 | tb_writer.add_scalar(f'{dataset_name}/PSNR', torch.tensor(psnrs).mean().item(), 0) 548 | tb_writer.add_scalar(f'{dataset_name}/LPIPS', torch.tensor(lpipss).mean().item(), 0) 549 | 550 | tb_writer.add_scalar(f'{dataset_name}/VISIBLE_NUMS', torch.tensor(visible_count).mean().item(), 0) 551 | 552 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 553 | "PSNR": torch.tensor(psnrs).mean().item(), 554 | "LPIPS": torch.tensor(lpipss).mean().item()}) 555 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 556 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 557 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}, 558 | "VISIBLE_COUNT": {name: vc for vc, name in zip(torch.tensor(visible_count).tolist(), image_names)}}) 559 | 560 | with open(scene_dir + "/results.json", 'w') as fp: 561 | json.dump(full_dict[scene_dir], fp, indent=True) 562 | with open(scene_dir + "/per_view.json", 'w') as fp: 563 | json.dump(per_view_dict[scene_dir], fp, indent=True) 564 | 565 | def get_logger(path): 566 | import logging 567 | 568 | logger = logging.getLogger() 569 | logger.setLevel(logging.INFO) 570 | fileinfo = logging.FileHandler(os.path.join(path, "outputs.log")) 571 | fileinfo.setLevel(logging.INFO) 572 | controlshow = logging.StreamHandler() 573 | controlshow.setLevel(logging.INFO) 574 | formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s") 575 | fileinfo.setFormatter(formatter) 576 | controlshow.setFormatter(formatter) 577 | 578 | logger.addHandler(fileinfo) 579 | logger.addHandler(controlshow) 580 | 581 | return logger 582 | 583 | if __name__ == "__main__": 584 | # Set up command line argument parser 585 | parser = ArgumentParser(description="Training script parameters") 586 | lp = ModelParams(parser) 587 | op = OptimizationParams(parser) 588 | pp = PipelineParams(parser) 589 | parser.add_argument('--ip', type=str, default="127.0.0.1") 590 | parser.add_argument('--port', type=int, default=6009) 591 | parser.add_argument('--debug_from', type=int, default=-1) 592 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 593 | parser.add_argument('--warmup', action='store_true', default=False) 594 | parser.add_argument('--use_wandb', action='store_true', default=False) 595 | # parser.add_argument("--test_iterations", nargs="+", type=int, default=[11_000, 15_000, 20_000, 25_000, 29_000, 30_000]) 596 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[30_000]) 597 | # parser.add_argument("--save_iterations", nargs="+", type=int, default=[11_000, 15_000, 20_000, 25_000, 29_000, 30_000]) 598 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[30_000]) 599 | parser.add_argument("--quiet", action="store_true") 600 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 601 | parser.add_argument("--start_checkpoint", type=str, default = None) 602 | parser.add_argument("--gpu", type=str, default = '-1') 603 | parser.add_argument("--log2", type=int, default = 13) 604 | parser.add_argument("--log2_2D", type=int, default = 15) 605 | parser.add_argument("--n_features", type=int, default = 4) 606 | parser.add_argument("--lmbda", type=float, default = 0.001) 607 | args = parser.parse_args(sys.argv[1:]) 608 | args.save_iterations.append(args.iterations) 609 | 610 | 611 | # enable logging 612 | 613 | model_path = args.model_path 614 | os.makedirs(model_path, exist_ok=True) 615 | 616 | logger = get_logger(model_path) 617 | 618 | 619 | logger.info(f'args: {args}') 620 | 621 | if args.gpu != '-1': 622 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 623 | os.system("echo $CUDA_VISIBLE_DEVICES") 624 | logger.info(f'using GPU {args.gpu}') 625 | 626 | '''try: 627 | saveRuntimeCode(os.path.join(args.model_path, 'backup')) 628 | except: 629 | logger.info(f'save code failed~')''' 630 | 631 | dataset = args.source_path.split('/')[-1] 632 | exp_name = args.model_path.split('/')[-2] 633 | 634 | if args.use_wandb: 635 | wandb.login() 636 | run = wandb.init( 637 | # Set the project where this run will be logged 638 | project=f"Scaffold-GS-{dataset}", 639 | name=exp_name, 640 | # Track hyperparameters and run metadata 641 | settings=wandb.Settings(start_method="fork"), 642 | config=vars(args) 643 | ) 644 | else: 645 | wandb = None 646 | 647 | logger.info("Optimizing " + args.model_path) 648 | 649 | # Initialize system state (RNG) 650 | safe_state(args.quiet) 651 | 652 | # Start GUI server, configure and run training 653 | args.port = np.random.randint(10000, 20000) 654 | # network_gui.init(args.ip, args.port) 655 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 656 | 657 | # training 658 | x_bound_min, x_bound_max = training(args, lp.extract(args), op.extract(args), pp.extract(args), dataset, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, wandb, logger) 659 | if args.warmup: 660 | logger.info("\n Warmup finished! Reboot from last checkpoints") 661 | new_ply_path = os.path.join(args.model_path, f'point_cloud/iteration_{args.iterations}', 'point_cloud.ply') 662 | x_bound_min, x_bound_max = training(args, lp.extract(args), op.extract(args), pp.extract(args), dataset, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, wandb=wandb, logger=logger, ply_path=new_ply_path) 663 | 664 | # All done 665 | logger.info("\nTraining complete.") 666 | 667 | # rendering 668 | logger.info(f'\nStarting Rendering~') 669 | visible_count = render_sets(args, lp.extract(args), -1, pp.extract(args), wandb=wandb, logger=logger, x_bound_min=x_bound_min, x_bound_max=x_bound_max) 670 | logger.info("\nRendering complete.") 671 | 672 | # calc metrics 673 | logger.info("\n Starting evaluation...") 674 | evaluate(args.model_path, visible_count=visible_count, wandb=wandb, logger=logger) 675 | logger.info("\nEvaluating complete.") 676 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 24 | else: # should be a type that converts to float 25 | if args.resolution == -1: 26 | if orig_w > 1600: 27 | global WARNED 28 | if not WARNED: 29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 31 | WARNED = True 32 | global_down = orig_w / 1600 33 | else: 34 | global_down = 1 35 | else: 36 | global_down = orig_w / args.resolution 37 | 38 | scale = float(global_down) * float(resolution_scale) 39 | resolution = (int(orig_w / scale), int(orig_h / scale)) 40 | 41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 42 | 43 | gt_image = resized_image_rgb[:3, ...] 44 | loaded_mask = None 45 | 46 | # print(f'gt_image: {gt_image.shape}') 47 | if resized_image_rgb.shape[1] == 4: 48 | loaded_mask = resized_image_rgb[3:4, ...] 49 | 50 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 51 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 52 | image=gt_image, gt_alpha_mask=loaded_mask, 53 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 54 | 55 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 56 | camera_list = [] 57 | 58 | for id, c in enumerate(cam_infos): 59 | camera_list.append(loadCam(args, id, c, resolution_scale)) 60 | 61 | return camera_list 62 | 63 | def camera_to_JSON(id, camera : Camera): 64 | Rt = np.zeros((4, 4)) 65 | Rt[:3, :3] = camera.R.transpose() 66 | Rt[:3, 3] = camera.T 67 | Rt[3, 3] = 1.0 68 | 69 | W2C = np.linalg.inv(Rt) 70 | pos = W2C[:3, 3] 71 | rot = W2C[:3, :3] 72 | serializable_array_2d = [x.tolist() for x in rot] 73 | camera_entry = { 74 | 'id' : id, 75 | 'img_name' : camera.image_name, 76 | 'width' : camera.width, 77 | 'height' : camera.height, 78 | 'position': pos.tolist(), 79 | 'rotation': serializable_array_2d, 80 | 'fy' : fov2focal(camera.FovY, camera.height), 81 | 'fx' : fov2focal(camera.FovX, camera.width) 82 | } 83 | return camera_entry 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils/encodings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.cuda.amp import custom_bwd, custom_fwd 5 | import numpy as np 6 | import math 7 | import multiprocessing 8 | 9 | import _gridencoder as _backend 10 | 11 | anchor_round_digits = 16 12 | Q_anchor = 1/(2 ** anchor_round_digits - 1) 13 | use_clamp = True 14 | use_multiprocessor = False # Always False plz. Not yet implemented for True. 15 | 16 | def get_binary_vxl_size(binary_vxl): 17 | # binary_vxl: {0, 1} 18 | # assert torch.unique(binary_vxl).mean() == 0.5 19 | ttl_num = binary_vxl.numel() 20 | 21 | pos_num = torch.sum(binary_vxl) 22 | neg_num = ttl_num - pos_num 23 | 24 | Pg = pos_num / ttl_num # + 1e-6 25 | Pg = torch.clamp(Pg, min=1e-6, max=1-1e-6) 26 | pos_prob = Pg 27 | neg_prob = (1 - Pg) 28 | pos_bit = pos_num * (-torch.log2(pos_prob)) 29 | neg_bit = neg_num * (-torch.log2(neg_prob)) 30 | ttl_bit = pos_bit + neg_bit 31 | ttl_bit += 32 # Pg 32 | # print('binary_vxl:', Pg.item(), ttl_bit.item(), ttl_num, pos_num.item(), neg_num.item()) 33 | return Pg, ttl_bit, ttl_bit.item()/8.0/1024/1024, ttl_num 34 | 35 | class STE_binary(torch.autograd.Function): 36 | @staticmethod 37 | def forward(ctx, input): 38 | ctx.save_for_backward(input) 39 | input = torch.clamp(input, min=-1, max=1) 40 | # out = torch.sign(input) 41 | p = (input >= 0) * (+1.0) 42 | n = (input < 0) * (-1.0) 43 | out = p + n 44 | return out 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | # mask: to ensure x belongs to (-1, 1) 48 | input, = ctx.saved_tensors 49 | i2 = input.clone().detach() 50 | i3 = torch.clamp(i2, -1, 1) 51 | mask = (i3 == i2) + 0.0 52 | return grad_output * mask 53 | 54 | 55 | class STE_multistep(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, input, Q, input_mean=None): 58 | if use_clamp: 59 | if input_mean is None: 60 | input_mean = input.mean() 61 | input_min = input_mean - 15_000 * Q 62 | input_max = input_mean + 15_000 * Q 63 | input = torch.clamp(input, min=input_min.detach(), max=input_max.detach()) 64 | 65 | Q_round = torch.round(input / Q) 66 | Q_q = Q_round * Q 67 | return Q_q 68 | @staticmethod 69 | def backward(ctx, grad_output): 70 | return grad_output, None 71 | 72 | 73 | class Quantize_anchor(torch.autograd.Function): 74 | @staticmethod 75 | def forward(ctx, anchors, min_v, max_v): 76 | # if anchor_round_digits == 32: 77 | # return anchors 78 | # min_v = torch.min(anchors).detach() 79 | # max_v = torch.max(anchors).detach() 80 | # scales = 2 ** anchor_round_digits - 1 81 | interval = ((max_v - min_v) * Q_anchor + 1e-6) # avoid 0, if max_v == min_v 82 | # quantized_v = (anchors - min_v) // interval 83 | quantized_v = torch.div(anchors - min_v, interval, rounding_mode='floor') 84 | quantized_v = torch.clamp(quantized_v, 0, 2 ** anchor_round_digits - 1) 85 | anchors_q = quantized_v * interval + min_v 86 | return anchors_q, quantized_v 87 | @staticmethod 88 | def backward(ctx, grad_output, tmp): # tmp is for quantized_v:) 89 | return grad_output, None, None 90 | 91 | 92 | class _grid_encode(Function): 93 | @staticmethod 94 | @custom_fwd 95 | def forward(ctx, inputs, embeddings, offsets_list, resolutions_list, calc_grad_inputs=False, min_level_id=None, n_levels_calc=1, binary_vxl=None, PV=0): 96 | # inputs: [N, num_dim], float in [0, 1] 97 | # embeddings: [sO, n_features], float. self.params = nn.Parameter(torch.empty(offset, n_features)) 98 | # offsets_list: [n_levels + 1], int 99 | # RETURN: [N, F], float 100 | inputs = inputs.contiguous() 101 | # embeddings_mask = torch.ones(size=[embeddings.shape[0]], dtype=torch.bool, device='cuda') 102 | # print('kkkkkkkkkk---000000000:', embeddings_mask.shape, embeddings.shape, embeddings_mask.sum()) 103 | 104 | Rb = 128 105 | if binary_vxl is not None: 106 | binary_vxl = binary_vxl.contiguous() 107 | Rb = binary_vxl.shape[-1] 108 | assert len(binary_vxl.shape) == inputs.shape[-1] 109 | 110 | N, num_dim = inputs.shape # batch size, coord dim # N_rays, 3 111 | n_levels = offsets_list.shape[0] - 1 # level # 层数=16 112 | n_features = embeddings.shape[1] # embedding dim for each level # 就是channel数=2 113 | 114 | max_level_id = min_level_id + n_levels_calc 115 | 116 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 117 | # if n_features % 2 != 0, force float, since half for atomicAdd is very slow. 118 | if torch.is_autocast_enabled() and n_features % 2 == 0: 119 | embeddings = embeddings.to(torch.half) 120 | 121 | # n_levels first, optimize cache for cuda kernel, but needs an extra permute later 122 | outputs = torch.empty(n_levels_calc, N, n_features, device=inputs.device, dtype=embeddings.dtype) # 创建一个buffer给cuda填充 123 | # outputs = [hash层数=16, N_rays, channels=2] 124 | 125 | # zero init if we only calculate partial levels 126 | # if n_levels_calc < n_levels: outputs.zero_() 127 | if calc_grad_inputs: # inputs.requires_grad 128 | dy_dx = torch.empty(N, n_levels_calc * num_dim * n_features, device=inputs.device, dtype=embeddings.dtype) 129 | else: 130 | dy_dx = None 131 | 132 | # assert embeddings.shape[0] == embeddings_mask.shape[0] 133 | # assert embeddings_mask.dtype == torch.bool 134 | 135 | if isinstance(min_level_id, int): 136 | _backend.grid_encode_forward( 137 | inputs, 138 | embeddings, 139 | # embeddings_mask, 140 | offsets_list[min_level_id:max_level_id+1], 141 | resolutions_list[min_level_id:max_level_id], 142 | outputs, 143 | N, num_dim, n_features, n_levels_calc, 0, Rb, PV, 144 | dy_dx, 145 | binary_vxl, 146 | None 147 | ) 148 | else: 149 | _backend.grid_encode_forward( 150 | inputs, 151 | embeddings, 152 | # embeddings_mask, 153 | offsets_list, 154 | resolutions_list, 155 | outputs, 156 | N, num_dim, n_features, n_levels_calc, 0, Rb, PV, 157 | dy_dx, 158 | binary_vxl, 159 | min_level_id 160 | ) 161 | 162 | # permute back to [N, n_levels * n_features] # [N_rays, hash层数=16 * channels=2] 163 | outputs = outputs.permute(1, 0, 2).reshape(N, n_levels_calc * n_features) 164 | 165 | ctx.save_for_backward(inputs, embeddings, offsets_list, resolutions_list, dy_dx, binary_vxl) 166 | ctx.dims = [N, num_dim, n_features, n_levels_calc, min_level_id, max_level_id, Rb, PV] # min_level_id是否要单独save为tensor 167 | 168 | return outputs 169 | 170 | @staticmethod 171 | #@once_differentiable 172 | @custom_bwd 173 | def backward(ctx, grad): 174 | 175 | inputs, embeddings, offsets_list, resolutions_list, dy_dx, binary_vxl = ctx.saved_tensors 176 | N, num_dim, n_features, n_levels_calc, min_level_id, max_level_id, Rb, PV = ctx.dims 177 | 178 | # grad: [N, n_levels * n_features] --> [n_levels, N, n_features] 179 | grad = grad.view(N, n_levels_calc, n_features).permute(1, 0, 2).contiguous() 180 | 181 | grad_embeddings = torch.zeros_like(embeddings) 182 | 183 | if dy_dx is not None: 184 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 185 | else: 186 | grad_inputs = None 187 | 188 | if isinstance(min_level_id, int): 189 | _backend.grid_encode_backward( 190 | grad, 191 | inputs, 192 | embeddings, 193 | # embeddings_mask, 194 | offsets_list[min_level_id:max_level_id+1], 195 | resolutions_list[min_level_id:max_level_id], 196 | grad_embeddings, 197 | N, num_dim, n_features, n_levels_calc, 0, Rb, 198 | dy_dx, 199 | grad_inputs, 200 | binary_vxl, 201 | None 202 | ) 203 | else: 204 | _backend.grid_encode_backward( 205 | grad, 206 | inputs, 207 | embeddings, 208 | # embeddings_mask, 209 | offsets_list, 210 | resolutions_list, 211 | grad_embeddings, 212 | N, num_dim, n_features, n_levels_calc, 0, Rb, 213 | dy_dx, 214 | grad_inputs, 215 | binary_vxl, 216 | min_level_id 217 | ) 218 | 219 | if dy_dx is not None: 220 | grad_inputs = grad_inputs.to(inputs.dtype) 221 | 222 | return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None 223 | grid_encode = _grid_encode.apply 224 | class GridEncoder(nn.Module): 225 | def __init__(self, 226 | num_dim=3, 227 | n_features=2, 228 | resolutions_list=(16, 23, 32, 46, 64, 92, 128, 184, 256, 368, 512, 736), 229 | log2_hashmap_size=19, 230 | ste_binary = True, 231 | ste_multistep = False, 232 | add_noise = False, 233 | Q = 1 234 | ): 235 | super().__init__() 236 | 237 | resolutions_list = torch.tensor(resolutions_list).to(torch.int) 238 | n_levels = resolutions_list.numel() 239 | 240 | self.num_dim = num_dim # coord dims, 2 or 3 241 | self.n_levels = n_levels # num levels, each level multiply resolution by 2 242 | self.n_features = n_features # encode channels per level 243 | self.log2_hashmap_size = log2_hashmap_size 244 | self.output_dim = n_levels * n_features 245 | self.ste_binary = ste_binary 246 | self.ste_multistep = ste_multistep 247 | self.add_noise = add_noise 248 | self.Q = Q 249 | 250 | # allocate parameters 251 | offsets_list = [] 252 | offset = 0 253 | self.max_params = 2 ** log2_hashmap_size 254 | for i in range(n_levels): 255 | resolution = resolutions_list[i].item() 256 | params_in_level = min(self.max_params, resolution ** num_dim) # limit max number 257 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 258 | offsets_list.append(offset) 259 | offset += params_in_level 260 | offsets_list.append(offset) 261 | offsets_list = torch.from_numpy(np.array(offsets_list, dtype=np.int32)) 262 | self.register_buffer('offsets_list', offsets_list) 263 | self.register_buffer('resolutions_list', resolutions_list) 264 | 265 | self.n_params = offsets_list[-1] * n_features 266 | 267 | # parameters 268 | self.params = nn.Parameter(torch.empty(offset, n_features)) 269 | 270 | self.reset_parameters() 271 | 272 | self.n_output_dims = n_levels * n_features 273 | 274 | def reset_parameters(self): 275 | std = 1e-4 276 | self.params.data.uniform_(-std, std) 277 | 278 | def __repr__(self): 279 | return f"GridEncoder: num_dim={self.num_dim} n_levels={self.n_levels} n_features={self.n_features} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.n_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.params.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" 280 | 281 | def forward(self, inputs, min_level_id=None, max_level_id=None, test_phase=False, outspace_params=None, binary_vxl=None, PV=0): 282 | # inputs: [..., num_dim], normalized real world positions in [0, 1] 283 | # return: [..., n_levels * n_features] 284 | 285 | prefix_shape = list(inputs.shape[:-1]) 286 | inputs = inputs.view(-1, self.num_dim) 287 | 288 | if outspace_params is not None: 289 | params = nn.Parameter(outspace_params) 290 | else: 291 | params = self.params 292 | 293 | if self.ste_binary: 294 | embeddings = STE_binary.apply(params) 295 | # embeddings = params 296 | elif (self.add_noise and not test_phase): 297 | embeddings = params + (torch.rand_like(params) - 0.5) * (1 / self.Q) 298 | elif (self.ste_multistep) or (self.add_noise and test_phase): 299 | embeddings = STE_multistep.apply(params, self.Q) 300 | else: 301 | embeddings = params 302 | # embeddings = embeddings * 0 # for ablation 303 | 304 | min_level_id = 0 if min_level_id is None else max(min_level_id, 0) 305 | max_level_id = self.n_levels if max_level_id is None else min(max_level_id, self.n_levels) 306 | n_levels_calc = max_level_id - min_level_id 307 | 308 | outputs = grid_encode(inputs, embeddings, self.offsets_list, self.resolutions_list, inputs.requires_grad, min_level_id, n_levels_calc, binary_vxl, PV) 309 | outputs = outputs.view(prefix_shape + [n_levels_calc * self.n_features]) 310 | 311 | return outputs 312 | -------------------------------------------------------------------------------- /utils/encodings_cuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import arithmetic 5 | 6 | chunk_size_cuda = 10000 7 | 8 | class STE_multistep(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, input, Q, input_mean=None): 11 | Q_round = torch.round(input / Q) 12 | Q_q = Q_round * Q 13 | return Q_q 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | return grad_output, None 17 | 18 | def get_binary_vxl_size(binary_vxl): 19 | # binary_vxl: {0, 1} 20 | # assert torch.unique(binary_vxl).mean() == 0.5 21 | ttl_num = binary_vxl.numel() 22 | 23 | pos_num = torch.sum(binary_vxl) 24 | neg_num = ttl_num - pos_num 25 | 26 | Pg = pos_num / ttl_num # + 1e-6 27 | Pg = torch.clamp(Pg, min=1e-6, max=1-1e-6) 28 | pos_prob = Pg 29 | neg_prob = (1 - Pg) 30 | pos_bit = pos_num * (-torch.log2(pos_prob)) 31 | neg_bit = neg_num * (-torch.log2(neg_prob)) 32 | ttl_bit = pos_bit + neg_bit 33 | ttl_bit += 32 # Pg 34 | # print('binary_vxl:', Pg.item(), ttl_bit.item(), ttl_num, pos_num.item(), neg_num.item()) 35 | return Pg, ttl_bit, ttl_bit.item()/8.0/1024/1024, ttl_num 36 | 37 | 38 | def encoder_factorized_chunk(x, lower_func, Q:float = 1, file_name='tmp.b', chunk_size=1000_0000): 39 | # should be with 2 dimensions 40 | # lower_func: xxx._logits_cumulative 41 | assert file_name.endswith('.b') 42 | assert len(x.shape) == 2 43 | N = x.shape[0] 44 | chunks = int(np.ceil(N / chunk_size)) 45 | bit_len_list = [] 46 | for c in range(chunks): 47 | bit_len = encoder_factorized( 48 | x=x[c * chunk_size:c * chunk_size + chunk_size], 49 | lower_func=lower_func, 50 | Q=Q, 51 | file_name=file_name.replace('.b', f'_{str(c)}.b'), 52 | ) 53 | bit_len_list.append(bit_len) 54 | return sum(bit_len_list) 55 | 56 | 57 | def encoder_factorized(x, lower_func, Q:float = 1, file_name='tmp.b'): 58 | ''' 59 | The reason why int(max_value.item()) + 1 or int(max_value.item()) + 1 + 1: 60 | first 1: range does not include the last value, so +1 61 | second 1: if directly calculate, we need to use samples - 0.5, in order to include the whole value space, 62 | the max bound value after -0.5 should be max_value+0.5. 63 | 64 | Here we do not add the second 1, because we use pmf to calculate cdf, instead of directly calculate cdf 65 | 66 | example in here ("`" means sample-0.5 places, "|" means sample places): 67 | ` ` ` ` ` ` ` ` ` ` ` ` ` 68 | lkl_lower | | | | lkl_upper | | | | -> cdf_lower | | | | 69 | 70 | example in other place ("`" means sample-0.5 places, "|" means sample places): 71 | ` ` ` ` ` 72 | cdf_lower | | | | 73 | 74 | ''' 75 | # should be with 2 dimensions 76 | # lower_func: xxx._logits_cumulative 77 | assert file_name.endswith('.b') 78 | assert len(x.shape) == 2 79 | x_int_round = torch.round(x / Q) # [100] 80 | max_value = x_int_round.max() 81 | min_value = x_int_round.min() 82 | samples = torch.tensor(range(int(min_value.item()), int(max_value.item()) + 1)).to( 83 | torch.float).to(x.device) # from min_value to max_value+1. shape = [max_value+1 - min_value] 84 | samples = samples.unsqueeze(0).unsqueeze(0).repeat(x.shape[-1], 1, 1) # [256, 1, max_value+1 - min_value] 85 | # lower_func: [C, 1, N] 86 | lower = lower_func((samples - 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value] 87 | upper = lower_func((samples + 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value] 88 | sign = -torch.sign(torch.add(lower, upper)) 89 | sign = sign.detach() 90 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) # [256, 1, max_value+1 - min_value] 91 | cdf = torch.cumsum(pmf, dim=-1) 92 | lower = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [256, 1, max_value+1+1 - min_value] 93 | lower = lower.permute(1, 0, 2).contiguous().repeat(x.shape[0], 1, 1) # [100, 256, max_value+1+1 - min_value] 94 | x_int_round_idx = (x_int_round - min_value).to(torch.int16) 95 | x_int_round_idx = x_int_round_idx.view(-1) # [100*256] 96 | lower = lower.view(x_int_round_idx.shape[0], -1) # [100*256, max_value+1+1 - min_value] 97 | lower = torch.clamp(lower, min=0.0, max=1.0) 98 | assert (x_int_round_idx.to(torch.int32) == x_int_round.view(-1) - min_value).all() 99 | 100 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode( 101 | x_int_round_idx, 102 | lower, 103 | chunk_size_cuda, 104 | int(lower.shape[0]), 105 | int(lower.shape[1]) 106 | ) 107 | cnt_bytes = cnt_torch.cpu().numpy().tobytes() 108 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes() 109 | len_cnt_bytes = len(cnt_bytes) 110 | with open(file_name, 'wb') as fout: 111 | fout.write(min_value.to(torch.float32).cpu().numpy().tobytes()) 112 | fout.write(max_value.to(torch.float32).cpu().numpy().tobytes()) 113 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes()) 114 | fout.write(cnt_bytes) 115 | fout.write(byte_stream_bytes) 116 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes)) * 8 + 32 * 3 117 | return bit_len 118 | 119 | def decoder_factorized_chunk(lower_func, Q, N_len, dim, file_name='tmp.b', device='cuda', chunk_size=1000_0000): 120 | assert file_name.endswith('.b') 121 | chunks = int(np.ceil(N_len / chunk_size)) 122 | x_c_list = [] 123 | for c in range(chunks): 124 | x_c = decoder_factorized( 125 | lower_func=lower_func, 126 | Q=Q, 127 | N_len=min(chunk_size, N_len-c*chunk_size), 128 | dim=dim, 129 | file_name=file_name.replace('.b', f'_{str(c)}.b'), 130 | device=device, 131 | ) 132 | x_c_list.append(x_c) 133 | x_c_list = torch.cat(x_c_list, dim=0) 134 | return x_c_list 135 | 136 | 137 | def decoder_factorized(lower_func, Q, N_len, dim, file_name='tmp.b', device='cuda'): 138 | assert file_name.endswith('.b') 139 | 140 | with open(file_name, 'rb') as fin: 141 | min_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda") 142 | max_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda") 143 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0] 144 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda") 145 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda") 146 | 147 | samples = torch.tensor(range(int(min_value.item()), int(max_value.item()) + 1)).to( 148 | torch.float).to(device) # from min_value to max_value+1. shape = [max_value+1 - min_value] 149 | samples = samples.unsqueeze(0).unsqueeze(0).repeat(dim, 1, 1) # [256, 1, max_value+1 - min_value] 150 | 151 | # lower_func: [C, 1, N] 152 | lower = lower_func((samples - 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value] 153 | upper = lower_func((samples + 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value] 154 | sign = -torch.sign(torch.add(lower, upper)) 155 | sign = sign.detach() 156 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) # [256, 1, max_value+1 - min_value] 157 | cdf = torch.cumsum(pmf, dim=-1) 158 | lower = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [256, 1, max_value+1+1 - min_value] 159 | lower = lower.permute(1, 0, 2).contiguous().repeat(N_len, 1, 1) # [100, 256, max_value+1+1 - min_value] 160 | lower = lower.view(N_len*dim, -1) # [100*256, max_value+1+1 - min_value] 161 | lower = torch.clamp(lower, min=0.0, max=1.0) 162 | 163 | sym_out = arithmetic.arithmetic_decode( 164 | lower, 165 | byte_stream_torch, 166 | cnt_torch, 167 | chunk_size_cuda, 168 | int(lower.shape[0]), 169 | int(lower.shape[1]) 170 | ).to(device).to(torch.float32) 171 | x = sym_out + min_value 172 | x = x * Q 173 | x = x.reshape(N_len, dim) 174 | return x 175 | 176 | 177 | def encoder_gaussian_mixed_chunk(x, mean_list, scale_list, prob_list, Q, file_name='tmp.b', chunk_size=1000_0000): 178 | assert file_name.endswith('.b') 179 | assert len(x.shape) == 1 180 | x_view = x.view(-1) 181 | mean_list_view = [mean.view(-1) for mean in mean_list] 182 | scale_list_view = [scale.view(-1) for scale in scale_list] 183 | prob_list_view = [prob.view(-1) for prob in prob_list] 184 | assert x_view.shape[0]==mean_list_view[0].shape[0]==scale_list_view[0].shape[0]==prob_list_view[0].shape[0] 185 | N = x_view.shape[0] 186 | chunks = int(np.ceil(N/chunk_size)) 187 | Is_Q_tensor = isinstance(Q, torch.Tensor) 188 | if Is_Q_tensor: Q_view = Q.view(-1) 189 | bit_len_list = [] 190 | for c in range(chunks): 191 | bit_len = encoder_gaussian_mixed( 192 | x=x_view[c*chunk_size:c*chunk_size + chunk_size], 193 | mean_list=[mean[c*chunk_size:c*chunk_size + chunk_size] for mean in mean_list_view], 194 | scale_list=[scale[c*chunk_size:c*chunk_size + chunk_size] for scale in scale_list_view], 195 | prob_list=[prob[c*chunk_size:c*chunk_size + chunk_size] for prob in prob_list_view], 196 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q, 197 | file_name=file_name.replace('.b', f'_{str(c)}.b'), 198 | ) 199 | bit_len_list.append(bit_len) 200 | return sum(bit_len_list) 201 | 202 | 203 | def encoder_gaussian_mixed(x, mean_list, scale_list, prob_list, Q, file_name='tmp.b'): 204 | # should be with single dimension 205 | assert file_name.endswith('.b') 206 | assert len(x.shape) == 1 207 | if not isinstance(Q, torch.Tensor): 208 | Q = torch.tensor([Q], dtype=x.dtype, device=x.device).repeat(x.shape[0]) 209 | assert x.shape == mean_list[0].shape == scale_list[0].shape == prob_list[0].shape == Q.shape, f'{x.shape}, {mean_list[0].shape}, {scale_list[0].shape}, {prob_list[0].shape}, {Q.shape}' 210 | x_int_round = torch.round(x / Q) # [100] 211 | max_value = x_int_round.max() 212 | min_value = x_int_round.min() 213 | lower_all = int(0) 214 | for (mean, scale, prob) in zip(mean_list, scale_list, prob_list): 215 | lower = arithmetic.calculate_cdf( 216 | mean, 217 | scale, 218 | Q, 219 | min_value, 220 | max_value 221 | ) * prob.unsqueeze(-1) 222 | if isinstance(lower_all, int): 223 | lower_all = lower 224 | else: 225 | lower_all += lower 226 | lower = torch.clamp(lower_all, min=0.0, max=1.0) 227 | del mean 228 | del scale 229 | del prob 230 | 231 | x_int_round_idx = (x_int_round - min_value).to(torch.int16) 232 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode( 233 | x_int_round_idx, 234 | lower, 235 | chunk_size_cuda, 236 | int(lower.shape[0]), 237 | int(lower.shape[1]) 238 | ) 239 | cnt_bytes = cnt_torch.cpu().numpy().tobytes() 240 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes() 241 | len_cnt_bytes = len(cnt_bytes) 242 | with open(file_name, 'wb') as fout: 243 | fout.write(min_value.to(torch.float32).cpu().numpy().tobytes()) 244 | fout.write(max_value.to(torch.float32).cpu().numpy().tobytes()) 245 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes()) 246 | fout.write(cnt_bytes) 247 | fout.write(byte_stream_bytes) 248 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes))*8 + 32 * 3 249 | return bit_len 250 | 251 | 252 | def decoder_gaussian_mixed_chunk(mean_list, scale_list, prob_list, Q, file_name='tmp.b', chunk_size=1000_0000): 253 | assert file_name.endswith('.b') 254 | mean_list_view = [mean.view(-1) for mean in mean_list] 255 | scale_list_view = [scale.view(-1) for scale in scale_list] 256 | prob_list_view = [prob.view(-1) for prob in prob_list] 257 | N = mean_list_view[0].shape[0] 258 | chunks = int(np.ceil(N/chunk_size)) 259 | Is_Q_tensor = isinstance(Q, torch.Tensor) 260 | if Is_Q_tensor: Q_view = Q.view(-1) 261 | x_c_list = [] 262 | for c in range(chunks): 263 | x_c = decoder_gaussian_mixed( 264 | mean_list=[mean[c*chunk_size:c*chunk_size + chunk_size] for mean in mean_list_view], 265 | scale_list=[scale[c*chunk_size:c*chunk_size + chunk_size] for scale in scale_list_view], 266 | prob_list=[prob[c*chunk_size:c*chunk_size + chunk_size] for prob in prob_list_view], 267 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q, 268 | file_name=file_name.replace('.b', f'_{str(c)}.b'), 269 | ) 270 | x_c_list.append(x_c) 271 | x_c_list = torch.cat(x_c_list, dim=0).type_as(mean_list[0]) 272 | return x_c_list 273 | 274 | 275 | def decoder_gaussian_mixed(mean_list, scale_list, prob_list, Q, file_name='tmp.b'): 276 | assert file_name.endswith('.b') 277 | m0 = mean_list[0] 278 | if not isinstance(Q, torch.Tensor): 279 | Q = torch.tensor([Q], dtype=m0.dtype, device=m0.device).repeat(m0.shape[0]) 280 | assert mean_list[0].shape == scale_list[0].shape == prob_list[0].shape == Q.shape 281 | 282 | with open(file_name, 'rb') as fin: 283 | min_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda") 284 | max_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda") 285 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0] 286 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda") 287 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda") 288 | 289 | lower_all = int(0) 290 | for (mean, scale, prob) in zip(mean_list, scale_list, prob_list): 291 | lower = arithmetic.calculate_cdf( 292 | mean, 293 | scale, 294 | Q, 295 | min_value, 296 | max_value 297 | ) * prob.unsqueeze(-1) 298 | if isinstance(lower_all, int): 299 | lower_all = lower 300 | else: 301 | lower_all += lower 302 | lower = torch.clamp(lower_all, min=0.0, max=1.0) 303 | 304 | sym_out = arithmetic.arithmetic_decode( 305 | lower, 306 | byte_stream_torch, 307 | cnt_torch, 308 | chunk_size_cuda, 309 | int(lower.shape[0]), 310 | int(lower.shape[1]) 311 | ).to(mean.device).to(torch.float32) 312 | x = sym_out + min_value 313 | x = x * Q 314 | return x 315 | 316 | 317 | def encoder_gaussian_chunk(x, mean, scale, Q, file_name='tmp.b', chunk_size=1000_0000): 318 | assert file_name.endswith('.b') 319 | assert len(x.shape) == 1 320 | x_view = x.view(-1) 321 | mean_view = mean.view(-1) 322 | scale_view = scale.view(-1) 323 | N = x_view.shape[0] 324 | chunks = int(np.ceil(N/chunk_size)) 325 | Is_Q_tensor = isinstance(Q, torch.Tensor) 326 | if Is_Q_tensor: Q_view = Q.view(-1) 327 | bit_len_list = [] 328 | for c in range(chunks): 329 | bit_len = encoder_gaussian( 330 | x=x_view[c*chunk_size:c*chunk_size + chunk_size], 331 | mean=mean_view[c*chunk_size:c*chunk_size + chunk_size], 332 | scale=scale_view[c*chunk_size:c*chunk_size + chunk_size], 333 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q, 334 | file_name=file_name.replace('.b', f'_{str(c)}.b'), 335 | ) 336 | bit_len_list.append(bit_len) 337 | return sum(bit_len_list) 338 | 339 | 340 | def encoder_gaussian(x, mean, scale, Q, file_name='tmp.b'): 341 | # should be single dimension 342 | assert file_name.endswith('.b') 343 | assert len(x.shape) == 1 344 | if not isinstance(Q, torch.Tensor): 345 | Q = torch.tensor([Q], dtype=mean.dtype, device=mean.device).repeat(mean.shape[0]) 346 | x_int_round = torch.round(x / Q) # [100] 347 | max_value = x_int_round.max() 348 | min_value = x_int_round.min() 349 | 350 | lower = arithmetic.calculate_cdf( 351 | mean, 352 | scale, 353 | Q, 354 | min_value, 355 | max_value 356 | ) 357 | 358 | x_int_round_idx = (x_int_round - min_value).to(torch.int16) 359 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode( 360 | x_int_round_idx, 361 | lower, 362 | chunk_size_cuda, 363 | int(lower.shape[0]), 364 | int(lower.shape[1]) 365 | ) 366 | cnt_bytes = cnt_torch.cpu().numpy().tobytes() 367 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes() 368 | len_cnt_bytes = len(cnt_bytes) 369 | with open(file_name, 'wb') as fout: 370 | fout.write(min_value.to(torch.float32).cpu().numpy().tobytes()) 371 | fout.write(max_value.to(torch.float32).cpu().numpy().tobytes()) 372 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes()) 373 | fout.write(cnt_bytes) 374 | fout.write(byte_stream_bytes) 375 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes))*8 + 32 * 3 376 | return bit_len 377 | 378 | def decoder_gaussian_chunk(mean, scale, Q, file_name='tmp.b', chunk_size=1000_0000): 379 | assert file_name.endswith('.b') 380 | mean_view = mean.view(-1) 381 | scale_view = scale.view(-1) 382 | N = mean_view.shape[0] 383 | chunks = int(np.ceil(N/chunk_size)) 384 | Is_Q_tensor = isinstance(Q, torch.Tensor) 385 | if Is_Q_tensor: Q_view = Q.view(-1) 386 | x_c_list = [] 387 | for c in range(chunks): 388 | x_c = decoder_gaussian( 389 | mean=mean_view[c*chunk_size:c*chunk_size + chunk_size], 390 | scale=scale_view[c*chunk_size:c*chunk_size + chunk_size], 391 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q, 392 | file_name=file_name.replace('.b', f'_{str(c)}.b'), 393 | ) 394 | x_c_list.append(x_c) 395 | x_c_list = torch.cat(x_c_list, dim=0).type_as(mean) 396 | return x_c_list 397 | 398 | 399 | def decoder_gaussian(mean, scale, Q, file_name='tmp.b'): 400 | # should be single dimension 401 | assert file_name.endswith('.b') 402 | assert len(mean.shape) == 1 403 | assert mean.shape == scale.shape 404 | if not isinstance(Q, torch.Tensor): 405 | Q = torch.tensor([Q], dtype=mean.dtype, device=mean.device).repeat(mean.shape[0]) 406 | 407 | with open(file_name, 'rb') as fin: 408 | min_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda") 409 | max_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda") 410 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0] 411 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda") 412 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda") 413 | 414 | lower = arithmetic.calculate_cdf( 415 | mean, 416 | scale, 417 | Q, 418 | min_value, 419 | max_value 420 | ) 421 | 422 | sym_out = arithmetic.arithmetic_decode( 423 | lower, 424 | byte_stream_torch, 425 | cnt_torch, 426 | chunk_size_cuda, 427 | int(lower.shape[0]), 428 | int(lower.shape[1]) 429 | ).to(mean.device).to(torch.float32) 430 | x = sym_out + min_value 431 | x = x * Q 432 | return x 433 | 434 | 435 | def encoder(x, file_name='tmp.b'): 436 | # x: 0 or 1 437 | assert file_name[-2:] == '.b' 438 | x = x.detach().view(-1) 439 | p = torch.zeros_like(x).to(torch.float32) 440 | prob_1 = x.sum() / x.numel() 441 | p[...] = prob_1 442 | p_u = 1 - p.unsqueeze(-1) 443 | p_0 = torch.zeros_like(p_u) 444 | p_1 = torch.ones_like(p_u) 445 | # Encode to bytestream. 446 | output_cdf = torch.cat([p_0, p_u, p_1], dim=-1) 447 | sym = torch.floor(x).to(torch.int16) 448 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode( 449 | sym, 450 | output_cdf, 451 | chunk_size_cuda, 452 | int(output_cdf.shape[0]), 453 | int(output_cdf.shape[1]) 454 | ) 455 | cnt_bytes = cnt_torch.cpu().numpy().tobytes() 456 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes() 457 | len_cnt_bytes = len(cnt_bytes) 458 | with open(file_name, 'wb') as fout: 459 | fout.write(prob_1.to(torch.float32).cpu().numpy().tobytes()) 460 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes()) 461 | fout.write(cnt_bytes) 462 | fout.write(byte_stream_bytes) 463 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes)) * 8 + 32 * 2 464 | return bit_len 465 | 466 | 467 | def decoder(N_len, file_name='tmp.b', device='cuda'): 468 | assert file_name[-2:] == '.b' 469 | 470 | with open(file_name, 'rb') as fin: 471 | prob_1 = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy()) 472 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0] 473 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda") 474 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda") 475 | p = torch.zeros(size=[N_len], dtype=torch.float32, device="cuda") 476 | p[...] = prob_1 477 | p_u = 1 - p.unsqueeze(-1) 478 | p_0 = torch.zeros_like(p_u) 479 | p_1 = torch.ones_like(p_u) 480 | # Encode to bytestream. 481 | output_cdf = torch.cat([p_0, p_u, p_1], dim=-1) 482 | # Read from a file. 483 | # Decode from bytestream. 484 | sym_out = arithmetic.arithmetic_decode( 485 | output_cdf, 486 | byte_stream_torch, 487 | cnt_torch, 488 | chunk_size_cuda, 489 | int(output_cdf.shape[0]), 490 | int(output_cdf.shape[1]) 491 | ) 492 | return sym_out 493 | 494 | -------------------------------------------------------------------------------- /utils/entropy_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as nnf 4 | import numpy as np 5 | from torch.distributions.uniform import Uniform 6 | from utils.encodings import use_clamp 7 | 8 | class Entropy_gaussian_clamp(nn.Module): 9 | def __init__(self, Q=1): 10 | super(Entropy_gaussian_clamp, self).__init__() 11 | self.Q = Q 12 | def forward(self, x, mean, scale, Q=None): 13 | if Q is None: 14 | Q = self.Q 15 | if use_clamp: 16 | x_mean = x.mean() 17 | x_min = x_mean - 15_000 * Q 18 | x_max = x_mean + 15_000 * Q 19 | x = torch.clamp(x, min=x_min.detach(), max=x_max.detach()) 20 | scale = torch.clamp(scale, min=1e-9) 21 | m1 = torch.distributions.normal.Normal(mean, scale) 22 | lower = m1.cdf(x - 0.5*Q) 23 | upper = m1.cdf(x + 0.5*Q) 24 | likelihood = torch.abs(upper - lower) 25 | likelihood = Low_bound.apply(likelihood) 26 | bits = -torch.log2(likelihood) 27 | return bits 28 | 29 | 30 | class Entropy_gaussian(nn.Module): 31 | def __init__(self, Q=1): 32 | super(Entropy_gaussian, self).__init__() 33 | self.Q = Q 34 | def forward(self, x, mean, scale, Q=None, x_mean=None): 35 | if Q is None: 36 | Q = self.Q 37 | if use_clamp: 38 | if x_mean is None: 39 | x_mean = x.mean() 40 | x_min = x_mean - 15_000 * Q 41 | x_max = x_mean + 15_000 * Q 42 | x = torch.clamp(x, min=x_min.detach(), max=x_max.detach()) 43 | scale = torch.clamp(scale, min=1e-9) 44 | m1 = torch.distributions.normal.Normal(mean, scale) 45 | lower = m1.cdf(x - 0.5*Q) 46 | upper = m1.cdf(x + 0.5*Q) 47 | likelihood = torch.abs(upper - lower) 48 | likelihood = Low_bound.apply(likelihood) 49 | bits = -torch.log2(likelihood) 50 | return bits 51 | 52 | 53 | class Entropy_bernoulli(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | def forward(self, x, p): 57 | # p = torch.sigmoid(p) 58 | p = torch.clamp(p, min=1e-6, max=1 - 1e-6) 59 | pos_mask = (1 + x) / 2.0 # 1 -> 1, -1 -> 0 60 | neg_mask = (1 - x) / 2.0 # -1 -> 1, 1 -> 0 61 | pos_prob = p 62 | neg_prob = 1 - p 63 | param_bit = -torch.log2(pos_prob) * pos_mask + -torch.log2(neg_prob) * neg_mask 64 | return param_bit 65 | 66 | 67 | class Entropy_factorized(nn.Module): 68 | def __init__(self, channel=32, init_scale=10, filters=(3, 3, 3), likelihood_bound=1e-6, 69 | tail_mass=1e-9, optimize_integer_offset=True, Q=1): 70 | super(Entropy_factorized, self).__init__() 71 | self.filters = tuple(int(t) for t in filters) 72 | self.init_scale = float(init_scale) 73 | self.likelihood_bound = float(likelihood_bound) 74 | self.tail_mass = float(tail_mass) 75 | self.optimize_integer_offset = bool(optimize_integer_offset) 76 | self.Q = Q 77 | if not 0 < self.tail_mass < 1: 78 | raise ValueError( 79 | "`tail_mass` must be between 0 and 1") 80 | filters = (1,) + self.filters + (1,) 81 | scale = self.init_scale ** (1.0 / (len(self.filters) + 1)) 82 | self._matrices = nn.ParameterList([]) 83 | self._bias = nn.ParameterList([]) 84 | self._factor = nn.ParameterList([]) 85 | for i in range(len(self.filters) + 1): 86 | init = np.log(np.expm1(1.0 / scale / filters[i + 1])) 87 | self.matrix = nn.Parameter(torch.FloatTensor( 88 | channel, filters[i + 1], filters[i])) 89 | self.matrix.data.fill_(init) 90 | self._matrices.append(self.matrix) 91 | self.bias = nn.Parameter( 92 | torch.FloatTensor(channel, filters[i + 1], 1)) 93 | noise = np.random.uniform(-0.5, 0.5, self.bias.size()) 94 | noise = torch.FloatTensor(noise) 95 | self.bias.data.copy_(noise) 96 | self._bias.append(self.bias) 97 | if i < len(self.filters): 98 | self.factor = nn.Parameter( 99 | torch.FloatTensor(channel, filters[i + 1], 1)) 100 | self.factor.data.fill_(0.0) 101 | self._factor.append(self.factor) 102 | 103 | def _logits_cumulative(self, logits, stop_gradient): 104 | for i in range(len(self.filters) + 1): 105 | matrix = nnf.softplus(self._matrices[i]) 106 | if stop_gradient: 107 | matrix = matrix.detach() 108 | # print('dqnwdnqwdqwdqwf:', matrix.shape, logits.shape) 109 | logits = torch.matmul(matrix, logits) 110 | bias = self._bias[i] 111 | if stop_gradient: 112 | bias = bias.detach() 113 | logits += bias 114 | if i < len(self._factor): 115 | factor = nnf.tanh(self._factor[i]) 116 | if stop_gradient: 117 | factor = factor.detach() 118 | logits += factor * nnf.tanh(logits) 119 | return logits 120 | 121 | def forward(self, x, Q=None): 122 | # x: [N, C], quantized 123 | if Q is None: 124 | Q = self.Q 125 | else: 126 | Q = Q.permute(1, 0).contiguous() 127 | x = x.permute(1, 0).contiguous() # [C, N] 128 | # print('dqwdqwdqwdqwfqwf:', x.shape, Q.shape) 129 | lower = self._logits_cumulative(x - 0.5*(1/Q), stop_gradient=False) 130 | upper = self._logits_cumulative(x + 0.5*(1/Q), stop_gradient=False) 131 | sign = -torch.sign(torch.add(lower, upper)) 132 | sign = sign.detach() 133 | likelihood = torch.abs( 134 | nnf.sigmoid(sign * upper) - nnf.sigmoid(sign * lower)) 135 | likelihood = Low_bound.apply(likelihood) 136 | bits = -torch.log2(likelihood) # [C, N] 137 | bits = bits.permute(1, 0).contiguous() 138 | return bits 139 | 140 | 141 | class Low_bound(torch.autograd.Function): 142 | @staticmethod 143 | def forward(ctx, x): 144 | ctx.save_for_backward(x) 145 | x = torch.clamp(x, min=1e-6) 146 | return x 147 | 148 | @staticmethod 149 | def backward(ctx, g): 150 | x, = ctx.saved_tensors 151 | grad1 = g.clone() 152 | grad1[x < 1e-6] = 0 153 | pass_through_if = np.logical_or( 154 | x.cpu().numpy() >= 1e-6, g.cpu().numpy() < 0.0) 155 | t = torch.Tensor(pass_through_if+0.0).cuda() 156 | return grad1 * t 157 | 158 | 159 | class UniverseQuant(torch.autograd.Function): 160 | @staticmethod 161 | def forward(ctx, x): 162 | #b = np.random.uniform(-1,1) 163 | b = 0 164 | uniform_distribution = Uniform(-0.5*torch.ones(x.size()) 165 | * (2**b), 0.5*torch.ones(x.size())*(2**b)).sample().cuda() 166 | return torch.round(x+uniform_distribution)-uniform_distribution 167 | 168 | @staticmethod 169 | def backward(ctx, g): 170 | 171 | return g 172 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func_mine( 30 | lr_init, lr_final, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | def helper(step): 33 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 34 | # Disable this parameter 35 | return 0.0 36 | delay_rate = 1.0 37 | 38 | if step < 10000: 39 | t = np.clip((step - 0) / (10000 - 0), 0, 1) 40 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 41 | else: 42 | t = np.clip((step - 10000) / (30000 - 10000), 0, 1) 43 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 44 | 45 | return delay_rate * log_lerp 46 | 47 | return helper 48 | 49 | def get_expon_lr_func( 50 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000, step_sub=0, 51 | ): 52 | """ 53 | Copied from Plenoxels 54 | 55 | Continuous learning rate decay function. Adapted from JaxNeRF 56 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 57 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 58 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 59 | function of lr_delay_mult, such that the initial learning rate is 60 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 61 | to the normal learning rate when steps>lr_delay_steps. 62 | :param conf: config subtree 'lr' or similar 63 | :param max_steps: int, the number of steps during optimization. 64 | :return HoF which takes step as input 65 | """ 66 | 67 | def helper(step): 68 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 69 | # Disable this parameter 70 | return 0.0 71 | if lr_delay_steps > 0: 72 | # A kind of reverse cosine decay. 73 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 74 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 75 | ) 76 | else: 77 | delay_rate = 1.0 78 | t = np.clip((step-step_sub) / (max_steps-step_sub), 0, 1) 79 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 80 | return delay_rate * log_lerp 81 | 82 | return helper 83 | 84 | def strip_lowerdiag(L): 85 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 86 | 87 | uncertainty[:, 0] = L[:, 0, 0] 88 | uncertainty[:, 1] = L[:, 0, 1] 89 | uncertainty[:, 2] = L[:, 0, 2] 90 | uncertainty[:, 3] = L[:, 1, 1] 91 | uncertainty[:, 4] = L[:, 1, 2] 92 | uncertainty[:, 5] = L[:, 2, 2] 93 | return uncertainty 94 | 95 | def strip_symmetric(sym): 96 | return strip_lowerdiag(sym) 97 | 98 | def build_rotation(r): 99 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 100 | 101 | q = r / norm[:, None] 102 | 103 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 104 | 105 | r = q[:, 0] 106 | x = q[:, 1] 107 | y = q[:, 2] 108 | z = q[:, 3] 109 | 110 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 111 | R[:, 0, 1] = 2 * (x*y - r*z) 112 | R[:, 0, 2] = 2 * (x*z + r*y) 113 | R[:, 1, 0] = 2 * (x*y + r*z) 114 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 115 | R[:, 1, 2] = 2 * (y*z - r*x) 116 | R[:, 2, 0] = 2 * (x*z - r*y) 117 | R[:, 2, 1] = 2 * (y*z + r*x) 118 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 119 | return R 120 | 121 | def build_scaling_rotation(s, r): 122 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 123 | R = build_rotation(r) 124 | 125 | L[:,0,0] = s[:,0] 126 | L[:,1,1] = s[:,1] 127 | L[:,2,2] = s[:,2] 128 | 129 | L = R @ L 130 | return L 131 | 132 | def safe_state(silent): 133 | old_f = sys.stdout 134 | class F: 135 | def __init__(self, silent): 136 | self.silent = silent 137 | 138 | def write(self, x): 139 | if not self.silent: 140 | if x.endswith("\n"): 141 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 142 | else: 143 | old_f.write(x) 144 | 145 | def flush(self): 146 | old_f.flush() 147 | 148 | sys.stdout = F(silent) 149 | 150 | random.seed(0) 151 | np.random.seed(0) 152 | torch.manual_seed(0) 153 | torch.cuda.set_device(torch.device("cuda:0")) -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /utils/visualize_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Tuple, Type 2 | 3 | import cv2 4 | import numpy as np 5 | import torch as th 6 | import torch.nn.functional as F 7 | 8 | 9 | def add_label_centered( 10 | img: np.ndarray, 11 | text: str, 12 | font_scale: float = 1.0, 13 | thickness: int = 2, 14 | alignment: str = "top", 15 | color: Tuple[int, int, int] = (0, 255, 0), 16 | ) -> np.ndarray: 17 | font = cv2.FONT_HERSHEY_SIMPLEX 18 | textsize = cv2.getTextSize(text, font, font_scale, thickness=thickness)[0] 19 | img = img.astype(np.uint8).copy() 20 | 21 | if alignment == "top": 22 | cv2.putText( 23 | img, 24 | text, 25 | ((img.shape[1] - textsize[0]) // 2, 50), 26 | font, 27 | font_scale, 28 | color, 29 | thickness=thickness, 30 | lineType=cv2.LINE_AA, 31 | ) 32 | elif alignment == "bottom": 33 | cv2.putText( 34 | img, 35 | text, 36 | ((img.shape[1] - textsize[0]) // 2, img.shape[0] - textsize[1]), 37 | font, 38 | font_scale, 39 | color, 40 | thickness=thickness, 41 | lineType=cv2.LINE_AA, 42 | ) 43 | else: 44 | raise ValueError("Unknown text alignment") 45 | 46 | return img 47 | 48 | def tensor2rgbjet( 49 | tensor: th.Tensor, x_max: Optional[float] = None, x_min: Optional[float] = None 50 | ) -> np.ndarray: 51 | return cv2.applyColorMap(tensor2rgb(tensor, x_max=x_max, x_min=x_min), cv2.COLORMAP_JET) 52 | 53 | 54 | def tensor2rgb( 55 | tensor: th.Tensor, x_max: Optional[float] = None, x_min: Optional[float] = None 56 | ) -> np.ndarray: 57 | x = tensor.data.cpu().numpy() 58 | if x_min is None: 59 | x_min = x.min() 60 | if x_max is None: 61 | x_max = x.max() 62 | 63 | gain = 255 / np.clip(x_max - x_min, 1e-3, None) 64 | x = (x - x_min) * gain 65 | x = x.clip(0.0, 255.0) 66 | x = x.astype(np.uint8) 67 | return x 68 | 69 | 70 | def tensor2image( 71 | tensor: th.Tensor, 72 | x_max: Optional[float] = 1.0, 73 | x_min: Optional[float] = 0.0, 74 | mode: str = "rgb", 75 | mask: Optional[th.Tensor] = None, 76 | label: Optional[str] = None, 77 | ) -> np.ndarray: 78 | 79 | tensor = tensor.detach() 80 | 81 | # Apply mask 82 | if mask is not None: 83 | tensor = tensor * mask 84 | 85 | if len(tensor.size()) == 2: 86 | tensor = tensor[None] 87 | 88 | # Make three channel image 89 | assert len(tensor.size()) == 3, tensor.size() 90 | n_channels = tensor.shape[0] 91 | if n_channels == 1: 92 | tensor = tensor.repeat(3, 1, 1) 93 | elif n_channels != 3: 94 | raise ValueError(f"Unsupported number of channels {n_channels}.") 95 | 96 | # Convert to display format 97 | img = tensor.permute(1, 2, 0) 98 | 99 | if mode == "rgb": 100 | img = tensor2rgb(img, x_max=x_max, x_min=x_min) 101 | elif mode == "jet": 102 | # `cv2.applyColorMap` assumes input format in BGR 103 | img[:, :, :3] = img[:, :, [2, 1, 0]] 104 | img = tensor2rgbjet(img, x_max=x_max, x_min=x_min) 105 | # convert back to rgb 106 | img[:, :, :3] = img[:, :, [2, 1, 0]] 107 | else: 108 | raise ValueError(f"Unsupported mode {mode}.") 109 | 110 | if label is not None: 111 | img = add_label_centered(img, label) 112 | 113 | return img 114 | 115 | # d: b x 1 x H x W 116 | # screenCoords: b x 2 x H X W 117 | # focal: b x 2 x 2 118 | # princpt: b x 2 119 | # out: b x 3 x H X W 120 | def depthImgToPosCam_Batched(d, screenCoords, focal, princpt): 121 | p = screenCoords - princpt[:, :, None, None] 122 | x = (d * p[:, 0:1, :, :]) / focal[:, 0:1, 0, None, None] 123 | y = (d * p[:, 1:2, :, :]) / focal[:, 1:2, 1, None, None] 124 | return th.cat([x, y, d], dim=1) 125 | 126 | # p: b x 3 x H x W 127 | # out: b x 3 x H x W 128 | def computeNormalsFromPosCam_Batched(p): 129 | p = F.pad(p, (1, 1, 1, 1), "replicate") 130 | d0 = p[:, :, 2:, 1:-1] - p[:, :, :-2, 1:-1] 131 | d1 = p[:, :, 1:-1, 2:] - p[:, :, 1:-1, :-2] 132 | n = th.cross(d0, d1, dim=1) 133 | norm = th.norm(n, dim=1, keepdim=True) 134 | norm = norm + 1e-5 135 | norm[norm < 1e-5] = 1 # Can not backprop through this 136 | return -n / norm 137 | 138 | def visualize_normal(inputs, depth_p): 139 | # Normals 140 | uv = th.stack( 141 | th.meshgrid( 142 | th.arange(depth_p.shape[2]), th.arange(depth_p.shape[1]), indexing="xy" 143 | ), 144 | dim=0, 145 | )[None].float().cuda() 146 | position = depthImgToPosCam_Batched( 147 | depth_p[None, ...], uv, inputs["focal"], inputs["princpt"] 148 | ) 149 | normal = 0.5 * (computeNormalsFromPosCam_Batched(position) + 1.0) 150 | normal = normal[0, [2, 1, 0], :, :] # legacy code assumes BGR format 151 | normal_p = tensor2image(normal, label="normal_p") 152 | 153 | return normal_p --------------------------------------------------------------------------------