├── LICENSE.md
├── README.md
├── arguments
└── __init__.py
├── assets
├── main_performance.png
└── teaser.png
├── environment.yml
├── gaussian_renderer
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── network_gui.cpython-37.pyc
└── network_gui.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── results
├── DeepBlending
│ ├── drjohnson.csv
│ └── playroom.csv
├── MipNeRF360
│ ├── bicycle.csv
│ ├── bonsai.csv
│ ├── counter.csv
│ ├── flowers.csv
│ ├── garden.csv
│ ├── kitchen.csv
│ ├── room.csv
│ ├── stump.csv
│ └── treehill.csv
├── README.md
├── SyntheticNeRF
│ ├── chair.csv
│ ├── drums.csv
│ ├── ficus.csv
│ ├── hotdog.csv
│ ├── lego.csv
│ ├── materials.csv
│ ├── mic.csv
│ └── ship.csv
└── TanksAndTemples
│ ├── train.csv
│ └── truck.csv
├── run_shell_blender.py
├── run_shell_bungee.py
├── run_shell_db.py
├── run_shell_mip360.py
├── run_shell_tnt.py
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
└── gaussian_model.py
├── submodules
├── arithmetic.zip
├── diff-gaussian-rasterization.zip
├── gridencoder.zip
└── simple-knn.zip
├── train.py
└── utils
├── camera_utils.py
├── encodings.py
├── encodings_cuda.py
├── entropy_models.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── sh_utils.py
├── system_utils.py
└── visualize_utils.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ECCV'24] HAC
2 | Official Pytorch implementation of **HAC: Hash-grid Assisted Context for 3D Gaussian Splatting Compression**.
3 | ## Compress 3D Gaussian Splatting for 75X without fidelity drop!
4 | ## 💪 An enhanced compression method, [HAC++](https://github.com/YihangChen-ee/HAC-plus), has been released!
5 |
6 | [Yihang Chen](https://yihangchen-ee.github.io),
7 | [Qianyi Wu](https://qianyiwu.github.io),
8 | [Weiyao Lin](https://weiyaolin.github.io),
9 | [Mehrtash Harandi](https://sites.google.com/site/mehrtashharandi/),
10 | [Jianfei Cai](http://jianfei-cai.github.io)
11 |
12 | [[`Paper`](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/01178.pdf)] [[`Arxiv`](https://arxiv.org/pdf/2403.14530)] [[`Project`](https://yihangchen-ee.github.io/project_hac/)] [[`Github`](https://github.com/YihangChen-ee/HAC)]
13 |
14 | ## Links
15 | You are welcomed to check a series of works from our group on 3D radiance field representation compression as listed below:
16 | - 🎉 [CNC](https://github.com/yihangchen-ee/cnc/) [CVPR'24]: efficient NeRF compression! [[`Paper`](https://openaccess.thecvf.com/content/CVPR2024/papers/Chen_How_Far_Can_We_Compress_Instant-NGP-Based_NeRF_CVPR_2024_paper.pdf)] [[`Arxiv`](https://arxiv.org/pdf/2406.04101)] [[`Project`](https://yihangchen-ee.github.io/project_cnc/)]
17 | - 🏠 [HAC](https://github.com/yihangchen-ee/hac/) [ECCV'24]: efficient 3DGS compression! [[`Paper`](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/01178.pdf)] [[`Arxiv`](https://arxiv.org/pdf/2403.14530)] [[`Project`](https://yihangchen-ee.github.io/project_hac/)]
18 | - 💪 [HAC++](https://github.com/yihangchen-ee/hac-plus/) [ARXIV'25]: an enhanced compression method over HAC! [[`Arxiv`](https://arxiv.org/pdf/2501.12255)] [[`Project`](https://yihangchen-ee.github.io/project_hac++/)]
19 | - 🚀 [FCGS](https://github.com/yihangchen-ee/fcgs/) [ICLR'25]: fast optimization-free 3DGS compression! [[`Paper`](https://openreview.net/pdf?id=DCandSZ2F1)] [[`Arxiv`](https://arxiv.org/pdf/2410.08017)] [[`Project`](https://yihangchen-ee.github.io/project_fcgs/)]
20 | - 🪜 [PCGS](https://github.com/yihangchen-ee/pcgs/) [ARXIV'25]: progressive 3DGS compression! [[`Arxiv`](https://arxiv.org/pdf/2503.08511)] [[`Project`](https://yihangchen-ee.github.io/project_pcgs/)]
21 |
22 | ## Updates
23 | 🔥8-Aug-2024: HAC now utilizes a ```cuda-based codec``` instead of the original ```torchac```, which significantly reduces the codec runtime by over ```10``` times compared to that reported in the paper!
24 |
25 | ## Overview
26 |
27 |
29 |
30 |
31 | Our approach introduces a binary hash grid to establish continuous spatial consistencies,
32 | allowing us to unveil the inherent spatial relations of anchors through a carefully designed context model.
33 | To facilitate entropy coding, we utilize Gaussian distributions to accurately estimate the probability of each quantized attribute,
34 | where an adaptive quantization module is proposed to enable high-precision quantization of these attributes for improved fidelity restoration.
35 | Additionally, we incorporate an adaptive masking strategy to eliminate invalid Gaussians and anchors.
36 | Importantly, our work is the pioneer to explore context-based compression for 3DGS representation, resulting in a remarkable size reduction.
37 |
38 | ## Performance
39 |
40 |
42 |
43 |
44 |
45 | ## Installation
46 |
47 | We tested our code on a server with Ubuntu 20.04.1, cuda 11.8, gcc 9.4.0
48 | 1. Unzip files
49 | ```
50 | cd submodules
51 | unzip diff-gaussian-rasterization.zip
52 | unzip gridencoder.zip
53 | unzip simple-knn.zip
54 | unzip arithmetic.zip
55 | cd ..
56 | ```
57 | 2. Install environment
58 | ```
59 | conda env create --file environment.yml
60 | conda activate HAC_env
61 | ```
62 |
63 | ## Data
64 |
65 | First, create a ```data/``` folder inside the project path by
66 | ```
67 | mkdir data
68 | ```
69 |
70 | The data structure will be organised as follows:
71 |
72 | ```
73 | data/
74 | ├── dataset_name
75 | │ ├── scene1/
76 | │ │ ├── images
77 | │ │ │ ├── IMG_0.jpg
78 | │ │ │ ├── IMG_1.jpg
79 | │ │ │ ├── ...
80 | │ │ ├── sparse/
81 | │ │ └──0/
82 | │ ├── scene2/
83 | │ │ ├── images
84 | │ │ │ ├── IMG_0.jpg
85 | │ │ │ ├── IMG_1.jpg
86 | │ │ │ ├── ...
87 | │ │ ├── sparse/
88 | │ │ └──0/
89 | ...
90 | ```
91 |
92 | - For instance: `./data/blending/drjohnson/`
93 | - For instance: `./data/bungeenerf/amsterdam/`
94 | - For instance: `./data/mipnerf360/bicycle/`
95 | - For instance: `./data/nerf_synthetic/chair/`
96 | - For instance: `./data/tandt/train/`
97 |
98 |
99 | ### Public Data (We follow suggestions from [Scaffold-GS](https://github.com/city-super/Scaffold-GS))
100 |
101 | - The **BungeeNeRF** dataset is available in [Google Drive](https://drive.google.com/file/d/1nBLcf9Jrr6sdxKa1Hbd47IArQQ_X8lww/view?usp=sharing)/[百度网盘[提取码:4whv]](https://pan.baidu.com/s/1AUYUJojhhICSKO2JrmOnCA).
102 | - The **MipNeRF360** scenes are provided by the paper author [here](https://jonbarron.info/mipnerf360/). And we test on its entire 9 scenes ```bicycle, bonsai, counter, garden, kitchen, room, stump, flowers, treehill```.
103 | - The SfM datasets for **Tanks&Temples** and **Deep Blending** are hosted by 3D-Gaussian-Splatting [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip). Download and uncompress them into the ```data/``` folder.
104 |
105 | ### Custom Data
106 |
107 | For custom data, you should process the image sequences with [Colmap](https://colmap.github.io/) to obtain the SfM points and camera poses. Then, place the results into ```data/``` folder.
108 |
109 | ## Training
110 |
111 | To train scenes, we provide the following training scripts:
112 | - Tanks&Temples: ```run_shell_tnt.py```
113 | - MipNeRF360: ```run_shell_mip360.py```
114 | - BungeeNeRF: ```run_shell_bungee.py```
115 | - Deep Blending: ```run_shell_db.py```
116 | - Nerf Synthetic: ```run_shell_blender.py```
117 |
118 | run them with
119 | ```
120 | python run_shell_xxx.py
121 | ```
122 |
123 | The code will automatically run the entire process of: **training, encoding, decoding, testing**.
124 | - Training log will be recorded in `output.log` of the output directory. Results of **detailed fidelity, detailed size, detailed time** will all be recorded
125 | - Encoded bitstreams will be stored in `./bitstreams` of the output directory.
126 | - Evaluated output images will be saved in `./test/ours_30000/renders` of the output directory.
127 | - Optionally, you can change `lmbda` in these `run_shell_xxx.py` scripts to try variable bitrate.
128 | - **After training, the original model `point_cloud.ply` is losslessly compressed as `./bitstreams`. You should refer to `./bitstreams` to get the final model size, but not `point_cloud.ply`. You can even delete `point_cloud.ply` if you like :).**
129 |
130 |
131 | ## Contact
132 |
133 | - Yihang Chen: yhchen.ee@sjtu.edu.cn
134 |
135 | ## Citation
136 |
137 | If you find our work helpful, please consider citing:
138 |
139 | ```bibtex
140 | @article{hac++2025,
141 | title={HAC++: Towards 100X Compression of 3D Gaussian Splatting},
142 | author={Chen, Yihang and Wu, Qianyi and Lin, Weiyao and Harandi, Mehrtash and Cai, Jianfei},
143 | year={2025}
144 | }
145 | ```
146 | ```bibtex
147 | @inproceedings{hac2024,
148 | title={HAC: Hash-grid Assisted Context for 3D Gaussian Splatting Compression},
149 | author={Chen, Yihang and Wu, Qianyi and Lin, Weiyao and Harandi, Mehrtash and Cai, Jianfei},
150 | booktitle={European Conference on Computer Vision},
151 | year={2024}
152 | }
153 | ```
154 |
155 |
156 | ## LICENSE
157 |
158 | Please follow the LICENSE of [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting).
159 |
160 | ## Acknowledgement
161 |
162 | - We thank all authors from [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting) for presenting such an excellent work.
163 | - We thank all authors from [Scaffold-GS](https://github.com/city-super/Scaffold-GS) for presenting such an excellent work.
164 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self.feat_dim = 50
51 | self.n_offsets = 10
52 | self.voxel_size = 0.001 # if voxel_size<=0, using 1nn dist
53 | self.update_depth = 3
54 | self.update_init_factor = 16
55 | self.update_hierachy_factor = 4
56 |
57 | self.use_feat_bank = False
58 | self._source_path = ""
59 | self._model_path = ""
60 | self._images = "images"
61 | self._resolution = -1
62 | self._white_background = False
63 | self.data_device = "cuda"
64 | self.eval = True
65 | self.lod = 0
66 | super().__init__(parser, "Loading Parameters", sentinel)
67 |
68 | def extract(self, args):
69 | g = super().extract(args)
70 | g.source_path = os.path.abspath(g.source_path)
71 | return g
72 |
73 | class PipelineParams(ParamGroup):
74 | def __init__(self, parser):
75 | self.convert_SHs_python = False
76 | self.compute_cov3D_python = False
77 | self.debug = False
78 | super().__init__(parser, "Pipeline Parameters")
79 |
80 | class OptimizationParams(ParamGroup):
81 | def __init__(self, parser):
82 | self.iterations = 30_000
83 | self.position_lr_init = 0.0
84 | self.position_lr_final = 0.0
85 | self.position_lr_delay_mult = 0.01
86 | self.position_lr_max_steps = 30_000
87 |
88 | self.offset_lr_init = 0.01
89 | self.offset_lr_final = 0.0001
90 | self.offset_lr_delay_mult = 0.01
91 | self.offset_lr_max_steps = 30_000
92 |
93 | self.mask_lr_init = 0.01
94 | self.mask_lr_final = 0.0001
95 | self.mask_lr_delay_mult = 0.01
96 | self.mask_lr_max_steps = 30_000
97 |
98 | self.feature_lr = 0.0075
99 | self.opacity_lr = 0.02
100 | self.scaling_lr = 0.007
101 | self.rotation_lr = 0.002
102 |
103 | self.mlp_opacity_lr_init = 0.002
104 | self.mlp_opacity_lr_final = 0.00002
105 | self.mlp_opacity_lr_delay_mult = 0.01
106 | self.mlp_opacity_lr_max_steps = 30_000
107 |
108 | self.mlp_cov_lr_init = 0.004
109 | self.mlp_cov_lr_final = 0.004
110 | self.mlp_cov_lr_delay_mult = 0.01
111 | self.mlp_cov_lr_max_steps = 30_000
112 |
113 | self.mlp_color_lr_init = 0.008
114 | self.mlp_color_lr_final = 0.00005
115 | self.mlp_color_lr_delay_mult = 0.01
116 | self.mlp_color_lr_max_steps = 30_000
117 |
118 | self.mlp_featurebank_lr_init = 0.01
119 | self.mlp_featurebank_lr_final = 0.00001
120 | self.mlp_featurebank_lr_delay_mult = 0.01
121 | self.mlp_featurebank_lr_max_steps = 30_000
122 |
123 | self.encoding_xyz_lr_init = 0.005
124 | self.encoding_xyz_lr_final = 0.00001
125 | self.encoding_xyz_lr_delay_mult = 0.33
126 | self.encoding_xyz_lr_max_steps = 30_000
127 |
128 | self.mlp_grid_lr_init = 0.005
129 | self.mlp_grid_lr_final = 0.00001
130 | self.mlp_grid_lr_delay_mult = 0.01
131 | self.mlp_grid_lr_max_steps = 30_000
132 |
133 | self.mlp_deform_lr_init = 0.005
134 | self.mlp_deform_lr_final = 0.0005
135 | self.mlp_deform_lr_delay_mult = 0.01
136 | self.mlp_deform_lr_max_steps = 30_000
137 |
138 | self.percent_dense = 0.01
139 | self.lambda_dssim = 0.2
140 |
141 | # for anchor densification
142 | self.start_stat = 500
143 | self.update_from = 1500
144 | self.update_interval = 100
145 | self.update_until = 15_000
146 |
147 | self.min_opacity = 0.005 # 0.2
148 | self.success_threshold = 0.8
149 | self.densify_grad_threshold = 0.0002
150 |
151 | super().__init__(parser, "Optimization Parameters")
152 |
153 | def get_combined_args(parser : ArgumentParser):
154 | cmdlne_string = sys.argv[1:]
155 | cfgfile_string = "Namespace()"
156 | args_cmdline = parser.parse_args(cmdlne_string)
157 |
158 | try:
159 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
160 | print("Looking for config file in", cfgfilepath)
161 | with open(cfgfilepath) as cfg_file:
162 | print("Config file found: {}".format(cfgfilepath))
163 | cfgfile_string = cfg_file.read()
164 | except TypeError:
165 | print("Config file not found at")
166 | pass
167 | args_cfgfile = eval(cfgfile_string)
168 |
169 | merged_dict = vars(args_cfgfile).copy()
170 | for k,v in vars(args_cmdline).items():
171 | if v != None:
172 | merged_dict[k] = v
173 | return Namespace(**merged_dict)
174 |
--------------------------------------------------------------------------------
/assets/main_performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/assets/main_performance.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/assets/teaser.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: HAC_env
2 | channels:
3 | - pytorch
4 | - pyg
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - cudatoolkit=11.6
9 | - plyfile=0.8.1
10 | - python=3.7.13
11 | - pip=22.3.1
12 | - pytorch=1.12.1
13 | - torchaudio=0.12.1
14 | - torchvision=0.13.1
15 | - pytorch-scatter
16 | - tqdm
17 | - pip:
18 | - einops
19 | - wandb
20 | - lpips
21 | - submodules/diff-gaussian-rasterization
22 | - submodules/simple-knn
23 | - submodules/gridencoder
24 | - submodules/arithmetic
25 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import os.path
12 | import time
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as nnf
17 | from einops import repeat
18 |
19 | import math
20 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
21 | from scene.gaussian_model import GaussianModel
22 | from utils.encodings import STE_binary, STE_multistep
23 |
24 |
25 | def generate_neural_gaussians(viewpoint_camera, pc : GaussianModel, visible_mask=None, is_training=False, step=0):
26 | ## view frustum filtering for acceleration
27 |
28 | time_sub = 0
29 |
30 | if visible_mask is None:
31 | visible_mask = torch.ones(pc.get_anchor.shape[0], dtype=torch.bool, device = pc.get_anchor.device)
32 |
33 | anchor = pc.get_anchor[visible_mask]
34 | #
35 | feat = pc._anchor_feat[visible_mask]
36 | grid_offsets = pc._offset[visible_mask]
37 | grid_scaling = pc.get_scaling[visible_mask]
38 | binary_grid_masks = pc.get_mask[visible_mask]
39 | mask_anchor = pc.get_mask_anchor[visible_mask]
40 | mask_anchor_bool = mask_anchor.to(torch.bool)
41 | mask_anchor_rate = (mask_anchor.sum() / mask_anchor.numel()).detach()
42 |
43 | bit_per_param = None
44 | bit_per_feat_param = None
45 | bit_per_scaling_param = None
46 | bit_per_offsets_param = None
47 | Q_feat = 1
48 | Q_scaling = 0.001
49 | Q_offsets = 0.2
50 | if is_training:
51 | if step > 3000 and step <= 10000:
52 | # quantization
53 | feat = feat + torch.empty_like(feat).uniform_(-0.5, 0.5) * Q_feat
54 | grid_scaling = grid_scaling + torch.empty_like(grid_scaling).uniform_(-0.5, 0.5) * Q_scaling
55 | grid_offsets = grid_offsets + torch.empty_like(grid_offsets).uniform_(-0.5, 0.5) * Q_offsets
56 |
57 | if step == 10000:
58 | pc.update_anchor_bound()
59 |
60 | if step > 10000:
61 | feat_context = pc.calc_interp_feat(anchor)
62 | feat_context = pc.get_grid_mlp(feat_context)
63 | mean, scale, mean_scaling, scale_scaling, mean_offsets, scale_offsets, Q_feat_adj, Q_scaling_adj, Q_offsets_adj = \
64 | torch.split(feat_context, split_size_or_sections=[pc.feat_dim, pc.feat_dim, 6, 6, 3*pc.n_offsets, 3*pc.n_offsets, 1, 1, 1], dim=-1)
65 |
66 | Q_feat = Q_feat * (1 + torch.tanh(Q_feat_adj))
67 | Q_scaling = Q_scaling * (1 + torch.tanh(Q_scaling_adj))
68 | Q_offsets = Q_offsets * (1 + torch.tanh(Q_offsets_adj))
69 | feat = feat + torch.empty_like(feat).uniform_(-0.5, 0.5) * Q_feat
70 | grid_scaling = grid_scaling + torch.empty_like(grid_scaling).uniform_(-0.5, 0.5) * Q_scaling
71 | grid_offsets = grid_offsets + torch.empty_like(grid_offsets).uniform_(-0.5, 0.5) * Q_offsets.unsqueeze(1)
72 |
73 | choose_idx = torch.rand_like(anchor[:, 0]) <= 0.05
74 | choose_idx = choose_idx & mask_anchor_bool
75 | feat_chosen = feat[choose_idx]
76 | grid_scaling_chosen = grid_scaling[choose_idx]
77 | grid_offsets_chosen = grid_offsets[choose_idx].view(-1, 3*pc.n_offsets)
78 | mean = mean[choose_idx]
79 | scale = scale[choose_idx]
80 | mean_scaling = mean_scaling[choose_idx]
81 | scale_scaling = scale_scaling[choose_idx]
82 | mean_offsets = mean_offsets[choose_idx]
83 | scale_offsets = scale_offsets[choose_idx]
84 | Q_feat = Q_feat[choose_idx]
85 | Q_scaling = Q_scaling[choose_idx]
86 | Q_offsets = Q_offsets[choose_idx]
87 | binary_grid_masks_chosen = binary_grid_masks[choose_idx].repeat(1, 1, 3).view(-1, 3*pc.n_offsets)
88 | bit_feat = pc.entropy_gaussian.forward(feat_chosen, mean, scale, Q_feat, pc._anchor_feat.mean())
89 | bit_scaling = pc.entropy_gaussian.forward(grid_scaling_chosen, mean_scaling, scale_scaling, Q_scaling, pc.get_scaling.mean())
90 | bit_offsets = pc.entropy_gaussian.forward(grid_offsets_chosen, mean_offsets, scale_offsets, Q_offsets, pc._offset.mean())
91 | bit_offsets = bit_offsets * binary_grid_masks_chosen
92 | bit_per_feat_param = torch.sum(bit_feat) / bit_feat.numel() * mask_anchor_rate
93 | bit_per_scaling_param = torch.sum(bit_scaling) / bit_scaling.numel() * mask_anchor_rate
94 | bit_per_offsets_param = torch.sum(bit_offsets) / bit_offsets.numel() * mask_anchor_rate
95 | bit_per_param = (torch.sum(bit_feat) + torch.sum(bit_scaling) + torch.sum(bit_offsets)) / \
96 | (bit_feat.numel() + bit_scaling.numel() + bit_offsets.numel()) * mask_anchor_rate
97 |
98 | elif not pc.decoded_version:
99 | torch.cuda.synchronize(); t1 = time.time()
100 | feat_context = pc.calc_interp_feat(anchor)
101 | mean, scale, mean_scaling, scale_scaling, mean_offsets, scale_offsets, Q_feat_adj, Q_scaling_adj, Q_offsets_adj = \
102 | torch.split(pc.get_grid_mlp(feat_context), split_size_or_sections=[pc.feat_dim, pc.feat_dim, 6, 6, 3*pc.n_offsets, 3*pc.n_offsets, 1, 1, 1], dim=-1)
103 |
104 | Q_feat = Q_feat * (1 + torch.tanh(Q_feat_adj))
105 | Q_scaling = Q_scaling * (1 + torch.tanh(Q_scaling_adj))
106 | Q_offsets = Q_offsets * (1 + torch.tanh(Q_offsets_adj)) # [N_visible_anchor, 1]
107 | feat = (STE_multistep.apply(feat, Q_feat, pc._anchor_feat.mean())).detach()
108 | grid_scaling = (STE_multistep.apply(grid_scaling, Q_scaling, pc.get_scaling.mean())).detach()
109 | grid_offsets = (STE_multistep.apply(grid_offsets, Q_offsets.unsqueeze(1), pc._offset.mean())).detach()
110 | torch.cuda.synchronize(); time_sub = time.time() - t1
111 |
112 | else:
113 | pass
114 |
115 | ob_view = anchor - viewpoint_camera.camera_center
116 | ob_dist = ob_view.norm(dim=1, keepdim=True)
117 | ob_view = ob_view / ob_dist
118 |
119 | ## view-adaptive feature
120 | if pc.use_feat_bank:
121 | cat_view = torch.cat([ob_view, ob_dist], dim=1) # [3+1]
122 |
123 | bank_weight = pc.get_featurebank_mlp(cat_view).unsqueeze(dim=1) # [N_visible_anchor, 1, 3]
124 |
125 | feat = feat.unsqueeze(dim=-1) # feat: [N_visible_anchor, 32]
126 | feat = \
127 | feat[:, ::4, :1].repeat([1, 4, 1])*bank_weight[:, :, :1] + \
128 | feat[:, ::2, :1].repeat([1, 2, 1])*bank_weight[:, :, 1:2] + \
129 | feat[:, ::1, :1]*bank_weight[:, :, 2:]
130 | feat = feat.squeeze(dim=-1) # [N_visible_anchor, 32]
131 |
132 | cat_local_view = torch.cat([feat, ob_view, ob_dist], dim=1) # [N_visible_anchor, 32+3+1]
133 |
134 | neural_opacity = pc.get_opacity_mlp(cat_local_view) # [N_visible_anchor, K]
135 | neural_opacity = neural_opacity.reshape([-1, 1]) # [N_visible_anchor*K, 1]
136 | neural_opacity = neural_opacity * binary_grid_masks.view(-1, 1)
137 | mask = (neural_opacity > 0.0)
138 | mask = mask.view(-1) # [N_visible_anchor*K]
139 |
140 | # select opacity
141 | opacity = neural_opacity[mask] # [N_opacity_pos_gaussian, 1]
142 |
143 | # get offset's color
144 | color = pc.get_color_mlp(cat_local_view) # [N_visible_anchor, K*3]
145 | color = color.reshape([anchor.shape[0] * pc.n_offsets, 3]) # [N_visible_anchor*K, 3]
146 |
147 | # get offset's cov
148 | scale_rot = pc.get_cov_mlp(cat_local_view) # [N_visible_anchor, K*7]
149 | scale_rot = scale_rot.reshape([anchor.shape[0] * pc.n_offsets, 7]) # [N_visible_anchor*K, 7]
150 |
151 | offsets = grid_offsets.view([-1, 3]) # [N_visible_anchor*K, 3]
152 |
153 | # combine for parallel masking
154 | concatenated = torch.cat([grid_scaling, anchor], dim=-1) # [N_visible_anchor, 6+3]
155 | concatenated_repeated = repeat(concatenated, 'n (c) -> (n k) (c)', k=pc.n_offsets) # [N_visible_anchor*K, 6+3]
156 | concatenated_all = torch.cat([concatenated_repeated, color, scale_rot, offsets],
157 | dim=-1) # [N_visible_anchor*K, (6+3)+3+7+3]
158 | masked = concatenated_all[mask] # [N_opacity_pos_gaussian, (6+3)+3+7+3]
159 | scaling_repeat, repeat_anchor, color, scale_rot, offsets = masked.split([6, 3, 3, 7, 3], dim=-1)
160 |
161 | # post-process cov
162 | scaling = scaling_repeat[:, 3:] * torch.sigmoid(
163 | scale_rot[:, :3])
164 | rot = pc.rotation_activation(scale_rot[:, 3:7]) # [N_opacity_pos_gaussian, 4]
165 |
166 | offsets = offsets * scaling_repeat[:, :3] # [N_opacity_pos_gaussian, 3]
167 | xyz = repeat_anchor + offsets # [N_opacity_pos_gaussian, 3]
168 |
169 | if is_training:
170 | return xyz, color, opacity, scaling, rot, neural_opacity, mask, bit_per_param, bit_per_feat_param, bit_per_scaling_param, bit_per_offsets_param
171 | else:
172 | return xyz, color, opacity, scaling, rot, time_sub
173 |
174 |
175 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, visible_mask=None, retain_grad=False, step=0):
176 | """
177 | Render the scene.
178 |
179 | Background tensor (bg_color) must be on GPU!
180 | """
181 | is_training = pc.get_color_mlp.training
182 |
183 | if is_training:
184 | xyz, color, opacity, scaling, rot, neural_opacity, mask, bit_per_param, bit_per_feat_param, bit_per_scaling_param, bit_per_offsets_param = generate_neural_gaussians(viewpoint_camera, pc, visible_mask, is_training=is_training, step=step)
185 | else:
186 | xyz, color, opacity, scaling, rot, time_sub = generate_neural_gaussians(viewpoint_camera, pc, visible_mask, is_training=is_training, step=step)
187 |
188 | screenspace_points = torch.zeros_like(xyz, dtype=pc.get_anchor.dtype, requires_grad=True, device="cuda") + 0
189 | if retain_grad:
190 | try:
191 | screenspace_points.retain_grad()
192 | except:
193 | pass
194 |
195 | # Set up rasterization configuration
196 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
197 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
198 |
199 | raster_settings = GaussianRasterizationSettings(
200 | image_height=int(viewpoint_camera.image_height),
201 | image_width=int(viewpoint_camera.image_width),
202 | tanfovx=tanfovx,
203 | tanfovy=tanfovy,
204 | bg=bg_color,
205 | scale_modifier=scaling_modifier,
206 | viewmatrix=viewpoint_camera.world_view_transform,
207 | projmatrix=viewpoint_camera.full_proj_transform,
208 | sh_degree=1,
209 | campos=viewpoint_camera.camera_center,
210 | prefiltered=False,
211 | debug=pipe.debug
212 | )
213 |
214 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
215 |
216 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
217 | rendered_image, radii = rasterizer(
218 | means3D = xyz,
219 | means2D = screenspace_points,
220 | shs = None,
221 | colors_precomp = color,
222 | opacities = opacity,
223 | scales = scaling,
224 | rotations = rot,
225 | cov3D_precomp = None)
226 |
227 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
228 | if is_training:
229 | return {"render": rendered_image,
230 | "viewspace_points": screenspace_points,
231 | "visibility_filter" : radii > 0,
232 | "radii": radii,
233 | "selection_mask": mask,
234 | "neural_opacity": neural_opacity,
235 | "scaling": scaling,
236 | "bit_per_param": bit_per_param,
237 | "bit_per_feat_param": bit_per_feat_param,
238 | "bit_per_scaling_param": bit_per_scaling_param,
239 | "bit_per_offsets_param": bit_per_offsets_param,
240 | }
241 | else:
242 | return {"render": rendered_image,
243 | "viewspace_points": screenspace_points,
244 | "visibility_filter" : radii > 0,
245 | "radii": radii,
246 | "time_sub": time_sub,
247 | }
248 |
249 |
250 | def prefilter_voxel(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0,
251 | override_color=None):
252 | """
253 | Render the scene.
254 |
255 | Background tensor (bg_color) must be on GPU!
256 | """
257 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
258 | screenspace_points = torch.zeros_like(pc.get_anchor, dtype=pc.get_anchor.dtype, requires_grad=True,
259 | device="cuda") + 0
260 | try:
261 | screenspace_points.retain_grad()
262 | except:
263 | pass
264 |
265 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
266 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
267 |
268 | raster_settings = GaussianRasterizationSettings(
269 | image_height=int(viewpoint_camera.image_height),
270 | image_width=int(viewpoint_camera.image_width),
271 | tanfovx=tanfovx,
272 | tanfovy=tanfovy,
273 | bg=bg_color,
274 | scale_modifier=scaling_modifier,
275 | viewmatrix=viewpoint_camera.world_view_transform,
276 | projmatrix=viewpoint_camera.full_proj_transform,
277 | sh_degree=1,
278 | campos=viewpoint_camera.camera_center,
279 | prefiltered=False,
280 | debug=pipe.debug
281 | )
282 |
283 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
284 |
285 | means3D = pc.get_anchor
286 |
287 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
288 | # scaling / rotation by the rasterizer.
289 | scales = None
290 | rotations = None
291 | cov3D_precomp = None
292 | if pipe.compute_cov3D_python: # False
293 | cov3D_precomp = pc.get_covariance(scaling_modifier)
294 | else: # into here
295 | scales = pc.get_scaling # requires_grad = True
296 | rotations = pc.get_rotation # requires_grad = True
297 |
298 | radii_pure = rasterizer.visible_filter(
299 | means3D=means3D,
300 | scales=scales[:, :3],
301 | rotations=rotations,
302 | cov3D_precomp=cov3D_precomp, # None
303 | )
304 |
305 | return radii_pure > 0
306 |
--------------------------------------------------------------------------------
/gaussian_renderer/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/gaussian_renderer/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/gaussian_renderer/__pycache__/network_gui.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/gaussian_renderer/__pycache__/network_gui.cpython-37.pyc
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 | def init(wish_host, wish_port):
27 | global host, port, listener
28 | host = wish_host
29 | port = wish_port
30 | listener.bind((host, port))
31 | listener.listen()
32 | listener.settimeout(0)
33 |
34 | def try_connect():
35 | global conn, addr, listener
36 | try:
37 | conn, addr = listener.accept()
38 | print(f"\nConnected by {addr}")
39 | conn.settimeout(None)
40 | except Exception as inst:
41 | pass
42 |
43 | def read():
44 | global conn
45 | messageLength = conn.recv(4)
46 | messageLength = int.from_bytes(messageLength, 'little')
47 | message = conn.recv(messageLength)
48 | return json.loads(message.decode("utf-8"))
49 |
50 | def send(message_bytes, verify):
51 | global conn
52 | if message_bytes != None:
53 | conn.sendall(message_bytes)
54 | conn.sendall(len(verify).to_bytes(4, 'little'))
55 | conn.sendall(bytes(verify, 'ascii'))
56 |
57 | def receive():
58 | message = read()
59 |
60 | width = message["resolution_x"]
61 | height = message["resolution_y"]
62 |
63 | if width != 0 and height != 0:
64 | try:
65 | do_training = bool(message["train"])
66 | fovy = message["fov_y"]
67 | fovx = message["fov_x"]
68 | znear = message["z_near"]
69 | zfar = message["z_far"]
70 | do_shs_python = bool(message["shs_python"])
71 | do_rot_scale_python = bool(message["rot_scale_python"])
72 | keep_alive = bool(message["keep_alive"])
73 | scaling_modifier = message["scaling_modifier"]
74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75 | world_view_transform[:,1] = -world_view_transform[:,1]
76 | world_view_transform[:,2] = -world_view_transform[:,2]
77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78 | full_proj_transform[:,1] = -full_proj_transform[:,1]
79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80 | except Exception as e:
81 | print("")
82 | traceback.print_exc()
83 | raise e
84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85 | else:
86 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/results/DeepBlending/drjohnson.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,29.8491783,0.9062452,0.2545068,8044150.784,800442
3 | HAC-lowrate,29.5277023,0.9025214,0.2650948,5815821.9264,757083
4 |
--------------------------------------------------------------------------------
/results/DeepBlending/playroom.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,30.8383617,0.905672,0.2616119,5277483.008,448521
3 | HAC-lowrate,30.4393959,0.902317,0.2723858,3306474.7008,380598
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/bicycle.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,25.1073151,0.7424726,0.2588055,41048080.384,4462216
3 | HAC-lowrate,25.0472488,0.7415894,0.2642056,28876944.1792,4337034
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/bonsai.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,32.9719124,0.9480926,0.1801753,13334216.704,1110454
3 | HAC-lowrate,32.2754822,0.9419372,0.1894576,8976754.2784,1016354
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/counter.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,29.7447491,0.917532,0.1838127,10949964.5952,910833
3 | HAC-lowrate,29.347311,0.9105615,0.1951433,7614024.9088,786790
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/flowers.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,21.2706013,0.5747946,0.3767022,28892463.104,3111616
3 | HAC-lowrate,21.2626057,0.572156,0.3812912,20542232.9856,3030281
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/garden.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,27.4615631,0.8488841,0.1394164,33734157.9264,3159104
3 | HAC-lowrate,27.2805805,0.8419042,0.1514979,23796278.8864,2783590
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/kitchen.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,31.6308594,0.9295087,0.1216748,12658094.8992,1160861
3 | HAC-lowrate,31.1647701,0.9230514,0.1306984,8445755.392,955774
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/room.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,31.9047031,0.9261015,0.1975124,8234886.7584,638887
3 | HAC-lowrate,31.5470905,0.9211181,0.2078761,5796528.128,613008
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/stump.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,26.5885735,0.7628535,0.2635286,26482520.8832,2811861
3 | HAC-lowrate,26.5819645,0.7618363,0.2692434,18989711.36,3139864
4 |
--------------------------------------------------------------------------------
/results/MipNeRF360/treehill.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,23.2636204,0.6480887,0.3448641,31086503.5264,3016897
3 | HAC-lowrate,23.3030071,0.6446206,0.3564054,21010422.1696,2835570
4 |
--------------------------------------------------------------------------------
/results/README.md:
--------------------------------------------------------------------------------
1 | Results formatted following [3DGS.zip](https://github.com/w-m/3dgs-compression-survey?tab=readme-ov-file#including-your-own-results).
2 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/chair.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,35.4915924,0.9860504,0.0126323,1753219.072,127611
3 | HAC-lowrate,34.7256927,0.9837951,0.0156229,1080662.4256,82370
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/drums.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,26.4458008,0.9520242,0.040858,2432906.0352,188193
3 | HAC-lowrate,26.3249664,0.9518869,0.0425592,1524944.0768,130122
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/ficus.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,35.3031387,0.9861924,0.0125956,1608830.1568,111752
3 | HAC-lowrate,34.8955841,0.9849692,0.0141863,990170.3168,70576
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/hotdog.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,37.8689346,0.983838,0.0237683,1020054.7328,62248
3 | HAC-lowrate,37.1095505,0.9814867,0.0286784,669201.2032,39299
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/lego.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,35.6688728,0.9810752,0.0185425,1997012.992,128854
3 | HAC-lowrate,35.0378571,0.9789792,0.0217885,1313656.0128,106010
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/materials.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,30.7035522,0.9620085,0.0386389,2171391.1808,204113
3 | HAC-lowrate,30.5300407,0.9608711,0.0413285,1518967.1936,152503
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/mic.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,36.7093277,0.9918905,0.0075859,1063256.064,83890
3 | HAC-lowrate,35.9191742,0.9903358,0.0095429,700763.3408,54533
4 |
--------------------------------------------------------------------------------
/results/SyntheticNeRF/ship.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,31.5231514,0.9036886,0.1153168,3558866.944,238351
3 | HAC-lowrate,31.3790035,0.903459,0.1189712,2089707.1104,163975
4 |
--------------------------------------------------------------------------------
/results/TanksAndTemples/train.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,22.7802544,0.8230955,0.2065995,10554232.0128,888307
3 | HAC-lowrate,22.1893082,0.8151628,0.2155594,7277641.728,746258
4 |
--------------------------------------------------------------------------------
/results/TanksAndTemples/truck.csv:
--------------------------------------------------------------------------------
1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians
2 | HAC-highrate,26.0207634,0.8831319,0.1471665,13024677.0688,1163467
3 | HAC-lowrate,25.8835316,0.8775318,0.1575988,9709813.76,951290
4 |
--------------------------------------------------------------------------------
/run_shell_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005
4 | for cuda, scene in enumerate(['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship']):
5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/nerf_synthetic/{scene} --eval --lod 0 --voxel_size 0.001 --update_init_factor 4 --iterations 30_000 -m outputs/nerf_synthetic/{scene}/{lmbda} --lmbda {lmbda}'
6 | os.system(one_cmd)
7 |
--------------------------------------------------------------------------------
/run_shell_bungee.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005
4 | for cuda, scene in enumerate(['amsterdam', 'bilbao', 'hollywood', 'pompidou', 'quebec', 'rome']):
5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/bungeenerf/{scene} --eval --lod 30 --voxel_size 0 --update_init_factor 128 --iterations 30_000 -m outputs/bungeenerf/{scene}/{lmbda} --lmbda {lmbda}'
6 | os.system(one_cmd)
7 |
--------------------------------------------------------------------------------
/run_shell_db.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005
4 | for cuda, scene in enumerate(['playroom', 'drjohnson']):
5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/blending/{scene} --eval --lod 0 --voxel_size 0.005 --update_init_factor 16 --iterations 30_000 -m outputs/blending/{scene}/{lmbda} --lmbda {lmbda}'
6 | os.system(one_cmd)
7 |
--------------------------------------------------------------------------------
/run_shell_mip360.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005
4 | for cuda, scene in enumerate(['bicycle', 'garden', 'stump', 'room', 'counter', 'kitchen', 'bonsai', 'flowers', 'treehill']):
5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/mipnerf360/{scene} --eval --lod 0 --voxel_size 0.001 --update_init_factor 16 --iterations 30_000 -m outputs/mipnerf360/{scene}/{lmbda} --lmbda {lmbda}'
6 | os.system(one_cmd)
--------------------------------------------------------------------------------
/run_shell_tnt.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | for lmbda in [0.004]: # Optionally, you can try: 0.003, 0.002, 0.001, 0.0005
4 | for cuda, scene in enumerate(['truck', 'train']):
5 | one_cmd = f'CUDA_VISIBLE_DEVICES={0} python train.py -s data/tandt/{scene} --eval --lod 0 --voxel_size 0.01 --update_init_factor 16 --iterations 30_000 -m outputs/tandt/{scene}/{lmbda} --lmbda {lmbda}'
6 | os.system(one_cmd)
7 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 |
21 | class Scene:
22 |
23 | gaussians : GaussianModel
24 |
25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], ply_path=None):
26 | """b
27 | :param path: Path to colmap scene main folder.
28 | """
29 | self.model_path = args.model_path
30 | self.loaded_iter = None
31 | self.gaussians = gaussians
32 |
33 | if load_iteration:
34 | if load_iteration == -1:
35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
36 | else:
37 | self.loaded_iter = load_iteration
38 |
39 | print("Loading trained model at iteration {}".format(self.loaded_iter))
40 |
41 | self.train_cameras = {}
42 | self.test_cameras = {}
43 |
44 | self.x_bound = None
45 | if os.path.exists(os.path.join(args.source_path, "sparse")):
46 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args.lod)
47 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
48 | print("Found transforms_train.json file, assuming Blender data set!")
49 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, ply_path=ply_path)
50 | self.x_bound = 1.3
51 | else:
52 | assert False, "Could not recognize scene type!"
53 |
54 | if not self.loaded_iter:
55 | if ply_path is not None:
56 | with open(ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
57 | dest_file.write(src_file.read())
58 | else:
59 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
60 | dest_file.write(src_file.read())
61 | json_cams = []
62 | camlist = []
63 | if scene_info.test_cameras:
64 | camlist.extend(scene_info.test_cameras)
65 | if scene_info.train_cameras:
66 | camlist.extend(scene_info.train_cameras)
67 | for id, cam in enumerate(camlist):
68 | json_cams.append(camera_to_JSON(id, cam))
69 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
70 | json.dump(json_cams, file)
71 |
72 | if shuffle:
73 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
74 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
75 |
76 | self.cameras_extent = scene_info.nerf_normalization["radius"]
77 |
78 | # print(f'self.cameras_extent: {self.cameras_extent}')
79 |
80 | for resolution_scale in resolution_scales:
81 | print("Loading Training Cameras")
82 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
83 | print("Loading Test Cameras")
84 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
85 |
86 | if self.loaded_iter:
87 | self.gaussians.load_ply_sparse_gaussian(os.path.join(self.model_path,
88 | "point_cloud",
89 | "iteration_" + str(self.loaded_iter),
90 | "point_cloud.ply"))
91 | self.gaussians.load_mlp_checkpoints(os.path.join(self.model_path,
92 | "point_cloud",
93 | "iteration_" + str(self.loaded_iter),
94 | "checkpoint.pth"))
95 | else:
96 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
97 |
98 | def save(self, iteration):
99 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
100 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
101 | self.gaussians.save_mlp_checkpoints(os.path.join(point_cloud_path, "checkpoint.pth"))
102 |
103 | def getTrainCameras(self, scale=1.0):
104 | return self.train_cameras[scale]
105 |
106 | def getTestCameras(self, scale=1.0):
107 | return self.test_cameras[scale]
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16 |
17 | class Camera(nn.Module):
18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
19 | image_name, uid,
20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
21 | ):
22 | super(Camera, self).__init__()
23 |
24 | self.uid = uid
25 | self.colmap_id = colmap_id
26 | self.R = R
27 | self.T = T
28 | self.FoVx = FoVx
29 | self.FoVy = FoVy
30 | self.image_name = image_name
31 |
32 | try:
33 | self.data_device = torch.device(data_device)
34 | except Exception as e:
35 | print(e)
36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
37 | self.data_device = torch.device("cuda")
38 |
39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40 | self.image_width = self.original_image.shape[2]
41 | self.image_height = self.original_image.shape[1]
42 |
43 | if gt_alpha_mask is not None:
44 | self.original_image *= gt_alpha_mask.to(self.data_device)
45 | else:
46 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
47 |
48 | self.zfar = 100.0
49 | self.znear = 0.01
50 |
51 | self.trans = trans
52 | self.scale = scale
53 |
54 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
55 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
56 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
57 | self.camera_center = self.world_view_transform.inverse()[3, :3]
58 |
59 | class MiniCam:
60 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
61 | self.image_width = width
62 | self.image_height = height
63 | self.FoVy = fovy
64 | self.FoVx = fovx
65 | self.znear = znear
66 | self.zfar = zfar
67 | self.world_view_transform = world_view_transform
68 | self.full_proj_transform = full_proj_transform
69 | view_inv = torch.inverse(self.world_view_transform)
70 | self.camera_center = view_inv[3][:3]
71 |
72 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54 |
55 | def rotmat2qvec(R):
56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57 | K = np.array([
58 | [Rxx - Ryy - Rzz, 0, 0, 0],
59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62 | eigvals, eigvecs = np.linalg.eigh(K)
63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64 | if qvec[0] < 0:
65 | qvec *= -1
66 | return qvec
67 |
68 | class Image(BaseImage):
69 | def qvec2rotmat(self):
70 | return qvec2rotmat(self.qvec)
71 |
72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73 | """Read and unpack the next bytes from a binary file.
74 | :param fid:
75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77 | :param endian_character: Any of {@, =, <, >, !}
78 | :return: Tuple of read and unpacked values.
79 | """
80 | data = fid.read(num_bytes)
81 | return struct.unpack(endian_character + format_char_sequence, data)
82 |
83 | def read_points3D_text(path):
84 | """
85 | see: src/base/reconstruction.cc
86 | void Reconstruction::ReadPoints3DText(const std::string& path)
87 | void Reconstruction::WritePoints3DText(const std::string& path)
88 | """
89 | xyzs = None
90 | rgbs = None
91 | errors = None
92 | num_points = 0
93 | with open(path, "r") as fid:
94 | while True:
95 | line = fid.readline()
96 | if not line:
97 | break
98 | line = line.strip()
99 | if len(line) > 0 and line[0] != "#":
100 | num_points += 1
101 |
102 |
103 | xyzs = np.empty((num_points, 3))
104 | rgbs = np.empty((num_points, 3))
105 | errors = np.empty((num_points, 1))
106 | count = 0
107 | with open(path, "r") as fid:
108 | while True:
109 | line = fid.readline()
110 | if not line:
111 | break
112 | line = line.strip()
113 | if len(line) > 0 and line[0] != "#":
114 | elems = line.split()
115 | xyz = np.array(tuple(map(float, elems[1:4])))
116 | rgb = np.array(tuple(map(int, elems[4:7])))
117 | error = np.array(float(elems[7]))
118 | xyzs[count] = xyz
119 | rgbs[count] = rgb
120 | errors[count] = error
121 | count += 1
122 |
123 | return xyzs, rgbs, errors
124 |
125 | def read_points3D_binary(path_to_model_file):
126 | """
127 | see: src/base/reconstruction.cc
128 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
129 | void Reconstruction::WritePoints3DBinary(const std::string& path)
130 | """
131 |
132 |
133 | with open(path_to_model_file, "rb") as fid:
134 | num_points = read_next_bytes(fid, 8, "Q")[0]
135 |
136 | xyzs = np.empty((num_points, 3))
137 | rgbs = np.empty((num_points, 3))
138 | errors = np.empty((num_points, 1))
139 |
140 | for p_id in range(num_points):
141 | binary_point_line_properties = read_next_bytes(
142 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
143 | xyz = np.array(binary_point_line_properties[1:4])
144 | rgb = np.array(binary_point_line_properties[4:7])
145 | error = np.array(binary_point_line_properties[7])
146 | track_length = read_next_bytes(
147 | fid, num_bytes=8, format_char_sequence="Q")[0]
148 | track_elems = read_next_bytes(
149 | fid, num_bytes=8*track_length,
150 | format_char_sequence="ii"*track_length)
151 | xyzs[p_id] = xyz
152 | rgbs[p_id] = rgb
153 | errors[p_id] = error
154 | return xyzs, rgbs, errors
155 |
156 | def read_intrinsics_text(path):
157 | """
158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159 | """
160 | cameras = {}
161 | with open(path, "r") as fid:
162 | while True:
163 | line = fid.readline()
164 | if not line:
165 | break
166 | line = line.strip()
167 | if len(line) > 0 and line[0] != "#":
168 | elems = line.split()
169 | camera_id = int(elems[0])
170 | model = elems[1]
171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172 | width = int(elems[2])
173 | height = int(elems[3])
174 | params = np.array(tuple(map(float, elems[4:])))
175 | cameras[camera_id] = Camera(id=camera_id, model=model,
176 | width=width, height=height,
177 | params=params)
178 | return cameras
179 |
180 | def read_extrinsics_binary(path_to_model_file):
181 | """
182 | see: src/base/reconstruction.cc
183 | void Reconstruction::ReadImagesBinary(const std::string& path)
184 | void Reconstruction::WriteImagesBinary(const std::string& path)
185 | """
186 | images = {}
187 | with open(path_to_model_file, "rb") as fid:
188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189 | for _ in range(num_reg_images):
190 | binary_image_properties = read_next_bytes(
191 | fid, num_bytes=64, format_char_sequence="idddddddi")
192 | image_id = binary_image_properties[0]
193 | qvec = np.array(binary_image_properties[1:5])
194 | tvec = np.array(binary_image_properties[5:8])
195 | camera_id = binary_image_properties[8]
196 | image_name = ""
197 | current_char = read_next_bytes(fid, 1, "c")[0]
198 | while current_char != b"\x00": # look for the ASCII 0 entry
199 | image_name += current_char.decode("utf-8")
200 | current_char = read_next_bytes(fid, 1, "c")[0]
201 | num_points2D = read_next_bytes(fid, num_bytes=8,
202 | format_char_sequence="Q")[0]
203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204 | format_char_sequence="ddq"*num_points2D)
205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206 | tuple(map(float, x_y_id_s[1::3]))])
207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208 | images[image_id] = Image(
209 | id=image_id, qvec=qvec, tvec=tvec,
210 | camera_id=camera_id, name=image_name,
211 | xys=xys, point3D_ids=point3D_ids)
212 | return images
213 |
214 |
215 | def read_intrinsics_binary(path_to_model_file):
216 | """
217 | see: src/base/reconstruction.cc
218 | void Reconstruction::WriteCamerasBinary(const std::string& path)
219 | void Reconstruction::ReadCamerasBinary(const std::string& path)
220 | """
221 | cameras = {}
222 | with open(path_to_model_file, "rb") as fid:
223 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
224 | for _ in range(num_cameras):
225 | camera_properties = read_next_bytes(
226 | fid, num_bytes=24, format_char_sequence="iiQQ")
227 | camera_id = camera_properties[0]
228 | model_id = camera_properties[1]
229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230 | width = camera_properties[2]
231 | height = camera_properties[3]
232 | num_params = CAMERA_MODEL_IDS[model_id].num_params
233 | params = read_next_bytes(fid, num_bytes=8*num_params,
234 | format_char_sequence="d"*num_params)
235 | cameras[camera_id] = Camera(id=camera_id,
236 | model=model_name,
237 | width=width,
238 | height=height,
239 | params=np.array(params))
240 | assert len(cameras) == num_cameras
241 | return cameras
242 |
243 |
244 | def read_extrinsics_text(path):
245 | """
246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247 | """
248 | images = {}
249 | with open(path, "r") as fid:
250 | while True:
251 | line = fid.readline()
252 | if not line:
253 | break
254 | line = line.strip()
255 | if len(line) > 0 and line[0] != "#":
256 | elems = line.split()
257 | image_id = int(elems[0])
258 | qvec = np.array(tuple(map(float, elems[1:5])))
259 | tvec = np.array(tuple(map(float, elems[5:8])))
260 | camera_id = int(elems[8])
261 | image_name = elems[9]
262 | elems = fid.readline().split()
263 | xys = np.column_stack([tuple(map(float, elems[0::3])),
264 | tuple(map(float, elems[1::3]))])
265 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
266 | images[image_id] = Image(
267 | id=image_id, qvec=qvec, tvec=tvec,
268 | camera_id=camera_id, name=image_name,
269 | xys=xys, point3D_ids=point3D_ids)
270 | return images
271 |
272 |
273 | def read_colmap_bin_array(path):
274 | """
275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276 |
277 | :param path: path to the colmap binary file.
278 | :return: nd array with the floating point values in the value
279 | """
280 | with open(path, "rb") as fid:
281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282 | usecols=(0, 1, 2), dtype=int)
283 | fid.seek(0)
284 | num_delimiter = 0
285 | byte = fid.read(1)
286 | while True:
287 | if byte == b"&":
288 | num_delimiter += 1
289 | if num_delimiter >= 3:
290 | break
291 | byte = fid.read(1)
292 | array = np.fromfile(fid, np.float32)
293 | array = array.reshape((width, height, channels), order="F")
294 | return np.transpose(array, (1, 0, 2)).squeeze()
295 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from tqdm import tqdm
16 | from typing import NamedTuple
17 | from colorama import Fore, init, Style
18 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
19 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
20 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
21 | import numpy as np
22 | import json
23 | from pathlib import Path
24 | from plyfile import PlyData, PlyElement
25 | from utils.sh_utils import SH2RGB
26 | from scene.gaussian_model import BasicPointCloud
27 |
28 | class CameraInfo(NamedTuple):
29 | uid: int
30 | R: np.array
31 | T: np.array
32 | FovY: np.array
33 | FovX: np.array
34 | image: np.array
35 | image_path: str
36 | image_name: str
37 | width: int
38 | height: int
39 |
40 | class SceneInfo(NamedTuple):
41 | point_cloud: BasicPointCloud
42 | train_cameras: list
43 | test_cameras: list
44 | nerf_normalization: dict
45 | ply_path: str
46 |
47 | def getNerfppNorm(cam_info):
48 | def get_center_and_diag(cam_centers):
49 | cam_centers = np.hstack(cam_centers)
50 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
51 | center = avg_cam_center
52 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
53 | diagonal = np.max(dist)
54 | return center.flatten(), diagonal
55 |
56 | cam_centers = []
57 |
58 | for cam in cam_info:
59 | W2C = getWorld2View2(cam.R, cam.T)
60 | C2W = np.linalg.inv(W2C)
61 | cam_centers.append(C2W[:3, 3:4])
62 |
63 | center, diagonal = get_center_and_diag(cam_centers)
64 | radius = diagonal * 1.1
65 |
66 | translate = -center
67 |
68 | return {"translate": translate, "radius": radius}
69 |
70 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
71 | cam_infos = []
72 | for idx, key in enumerate(cam_extrinsics):
73 | sys.stdout.write('\r')
74 | # the exact output you're looking for:
75 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
76 | sys.stdout.flush()
77 |
78 | extr = cam_extrinsics[key]
79 | intr = cam_intrinsics[extr.camera_id]
80 | height = intr.height
81 | width = intr.width
82 |
83 | uid = intr.id
84 | R = np.transpose(qvec2rotmat(extr.qvec))
85 | T = np.array(extr.tvec)
86 |
87 | # if intr.model=="SIMPLE_PINHOLE":
88 | if intr.model=="SIMPLE_PINHOLE" or intr.model == "SIMPLE_RADIAL":
89 | focal_length_x = intr.params[0]
90 | FovY = focal2fov(focal_length_x, height)
91 | FovX = focal2fov(focal_length_x, width)
92 | elif intr.model=="PINHOLE":
93 | focal_length_x = intr.params[0]
94 | focal_length_y = intr.params[1]
95 | FovY = focal2fov(focal_length_y, height)
96 | FovX = focal2fov(focal_length_x, width)
97 | else:
98 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
99 |
100 | # print(f'FovX: {FovX}, FovY: {FovY}')
101 |
102 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
103 | image_name = os.path.basename(image_path).split(".")[0]
104 | image = Image.open(image_path)
105 |
106 | # print(f'image: {image.size}')
107 |
108 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
109 | image_path=image_path, image_name=image_name, width=width, height=height)
110 | cam_infos.append(cam_info)
111 | sys.stdout.write('\n')
112 | return cam_infos
113 |
114 | def fetchPly(path):
115 | plydata = PlyData.read(path)
116 | vertices = plydata['vertex']
117 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
118 | try:
119 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
120 | except:
121 | colors = np.random.rand(positions.shape[0], positions.shape[1])
122 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
123 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
124 |
125 | def storePly(path, xyz, rgb):
126 | # Define the dtype for the structured array
127 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
128 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
129 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
130 |
131 | normals = np.zeros_like(xyz)
132 |
133 | elements = np.empty(xyz.shape[0], dtype=dtype)
134 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
135 | elements[:] = list(map(tuple, attributes))
136 |
137 | # Create the PlyData object and write to file
138 | vertex_element = PlyElement.describe(elements, 'vertex')
139 | ply_data = PlyData([vertex_element])
140 | ply_data.write(path)
141 |
142 | def readColmapSceneInfo(path, images, eval, lod, llffhold=8):
143 | try:
144 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
145 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
146 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
147 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
148 | except:
149 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
150 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
151 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
152 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
153 |
154 | reading_dir = "images" if images == None else images
155 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
156 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
157 |
158 | if eval:
159 | if lod>0:
160 | print(f'using lod, using eval')
161 | if lod < 50:
162 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx > lod]
163 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx <= lod]
164 | print(f'test_cam_infos: {len(test_cam_infos)}')
165 | else:
166 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx <= lod]
167 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx > lod]
168 |
169 | else:
170 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
171 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
172 |
173 | else:
174 | train_cam_infos = cam_infos
175 | test_cam_infos = []
176 |
177 | nerf_normalization = getNerfppNorm(train_cam_infos)
178 |
179 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
180 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
181 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
182 | if not os.path.exists(ply_path):
183 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
184 | try:
185 | xyz, rgb, _ = read_points3D_binary(bin_path)
186 | except:
187 | xyz, rgb, _ = read_points3D_text(txt_path)
188 | storePly(ply_path, xyz, rgb)
189 | # try:
190 | print(f'start fetching data from ply file')
191 | pcd = fetchPly(ply_path)
192 | # except:
193 | # pcd = None
194 |
195 | scene_info = SceneInfo(point_cloud=pcd,
196 | train_cameras=train_cam_infos,
197 | test_cameras=test_cam_infos,
198 | nerf_normalization=nerf_normalization,
199 | ply_path=ply_path)
200 | return scene_info
201 |
202 | # def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
203 | # cam_infos = []
204 |
205 | # with open(os.path.join(path, transformsfile)) as json_file:
206 | # contents = json.load(json_file)
207 | # fovx = contents["camera_angle_x"]
208 |
209 | # frames = contents["frames"]
210 | # for idx, frame in enumerate(frames):
211 | # cam_name = os.path.join(path, frame["file_path"] + extension)
212 |
213 | # # NeRF 'transform_matrix' is a camera-to-world transform
214 | # c2w = np.array(frame["transform_matrix"])
215 | # # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
216 | # c2w[:3, 1:3] *= -1
217 |
218 | # # get the world-to-camera transform and set R, T
219 | # w2c = np.linalg.inv(c2w)
220 | # R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
221 | # T = w2c[:3, 3]
222 |
223 | # image_path = os.path.join(path, cam_name)
224 | # image_name = Path(cam_name).stem
225 | # image = Image.open(image_path)
226 |
227 | # im_data = np.array(image.convert("RGBA"))
228 |
229 | # bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
230 |
231 | # norm_data = im_data / 255.0
232 | # arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
233 | # image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
234 |
235 | # fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
236 | # FovY = fovy
237 | # FovX = fovx
238 |
239 | # cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
240 | # image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
241 |
242 | # return cam_infos
243 |
244 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png", is_debug=False):
245 | cam_infos = []
246 | with open(os.path.join(path, transformsfile)) as json_file:
247 | contents = json.load(json_file)
248 | try:
249 | fovx = contents["camera_angle_x"]
250 | except:
251 | fovx = None
252 |
253 | frames = contents["frames"]
254 | # check if filename already contain postfix
255 | if frames[0]["file_path"].split('.')[-1] in ['jpg', 'jpeg', 'JPG', 'png']:
256 | extension = ""
257 |
258 | c2ws = np.array([frame["transform_matrix"] for frame in frames])
259 |
260 | Ts = c2ws[:,:3,3]
261 |
262 | ct = 0
263 |
264 | progress_bar = tqdm(frames, desc="Loading dataset")
265 |
266 | for idx, frame in enumerate(frames):
267 | cam_name = os.path.join(path, frame["file_path"] + extension)
268 | if not os.path.exists(cam_name):
269 | continue
270 | # NeRF 'transform_matrix' is a camera-to-world transform
271 | c2w = np.array(frame["transform_matrix"])
272 |
273 | if idx % 10 == 0:
274 | progress_bar.set_postfix({"num": Fore.YELLOW+f"{ct}/{len(frames)}"+Style.RESET_ALL})
275 | progress_bar.update(10)
276 | if idx == len(frames) - 1:
277 | progress_bar.close()
278 |
279 | ct += 1
280 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
281 | c2w[:3, 1:3] *= -1
282 | if "small_city_img" in path:
283 | c2w[-1,-1] = 1
284 |
285 | # get the world-to-camera transform and set R, T
286 | w2c = np.linalg.inv(c2w)
287 |
288 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
289 | T = w2c[:3, 3]
290 |
291 | image_path = os.path.join(path, cam_name)
292 | image_name = Path(cam_name).stem
293 | image = Image.open(image_path)
294 |
295 | im_data = np.array(image.convert("RGBA"))
296 |
297 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
298 |
299 | norm_data = im_data / 255.0
300 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
301 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
302 |
303 | if fovx is not None:
304 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
305 | FovY = fovy
306 | FovX = fovx
307 | else:
308 | # given focal in pixel unit
309 | FovY = focal2fov(frame["fl_y"], image.size[1])
310 | FovX = focal2fov(frame["fl_x"], image.size[0])
311 |
312 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
313 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
314 |
315 | if is_debug and idx > 50:
316 | break
317 | return cam_infos
318 |
319 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png", ply_path=None):
320 | print("Reading Training Transforms")
321 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
322 | print("Reading Test Transforms")
323 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
324 |
325 | if not eval:
326 | train_cam_infos.extend(test_cam_infos)
327 | test_cam_infos = []
328 |
329 | nerf_normalization = getNerfppNorm(train_cam_infos)
330 | if ply_path is None:
331 | ply_path = os.path.join(path, "points3d.ply")
332 | if not os.path.exists(ply_path):
333 | # Since this data set has no colmap data, we start with random points
334 | num_pts = 10_000
335 | print(f"Generating random point cloud ({num_pts})...")
336 |
337 | # We create random points inside the bounds of the synthetic Blender scenes
338 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
339 | shs = np.random.random((num_pts, 3)) / 255.0
340 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
341 |
342 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
343 | try:
344 | pcd = fetchPly(ply_path)
345 | except:
346 | pcd = None
347 |
348 | scene_info = SceneInfo(point_cloud=pcd,
349 | train_cameras=train_cam_infos,
350 | test_cameras=test_cam_infos,
351 | nerf_normalization=nerf_normalization,
352 | ply_path=ply_path)
353 | return scene_info
354 |
355 |
356 | sceneLoadTypeCallbacks = {
357 | "Colmap": readColmapSceneInfo,
358 | "Blender": readNerfSyntheticInfo,
359 | }
--------------------------------------------------------------------------------
/submodules/arithmetic.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/arithmetic.zip
--------------------------------------------------------------------------------
/submodules/diff-gaussian-rasterization.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/diff-gaussian-rasterization.zip
--------------------------------------------------------------------------------
/submodules/gridencoder.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/gridencoder.zip
--------------------------------------------------------------------------------
/submodules/simple-knn.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YihangChen-ee/HAC/6a04feae89cf1eb6fa369af449effddfb1a5724f/submodules/simple-knn.zip
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import numpy as np
14 |
15 | import subprocess
16 | # cmd = 'nvidia-smi -q -d Memory |grep -A4 GPU|grep Used'
17 | # result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode().split('\n')
18 | # os.environ['CUDA_VISIBLE_DEVICES']=str(np.argmin([int(x.split()[2]) for x in result[:-1]]))
19 |
20 | # os.system('echo $CUDA_VISIBLE_DEVICES')
21 |
22 |
23 | import torch
24 | import torchvision
25 | import json
26 | import wandb
27 | import time
28 | from os import makedirs
29 | import shutil, pathlib
30 | from pathlib import Path
31 | from PIL import Image
32 | import torchvision.transforms.functional as tf
33 | # from lpipsPyTorch import lpips
34 | # import lpips
35 | from random import randint
36 | from utils.loss_utils import l1_loss, ssim
37 | from gaussian_renderer import prefilter_voxel, render, network_gui
38 | import sys
39 | from scene import Scene, GaussianModel
40 | from utils.general_utils import safe_state
41 | import uuid
42 | from tqdm import tqdm
43 | from utils.image_utils import psnr
44 | from argparse import ArgumentParser, Namespace
45 | from arguments import ModelParams, PipelineParams, OptimizationParams
46 | from utils.encodings import get_binary_vxl_size
47 |
48 | # torch.set_num_threads(32)
49 | # lpips_fn = lpips.LPIPS(net='vgg').to('cuda')
50 |
51 | from lpipsPyTorch import lpips
52 |
53 | bit2MB_scale = 8 * 1024 * 1024
54 | run_codec = True
55 |
56 | try:
57 | from torch.utils.tensorboard import SummaryWriter
58 | TENSORBOARD_FOUND = True
59 | print("found tf board")
60 | except ImportError:
61 | TENSORBOARD_FOUND = False
62 | print("not found tf board")
63 |
64 | def saveRuntimeCode(dst: str) -> None:
65 | additionalIgnorePatterns = ['.git', '.gitignore']
66 | ignorePatterns = set()
67 | ROOT = '.'
68 | with open(os.path.join(ROOT, '.gitignore')) as gitIgnoreFile:
69 | for line in gitIgnoreFile:
70 | if not line.startswith('#'):
71 | if line.endswith('\n'):
72 | line = line[:-1]
73 | if line.endswith('/'):
74 | line = line[:-1]
75 | ignorePatterns.add(line)
76 | ignorePatterns = list(ignorePatterns)
77 | for additionalPattern in additionalIgnorePatterns:
78 | ignorePatterns.append(additionalPattern)
79 |
80 | log_dir = pathlib.Path(__file__).parent.resolve()
81 |
82 |
83 | shutil.copytree(log_dir, dst, ignore=shutil.ignore_patterns(*ignorePatterns))
84 |
85 | print('Backup Finished!')
86 |
87 |
88 | def training(args_param, dataset, opt, pipe, dataset_name, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, wandb=None, logger=None, ply_path=None):
89 | first_iter = 0
90 | tb_writer = prepare_output_and_logger(dataset)
91 |
92 | gaussians = GaussianModel(
93 | dataset.feat_dim,
94 | dataset.n_offsets,
95 | dataset.voxel_size,
96 | dataset.update_depth,
97 | dataset.update_init_factor,
98 | dataset.update_hierachy_factor,
99 | dataset.use_feat_bank,
100 | n_features_per_level=args_param.n_features,
101 | log2_hashmap_size=args_param.log2,
102 | log2_hashmap_size_2D=args_param.log2_2D,
103 | )
104 | scene = Scene(dataset, gaussians, ply_path=ply_path)
105 | gaussians.update_anchor_bound()
106 |
107 | gaussians.training_setup(opt)
108 | if checkpoint:
109 | (model_params, first_iter) = torch.load(checkpoint)
110 | gaussians.restore(model_params, opt)
111 |
112 | iter_start = torch.cuda.Event(enable_timing = True)
113 | iter_end = torch.cuda.Event(enable_timing = True)
114 |
115 | viewpoint_stack = None
116 | ema_loss_for_log = 0.0
117 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
118 | first_iter += 1
119 | torch.cuda.synchronize(); t_start = time.time()
120 | log_time_sub = 0
121 | for iteration in range(first_iter, opt.iterations + 1):
122 | # network gui not available in scaffold-gs yet
123 | if network_gui.conn == None:
124 | network_gui.try_connect()
125 | while network_gui.conn != None:
126 | try:
127 | net_image_bytes = None
128 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
129 | if custom_cam != None:
130 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
131 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
132 | network_gui.send(net_image_bytes, dataset.source_path)
133 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
134 | break
135 | except Exception as e:
136 | network_gui.conn = None
137 |
138 | iter_start.record()
139 |
140 | gaussians.update_learning_rate(iteration)
141 |
142 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
143 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
144 |
145 | # Pick a random Camera
146 | if not viewpoint_stack:
147 | viewpoint_stack = scene.getTrainCameras().copy()
148 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
149 |
150 | # Render
151 | if (iteration - 1) == debug_from:
152 | pipe.debug = True
153 |
154 | voxel_visible_mask = prefilter_voxel(viewpoint_cam, gaussians, pipe, background)
155 | # voxel_visible_mask:bool = radii_pure > 0: 应该是[N_anchor]?
156 | retain_grad = (iteration < opt.update_until and iteration >= 0)
157 | render_pkg = render(viewpoint_cam, gaussians, pipe, background, visible_mask=voxel_visible_mask, retain_grad=retain_grad, step=iteration)
158 | image, viewspace_point_tensor, visibility_filter, offset_selection_mask, radii, scaling, opacity = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["selection_mask"], render_pkg["radii"], render_pkg["scaling"], render_pkg["neural_opacity"]
159 | # image: [3, H, W]. inited as: torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
160 | # viewspace_point_tensor=screenspace_points: [N_opacity_pos_gaussian, 3]
161 | # visibility_filter: radii > 0. 其中 radii inited as: torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 其中P=N_opacity_pos_gaussian
162 | # offset_selection_mask: [N_visible_anchor*k]。 用来表示visible anchor中哪几个gaussian是有效的,根据opacity>0.0得到
163 | # radii: [N_opacity_pos_gaussian]. inited as: torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 其中P=N_opacity_pos_gaussian
164 | # scaling: [N_opacity_pos_gaussian, 3]
165 | # opacity: [N_visible_anchor*K, 1]
166 |
167 | bit_per_param = render_pkg["bit_per_param"]
168 | bit_per_feat_param = render_pkg["bit_per_feat_param"]
169 | bit_per_scaling_param = render_pkg["bit_per_scaling_param"]
170 | bit_per_offsets_param = render_pkg["bit_per_offsets_param"]
171 |
172 | if iteration % 2000 == 0 and bit_per_param is not None:
173 |
174 | ttl_size_feat_MB = bit_per_feat_param.item() * gaussians.get_anchor.shape[0] * gaussians.feat_dim / bit2MB_scale
175 | ttl_size_scaling_MB = bit_per_scaling_param.item() * gaussians.get_anchor.shape[0] * 6 / bit2MB_scale
176 | ttl_size_offsets_MB = bit_per_offsets_param.item() * gaussians.get_anchor.shape[0] * 3 * gaussians.n_offsets / bit2MB_scale
177 | ttl_size_MB = ttl_size_feat_MB + ttl_size_scaling_MB + ttl_size_offsets_MB
178 |
179 | logger.info("\n----------------------------------------------------------------------------------------")
180 | logger.info("\n-----[ITER {}] bits info: bit_per_feat_param={}, anchor_num={}, ttl_size_feat_MB={}-----".format(iteration, bit_per_feat_param.item(), gaussians.get_anchor.shape[0], ttl_size_feat_MB))
181 | logger.info("\n-----[ITER {}] bits info: bit_per_scaling_param={}, anchor_num={}, ttl_size_scaling_MB={}-----".format(iteration, bit_per_scaling_param.item(), gaussians.get_anchor.shape[0], ttl_size_scaling_MB))
182 | logger.info("\n-----[ITER {}] bits info: bit_per_offsets_param={}, anchor_num={}, ttl_size_offsets_MB={}-----".format(iteration, bit_per_offsets_param.item(), gaussians.get_anchor.shape[0], ttl_size_offsets_MB))
183 | logger.info("\n-----[ITER {}] bits info: bit_per_param={}, anchor_num={}, ttl_size_MB={}-----".format(iteration, bit_per_param.item(), gaussians.get_anchor.shape[0], ttl_size_MB))
184 | with torch.no_grad():
185 | grid_masks = gaussians._mask.data
186 | binary_grid_masks = (torch.sigmoid(grid_masks) > 0.01).float()
187 | mask_1_rate, mask_size_bit, mask_size_MB, mask_numel = get_binary_vxl_size(binary_grid_masks + 0.0) # [0, 1] -> [-1, 1]
188 | logger.info("\n-----[ITER {}] bits info: 1_rate_mask={}, mask_numel={}, mask_size_MB={}-----".format(iteration, mask_1_rate, mask_numel, mask_size_MB))
189 |
190 | gt_image = viewpoint_cam.original_image.cuda()
191 | Ll1 = l1_loss(image, gt_image)
192 |
193 | ssim_loss = (1.0 - ssim(image, gt_image))
194 | scaling_reg = scaling.prod(dim=1).mean()
195 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ssim_loss + 0.01*scaling_reg
196 |
197 | if bit_per_param is not None:
198 | _, bit_hash_grid, MB_hash_grid, _ = get_binary_vxl_size((gaussians.get_encoding_params()+1)/2)
199 | denom = gaussians._anchor.shape[0]*(gaussians.feat_dim+6+3*gaussians.n_offsets)
200 | loss = loss + args_param.lmbda * (bit_per_param + bit_hash_grid / denom)
201 |
202 | loss = loss + 5e-4 * torch.mean(torch.sigmoid(gaussians._mask))
203 |
204 | loss.backward()
205 |
206 | iter_end.record()
207 |
208 | with torch.no_grad():
209 | # Progress bar
210 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
211 |
212 | if iteration % 10 == 0:
213 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
214 | progress_bar.update(10)
215 | if iteration == opt.iterations:
216 | progress_bar.close()
217 |
218 | # Log and save
219 | torch.cuda.synchronize(); t_start_log = time.time()
220 | training_report(tb_writer, dataset_name, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), wandb, logger, args_param.model_path)
221 | if (iteration in saving_iterations):
222 | logger.info("\n[ITER {}] Saving Gaussians".format(iteration))
223 | scene.save(iteration)
224 | torch.cuda.synchronize(); t_end_log = time.time()
225 | t_log = t_end_log - t_start_log
226 | log_time_sub += t_log
227 |
228 | # densification
229 | if iteration < opt.update_until and iteration > opt.start_stat:
230 | # add statis
231 | # viewspace_point_tensor=screenspace_points: [N_opacity_pos_gaussian, 3]
232 | # opacity: [N_visible_anchor*K, 1]
233 | # visibility_filter: radii > 0. 其中 radii inited as: torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 其中P=N_opacity_pos_gaussian
234 | # offset_selection_mask: [N_visible_anchor*k]。 用来表示visible anchor中哪几个gaussian是有效的,根据opacity>0.0得到
235 | # voxel_visible_mask:bool = radii_pure > 0: 应该是[N_anchor]? voxel_visible_mask.sum()=N_visible_anchor
236 | gaussians.training_statis(viewspace_point_tensor, opacity, visibility_filter, offset_selection_mask, voxel_visible_mask)
237 | if iteration not in range(3000, 4000): # let the model get fit to quantization
238 | # densification
239 | if iteration > opt.update_from and iteration % opt.update_interval == 0:
240 | gaussians.adjust_anchor(check_interval=opt.update_interval, success_threshold=opt.success_threshold, grad_threshold=opt.densify_grad_threshold, min_opacity=opt.min_opacity)
241 | elif iteration == opt.update_until:
242 | del gaussians.opacity_accum
243 | del gaussians.offset_gradient_accum
244 | del gaussians.offset_denom
245 | torch.cuda.empty_cache()
246 |
247 | if iteration < opt.iterations:
248 | gaussians.optimizer.step()
249 | gaussians.optimizer.zero_grad(set_to_none = True)
250 | if (iteration in checkpoint_iterations):
251 | logger.info("\n[ITER {}] Saving Checkpoint".format(iteration))
252 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
253 |
254 | torch.cuda.synchronize(); t_end = time.time()
255 | logger.info("\n Total Training time: {}".format(t_end-t_start-log_time_sub))
256 |
257 | return gaussians.x_bound_min, gaussians.x_bound_max
258 |
259 | def prepare_output_and_logger(args):
260 | if not args.model_path:
261 | if os.getenv('OAR_JOB_ID'):
262 | unique_str=os.getenv('OAR_JOB_ID')
263 | else:
264 | unique_str = str(uuid.uuid4())
265 | args.model_path = os.path.join("./output/", unique_str[0:10])
266 |
267 | # Set up output folder
268 | print("Output folder: {}".format(args.model_path))
269 | os.makedirs(args.model_path, exist_ok = True)
270 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
271 | cfg_log_f.write(str(Namespace(**vars(args))))
272 |
273 | # Create Tensorboard writer
274 | tb_writer = None
275 | if TENSORBOARD_FOUND:
276 | tb_writer = SummaryWriter(args.model_path)
277 | else:
278 | print("Tensorboard not available: not logging progress")
279 | return tb_writer
280 |
281 |
282 | def training_report(tb_writer, dataset_name, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, wandb=None, logger=None, pre_path_name=''):
283 | if tb_writer:
284 | tb_writer.add_scalar(f'{dataset_name}/train_loss_patches/l1_loss', Ll1.item(), iteration)
285 | tb_writer.add_scalar(f'{dataset_name}/train_loss_patches/total_loss', loss.item(), iteration)
286 | tb_writer.add_scalar(f'{dataset_name}/iter_time', elapsed, iteration)
287 |
288 | if wandb is not None:
289 | wandb.log({"train_l1_loss":Ll1, 'train_total_loss':loss, })
290 | # Report test and samples of training set
291 | if iteration in testing_iterations:
292 | scene.gaussians.eval()
293 |
294 | if 1:
295 | if iteration == testing_iterations[-1]:
296 | with torch.no_grad():
297 | log_info = scene.gaussians.estimate_final_bits()
298 | logger.info(log_info)
299 | if run_codec: # conduct encoding and decoding
300 | with torch.no_grad():
301 | bit_stream_path = os.path.join(pre_path_name, 'bitstreams')
302 | os.makedirs(bit_stream_path, exist_ok=True)
303 | # conduct encoding
304 | patched_infos, log_info = scene.gaussians.conduct_encoding(pre_path_name=bit_stream_path)
305 | logger.info(log_info)
306 | # conduct decoding
307 | log_info = scene.gaussians.conduct_decoding(pre_path_name=bit_stream_path, patched_infos=patched_infos)
308 | logger.info(log_info)
309 | torch.cuda.empty_cache()
310 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
311 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
312 |
313 | for config in validation_configs:
314 | # if config['name'] == 'test': assert len(config['cameras']) == 200
315 | if config['cameras'] and len(config['cameras']) > 0:
316 | l1_test = 0.0
317 | psnr_test = 0.0
318 | ssim_test = 0.0
319 | lpips_test = 0.0
320 |
321 | if wandb is not None:
322 | gt_image_list = []
323 | render_image_list = []
324 | errormap_list = []
325 |
326 | t_list = []
327 |
328 | for idx, viewpoint in enumerate(config['cameras']):
329 | torch.cuda.synchronize(); t_start = time.time()
330 | voxel_visible_mask = prefilter_voxel(viewpoint, scene.gaussians, *renderArgs)
331 | # image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs, visible_mask=voxel_visible_mask)["render"], 0.0, 1.0)
332 | render_output = renderFunc(viewpoint, scene.gaussians, *renderArgs, visible_mask=voxel_visible_mask)
333 | image = torch.clamp(render_output["render"], 0.0, 1.0)
334 | time_sub = render_output["time_sub"]
335 | torch.cuda.synchronize(); t_end = time.time()
336 | t_list.append(t_end - t_start - time_sub)
337 |
338 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
339 | if tb_writer and (idx < 30):
340 | tb_writer.add_images(f'{dataset_name}/'+config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
341 | tb_writer.add_images(f'{dataset_name}/'+config['name'] + "_view_{}/errormap".format(viewpoint.image_name), (gt_image[None]-image[None]).abs(), global_step=iteration)
342 |
343 | if wandb:
344 | render_image_list.append(image[None])
345 | errormap_list.append((gt_image[None]-image[None]).abs())
346 |
347 | if iteration == testing_iterations[0]:
348 | tb_writer.add_images(f'{dataset_name}/'+config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
349 | if wandb:
350 | gt_image_list.append(gt_image[None])
351 | l1_test += l1_loss(image, gt_image).mean().double()
352 | psnr_test += psnr(image, gt_image).mean().double()
353 | ssim_test += ssim(image, gt_image).mean().double()
354 | # lpips_test += lpips_fn(image, gt_image, normalize=True).detach().mean().double()
355 | lpips_test += lpips(image, gt_image, net_type='vgg').detach().mean().double()
356 |
357 | psnr_test /= len(config['cameras'])
358 | ssim_test /= len(config['cameras'])
359 | lpips_test /= len(config['cameras'])
360 | l1_test /= len(config['cameras'])
361 | logger.info("\n[ITER {}] Evaluating {}: L1 {} PSNR {} ssim {} lpips {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test))
362 | test_fps = 1.0 / torch.tensor(t_list[0:]).mean()
363 | logger.info(f'Test FPS: {test_fps.item():.5f}')
364 | if tb_writer:
365 | tb_writer.add_scalar(f'{dataset_name}/test_FPS', test_fps.item(), 0)
366 | if wandb is not None:
367 | wandb.log({"test_fps": test_fps, })
368 |
369 | if tb_writer:
370 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
371 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
372 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - ssim', ssim_test, iteration)
373 | tb_writer.add_scalar(f'{dataset_name}/'+config['name'] + '/loss_viewpoint - lpips', lpips_test, iteration)
374 | if wandb is not None:
375 | wandb.log({f"{config['name']}_loss_viewpoint_l1_loss":l1_test, f"{config['name']}_PSNR":psnr_test}, f"ssim{ssim_test}", f"lpips{lpips_test}")
376 |
377 | if tb_writer:
378 | # tb_writer.add_histogram(f'{dataset_name}/'+"scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
379 | tb_writer.add_scalar(f'{dataset_name}/'+'total_points', scene.gaussians.get_anchor.shape[0], iteration)
380 | torch.cuda.empty_cache()
381 |
382 | scene.gaussians.train()
383 |
384 |
385 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
386 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
387 | error_path = os.path.join(model_path, name, "ours_{}".format(iteration), "errors")
388 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
389 |
390 | makedirs(render_path, exist_ok=True)
391 | makedirs(error_path, exist_ok=True)
392 | makedirs(gts_path, exist_ok=True)
393 |
394 | t_list = []
395 | visible_count_list = []
396 | name_list = []
397 | per_view_dict = {}
398 | psnr_list = []
399 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
400 |
401 | torch.cuda.synchronize(); t_start = time.time()
402 | voxel_visible_mask = prefilter_voxel(view, gaussians, pipeline, background)
403 | render_pkg = render(view, gaussians, pipeline, background, visible_mask=voxel_visible_mask)
404 | torch.cuda.synchronize(); t_end = time.time()
405 |
406 | t_list.append(t_end - t_start)
407 |
408 | # renders
409 | rendering = torch.clamp(render_pkg["render"], 0.0, 1.0)
410 | visible_count = (render_pkg["radii"] > 0).sum()
411 | visible_count_list.append(visible_count)
412 |
413 | # gts
414 | gt = view.original_image[0:3, :, :]
415 |
416 | #
417 | gt_image = torch.clamp(view.original_image.to("cuda"), 0.0, 1.0)
418 | render_image = torch.clamp(rendering.to("cuda"), 0.0, 1.0)
419 | psnr_view = psnr(render_image, gt_image).mean().double()
420 | psnr_list.append(psnr_view)
421 |
422 | # error maps
423 | errormap = (rendering - gt).abs()
424 |
425 |
426 | name_list.append('{0:05d}'.format(idx) + ".png")
427 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
428 | torchvision.utils.save_image(errormap, os.path.join(error_path, '{0:05d}'.format(idx) + ".png"))
429 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
430 | per_view_dict['{0:05d}'.format(idx) + ".png"] = visible_count.item()
431 |
432 | with open(os.path.join(model_path, name, "ours_{}".format(iteration), "per_view_count.json"), 'w') as fp:
433 | json.dump(per_view_dict, fp, indent=True)
434 |
435 | print('testing_float_psnr=:', sum(psnr_list) / len(psnr_list))
436 |
437 | return t_list, visible_count_list
438 |
439 |
440 | def render_sets(args_param, dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train=True, skip_test=False, wandb=None, tb_writer=None, dataset_name=None, logger=None, x_bound_min=None, x_bound_max=None):
441 | with torch.no_grad():
442 | gaussians = GaussianModel(
443 | dataset.feat_dim,
444 | dataset.n_offsets,
445 | dataset.voxel_size,
446 | dataset.update_depth,
447 | dataset.update_init_factor,
448 | dataset.update_hierachy_factor,
449 | dataset.use_feat_bank,
450 | n_features_per_level=args_param.n_features,
451 | log2_hashmap_size=args_param.log2,
452 | log2_hashmap_size_2D=args_param.log2_2D,
453 | decoded_version=run_codec,
454 | )
455 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
456 | gaussians.eval()
457 | if x_bound_min is not None:
458 | gaussians.x_bound_min = x_bound_min
459 | gaussians.x_bound_max = x_bound_max
460 |
461 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
462 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
463 |
464 | if not skip_train:
465 | t_train_list, _ = render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
466 | train_fps = 1.0 / torch.tensor(t_train_list[5:]).mean()
467 | logger.info(f'Train FPS: \033[1;35m{train_fps.item():.5f}\033[0m')
468 | if wandb is not None:
469 | wandb.log({"train_fps":train_fps.item(), })
470 |
471 | if not skip_test:
472 | t_test_list, visible_count = render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
473 | test_fps = 1.0 / torch.tensor(t_test_list[5:]).mean()
474 | logger.info(f'Test FPS: \033[1;35m{test_fps.item():.5f}\033[0m')
475 | if tb_writer:
476 | tb_writer.add_scalar(f'{dataset_name}/test_FPS', test_fps.item(), 0)
477 | if wandb is not None:
478 | wandb.log({"test_fps":test_fps, })
479 |
480 | return visible_count
481 |
482 |
483 | def readImages(renders_dir, gt_dir):
484 | renders = []
485 | gts = []
486 | image_names = []
487 | for fname in os.listdir(renders_dir):
488 | render = Image.open(renders_dir / fname)
489 | gt = Image.open(gt_dir / fname)
490 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
491 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
492 | image_names.append(fname)
493 | return renders, gts, image_names
494 |
495 |
496 | def evaluate(model_paths, visible_count=None, wandb=None, tb_writer=None, dataset_name=None, logger=None):
497 |
498 | full_dict = {}
499 | per_view_dict = {}
500 | full_dict_polytopeonly = {}
501 | per_view_dict_polytopeonly = {}
502 | print("")
503 |
504 | scene_dir = model_paths
505 | full_dict[scene_dir] = {}
506 | per_view_dict[scene_dir] = {}
507 | full_dict_polytopeonly[scene_dir] = {}
508 | per_view_dict_polytopeonly[scene_dir] = {}
509 |
510 | test_dir = Path(scene_dir) / "test"
511 |
512 | for method in os.listdir(test_dir):
513 |
514 | full_dict[scene_dir][method] = {}
515 | per_view_dict[scene_dir][method] = {}
516 | full_dict_polytopeonly[scene_dir][method] = {}
517 | per_view_dict_polytopeonly[scene_dir][method] = {}
518 |
519 | method_dir = test_dir / method
520 | gt_dir = method_dir/ "gt"
521 | renders_dir = method_dir / "renders"
522 | renders, gts, image_names = readImages(renders_dir, gt_dir)
523 |
524 | ssims = []
525 | psnrs = []
526 | lpipss = []
527 |
528 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
529 | ssims.append(ssim(renders[idx], gts[idx]))
530 | psnrs.append(psnr(renders[idx], gts[idx]))
531 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
532 |
533 | if wandb is not None:
534 | wandb.log({"test_SSIMS":torch.stack(ssims).mean().item(), })
535 | wandb.log({"test_PSNR_final":torch.stack(psnrs).mean().item(), })
536 | wandb.log({"test_LPIPS":torch.stack(lpipss).mean().item(), })
537 |
538 | logger.info(f"model_paths: \033[1;35m{model_paths}\033[0m")
539 | logger.info(" SSIM : \033[1;35m{:>12.7f}\033[0m".format(torch.tensor(ssims).mean(), ".5"))
540 | logger.info(" PSNR : \033[1;35m{:>12.7f}\033[0m".format(torch.tensor(psnrs).mean(), ".5"))
541 | logger.info(" LPIPS: \033[1;35m{:>12.7f}\033[0m".format(torch.tensor(lpipss).mean(), ".5"))
542 | print("")
543 |
544 |
545 | if tb_writer:
546 | tb_writer.add_scalar(f'{dataset_name}/SSIM', torch.tensor(ssims).mean().item(), 0)
547 | tb_writer.add_scalar(f'{dataset_name}/PSNR', torch.tensor(psnrs).mean().item(), 0)
548 | tb_writer.add_scalar(f'{dataset_name}/LPIPS', torch.tensor(lpipss).mean().item(), 0)
549 |
550 | tb_writer.add_scalar(f'{dataset_name}/VISIBLE_NUMS', torch.tensor(visible_count).mean().item(), 0)
551 |
552 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
553 | "PSNR": torch.tensor(psnrs).mean().item(),
554 | "LPIPS": torch.tensor(lpipss).mean().item()})
555 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
556 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
557 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)},
558 | "VISIBLE_COUNT": {name: vc for vc, name in zip(torch.tensor(visible_count).tolist(), image_names)}})
559 |
560 | with open(scene_dir + "/results.json", 'w') as fp:
561 | json.dump(full_dict[scene_dir], fp, indent=True)
562 | with open(scene_dir + "/per_view.json", 'w') as fp:
563 | json.dump(per_view_dict[scene_dir], fp, indent=True)
564 |
565 | def get_logger(path):
566 | import logging
567 |
568 | logger = logging.getLogger()
569 | logger.setLevel(logging.INFO)
570 | fileinfo = logging.FileHandler(os.path.join(path, "outputs.log"))
571 | fileinfo.setLevel(logging.INFO)
572 | controlshow = logging.StreamHandler()
573 | controlshow.setLevel(logging.INFO)
574 | formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s")
575 | fileinfo.setFormatter(formatter)
576 | controlshow.setFormatter(formatter)
577 |
578 | logger.addHandler(fileinfo)
579 | logger.addHandler(controlshow)
580 |
581 | return logger
582 |
583 | if __name__ == "__main__":
584 | # Set up command line argument parser
585 | parser = ArgumentParser(description="Training script parameters")
586 | lp = ModelParams(parser)
587 | op = OptimizationParams(parser)
588 | pp = PipelineParams(parser)
589 | parser.add_argument('--ip', type=str, default="127.0.0.1")
590 | parser.add_argument('--port', type=int, default=6009)
591 | parser.add_argument('--debug_from', type=int, default=-1)
592 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
593 | parser.add_argument('--warmup', action='store_true', default=False)
594 | parser.add_argument('--use_wandb', action='store_true', default=False)
595 | # parser.add_argument("--test_iterations", nargs="+", type=int, default=[11_000, 15_000, 20_000, 25_000, 29_000, 30_000])
596 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[30_000])
597 | # parser.add_argument("--save_iterations", nargs="+", type=int, default=[11_000, 15_000, 20_000, 25_000, 29_000, 30_000])
598 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[30_000])
599 | parser.add_argument("--quiet", action="store_true")
600 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
601 | parser.add_argument("--start_checkpoint", type=str, default = None)
602 | parser.add_argument("--gpu", type=str, default = '-1')
603 | parser.add_argument("--log2", type=int, default = 13)
604 | parser.add_argument("--log2_2D", type=int, default = 15)
605 | parser.add_argument("--n_features", type=int, default = 4)
606 | parser.add_argument("--lmbda", type=float, default = 0.001)
607 | args = parser.parse_args(sys.argv[1:])
608 | args.save_iterations.append(args.iterations)
609 |
610 |
611 | # enable logging
612 |
613 | model_path = args.model_path
614 | os.makedirs(model_path, exist_ok=True)
615 |
616 | logger = get_logger(model_path)
617 |
618 |
619 | logger.info(f'args: {args}')
620 |
621 | if args.gpu != '-1':
622 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
623 | os.system("echo $CUDA_VISIBLE_DEVICES")
624 | logger.info(f'using GPU {args.gpu}')
625 |
626 | '''try:
627 | saveRuntimeCode(os.path.join(args.model_path, 'backup'))
628 | except:
629 | logger.info(f'save code failed~')'''
630 |
631 | dataset = args.source_path.split('/')[-1]
632 | exp_name = args.model_path.split('/')[-2]
633 |
634 | if args.use_wandb:
635 | wandb.login()
636 | run = wandb.init(
637 | # Set the project where this run will be logged
638 | project=f"Scaffold-GS-{dataset}",
639 | name=exp_name,
640 | # Track hyperparameters and run metadata
641 | settings=wandb.Settings(start_method="fork"),
642 | config=vars(args)
643 | )
644 | else:
645 | wandb = None
646 |
647 | logger.info("Optimizing " + args.model_path)
648 |
649 | # Initialize system state (RNG)
650 | safe_state(args.quiet)
651 |
652 | # Start GUI server, configure and run training
653 | args.port = np.random.randint(10000, 20000)
654 | # network_gui.init(args.ip, args.port)
655 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
656 |
657 | # training
658 | x_bound_min, x_bound_max = training(args, lp.extract(args), op.extract(args), pp.extract(args), dataset, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, wandb, logger)
659 | if args.warmup:
660 | logger.info("\n Warmup finished! Reboot from last checkpoints")
661 | new_ply_path = os.path.join(args.model_path, f'point_cloud/iteration_{args.iterations}', 'point_cloud.ply')
662 | x_bound_min, x_bound_max = training(args, lp.extract(args), op.extract(args), pp.extract(args), dataset, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, wandb=wandb, logger=logger, ply_path=new_ply_path)
663 |
664 | # All done
665 | logger.info("\nTraining complete.")
666 |
667 | # rendering
668 | logger.info(f'\nStarting Rendering~')
669 | visible_count = render_sets(args, lp.extract(args), -1, pp.extract(args), wandb=wandb, logger=logger, x_bound_min=x_bound_min, x_bound_max=x_bound_max)
670 | logger.info("\nRendering complete.")
671 |
672 | # calc metrics
673 | logger.info("\n Starting evaluation...")
674 | evaluate(args.model_path, visible_count=visible_count, wandb=wandb, logger=logger)
675 | logger.info("\nEvaluating complete.")
676 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from scene.cameras import Camera
13 | import numpy as np
14 | from utils.general_utils import PILtoTorch
15 | from utils.graphics_utils import fov2focal
16 |
17 | WARNED = False
18 |
19 | def loadCam(args, id, cam_info, resolution_scale):
20 | orig_w, orig_h = cam_info.image.size
21 |
22 | if args.resolution in [1, 2, 4, 8]:
23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
24 | else: # should be a type that converts to float
25 | if args.resolution == -1:
26 | if orig_w > 1600:
27 | global WARNED
28 | if not WARNED:
29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
31 | WARNED = True
32 | global_down = orig_w / 1600
33 | else:
34 | global_down = 1
35 | else:
36 | global_down = orig_w / args.resolution
37 |
38 | scale = float(global_down) * float(resolution_scale)
39 | resolution = (int(orig_w / scale), int(orig_h / scale))
40 |
41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42 |
43 | gt_image = resized_image_rgb[:3, ...]
44 | loaded_mask = None
45 |
46 | # print(f'gt_image: {gt_image.shape}')
47 | if resized_image_rgb.shape[1] == 4:
48 | loaded_mask = resized_image_rgb[3:4, ...]
49 |
50 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
51 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
52 | image=gt_image, gt_alpha_mask=loaded_mask,
53 | image_name=cam_info.image_name, uid=id, data_device=args.data_device)
54 |
55 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
56 | camera_list = []
57 |
58 | for id, c in enumerate(cam_infos):
59 | camera_list.append(loadCam(args, id, c, resolution_scale))
60 |
61 | return camera_list
62 |
63 | def camera_to_JSON(id, camera : Camera):
64 | Rt = np.zeros((4, 4))
65 | Rt[:3, :3] = camera.R.transpose()
66 | Rt[:3, 3] = camera.T
67 | Rt[3, 3] = 1.0
68 |
69 | W2C = np.linalg.inv(Rt)
70 | pos = W2C[:3, 3]
71 | rot = W2C[:3, :3]
72 | serializable_array_2d = [x.tolist() for x in rot]
73 | camera_entry = {
74 | 'id' : id,
75 | 'img_name' : camera.image_name,
76 | 'width' : camera.width,
77 | 'height' : camera.height,
78 | 'position': pos.tolist(),
79 | 'rotation': serializable_array_2d,
80 | 'fy' : fov2focal(camera.FovY, camera.height),
81 | 'fx' : fov2focal(camera.FovX, camera.width)
82 | }
83 | return camera_entry
84 |
85 |
86 |
--------------------------------------------------------------------------------
/utils/encodings.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.cuda.amp import custom_bwd, custom_fwd
5 | import numpy as np
6 | import math
7 | import multiprocessing
8 |
9 | import _gridencoder as _backend
10 |
11 | anchor_round_digits = 16
12 | Q_anchor = 1/(2 ** anchor_round_digits - 1)
13 | use_clamp = True
14 | use_multiprocessor = False # Always False plz. Not yet implemented for True.
15 |
16 | def get_binary_vxl_size(binary_vxl):
17 | # binary_vxl: {0, 1}
18 | # assert torch.unique(binary_vxl).mean() == 0.5
19 | ttl_num = binary_vxl.numel()
20 |
21 | pos_num = torch.sum(binary_vxl)
22 | neg_num = ttl_num - pos_num
23 |
24 | Pg = pos_num / ttl_num # + 1e-6
25 | Pg = torch.clamp(Pg, min=1e-6, max=1-1e-6)
26 | pos_prob = Pg
27 | neg_prob = (1 - Pg)
28 | pos_bit = pos_num * (-torch.log2(pos_prob))
29 | neg_bit = neg_num * (-torch.log2(neg_prob))
30 | ttl_bit = pos_bit + neg_bit
31 | ttl_bit += 32 # Pg
32 | # print('binary_vxl:', Pg.item(), ttl_bit.item(), ttl_num, pos_num.item(), neg_num.item())
33 | return Pg, ttl_bit, ttl_bit.item()/8.0/1024/1024, ttl_num
34 |
35 | class STE_binary(torch.autograd.Function):
36 | @staticmethod
37 | def forward(ctx, input):
38 | ctx.save_for_backward(input)
39 | input = torch.clamp(input, min=-1, max=1)
40 | # out = torch.sign(input)
41 | p = (input >= 0) * (+1.0)
42 | n = (input < 0) * (-1.0)
43 | out = p + n
44 | return out
45 | @staticmethod
46 | def backward(ctx, grad_output):
47 | # mask: to ensure x belongs to (-1, 1)
48 | input, = ctx.saved_tensors
49 | i2 = input.clone().detach()
50 | i3 = torch.clamp(i2, -1, 1)
51 | mask = (i3 == i2) + 0.0
52 | return grad_output * mask
53 |
54 |
55 | class STE_multistep(torch.autograd.Function):
56 | @staticmethod
57 | def forward(ctx, input, Q, input_mean=None):
58 | if use_clamp:
59 | if input_mean is None:
60 | input_mean = input.mean()
61 | input_min = input_mean - 15_000 * Q
62 | input_max = input_mean + 15_000 * Q
63 | input = torch.clamp(input, min=input_min.detach(), max=input_max.detach())
64 |
65 | Q_round = torch.round(input / Q)
66 | Q_q = Q_round * Q
67 | return Q_q
68 | @staticmethod
69 | def backward(ctx, grad_output):
70 | return grad_output, None
71 |
72 |
73 | class Quantize_anchor(torch.autograd.Function):
74 | @staticmethod
75 | def forward(ctx, anchors, min_v, max_v):
76 | # if anchor_round_digits == 32:
77 | # return anchors
78 | # min_v = torch.min(anchors).detach()
79 | # max_v = torch.max(anchors).detach()
80 | # scales = 2 ** anchor_round_digits - 1
81 | interval = ((max_v - min_v) * Q_anchor + 1e-6) # avoid 0, if max_v == min_v
82 | # quantized_v = (anchors - min_v) // interval
83 | quantized_v = torch.div(anchors - min_v, interval, rounding_mode='floor')
84 | quantized_v = torch.clamp(quantized_v, 0, 2 ** anchor_round_digits - 1)
85 | anchors_q = quantized_v * interval + min_v
86 | return anchors_q, quantized_v
87 | @staticmethod
88 | def backward(ctx, grad_output, tmp): # tmp is for quantized_v:)
89 | return grad_output, None, None
90 |
91 |
92 | class _grid_encode(Function):
93 | @staticmethod
94 | @custom_fwd
95 | def forward(ctx, inputs, embeddings, offsets_list, resolutions_list, calc_grad_inputs=False, min_level_id=None, n_levels_calc=1, binary_vxl=None, PV=0):
96 | # inputs: [N, num_dim], float in [0, 1]
97 | # embeddings: [sO, n_features], float. self.params = nn.Parameter(torch.empty(offset, n_features))
98 | # offsets_list: [n_levels + 1], int
99 | # RETURN: [N, F], float
100 | inputs = inputs.contiguous()
101 | # embeddings_mask = torch.ones(size=[embeddings.shape[0]], dtype=torch.bool, device='cuda')
102 | # print('kkkkkkkkkk---000000000:', embeddings_mask.shape, embeddings.shape, embeddings_mask.sum())
103 |
104 | Rb = 128
105 | if binary_vxl is not None:
106 | binary_vxl = binary_vxl.contiguous()
107 | Rb = binary_vxl.shape[-1]
108 | assert len(binary_vxl.shape) == inputs.shape[-1]
109 |
110 | N, num_dim = inputs.shape # batch size, coord dim # N_rays, 3
111 | n_levels = offsets_list.shape[0] - 1 # level # 层数=16
112 | n_features = embeddings.shape[1] # embedding dim for each level # 就是channel数=2
113 |
114 | max_level_id = min_level_id + n_levels_calc
115 |
116 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
117 | # if n_features % 2 != 0, force float, since half for atomicAdd is very slow.
118 | if torch.is_autocast_enabled() and n_features % 2 == 0:
119 | embeddings = embeddings.to(torch.half)
120 |
121 | # n_levels first, optimize cache for cuda kernel, but needs an extra permute later
122 | outputs = torch.empty(n_levels_calc, N, n_features, device=inputs.device, dtype=embeddings.dtype) # 创建一个buffer给cuda填充
123 | # outputs = [hash层数=16, N_rays, channels=2]
124 |
125 | # zero init if we only calculate partial levels
126 | # if n_levels_calc < n_levels: outputs.zero_()
127 | if calc_grad_inputs: # inputs.requires_grad
128 | dy_dx = torch.empty(N, n_levels_calc * num_dim * n_features, device=inputs.device, dtype=embeddings.dtype)
129 | else:
130 | dy_dx = None
131 |
132 | # assert embeddings.shape[0] == embeddings_mask.shape[0]
133 | # assert embeddings_mask.dtype == torch.bool
134 |
135 | if isinstance(min_level_id, int):
136 | _backend.grid_encode_forward(
137 | inputs,
138 | embeddings,
139 | # embeddings_mask,
140 | offsets_list[min_level_id:max_level_id+1],
141 | resolutions_list[min_level_id:max_level_id],
142 | outputs,
143 | N, num_dim, n_features, n_levels_calc, 0, Rb, PV,
144 | dy_dx,
145 | binary_vxl,
146 | None
147 | )
148 | else:
149 | _backend.grid_encode_forward(
150 | inputs,
151 | embeddings,
152 | # embeddings_mask,
153 | offsets_list,
154 | resolutions_list,
155 | outputs,
156 | N, num_dim, n_features, n_levels_calc, 0, Rb, PV,
157 | dy_dx,
158 | binary_vxl,
159 | min_level_id
160 | )
161 |
162 | # permute back to [N, n_levels * n_features] # [N_rays, hash层数=16 * channels=2]
163 | outputs = outputs.permute(1, 0, 2).reshape(N, n_levels_calc * n_features)
164 |
165 | ctx.save_for_backward(inputs, embeddings, offsets_list, resolutions_list, dy_dx, binary_vxl)
166 | ctx.dims = [N, num_dim, n_features, n_levels_calc, min_level_id, max_level_id, Rb, PV] # min_level_id是否要单独save为tensor
167 |
168 | return outputs
169 |
170 | @staticmethod
171 | #@once_differentiable
172 | @custom_bwd
173 | def backward(ctx, grad):
174 |
175 | inputs, embeddings, offsets_list, resolutions_list, dy_dx, binary_vxl = ctx.saved_tensors
176 | N, num_dim, n_features, n_levels_calc, min_level_id, max_level_id, Rb, PV = ctx.dims
177 |
178 | # grad: [N, n_levels * n_features] --> [n_levels, N, n_features]
179 | grad = grad.view(N, n_levels_calc, n_features).permute(1, 0, 2).contiguous()
180 |
181 | grad_embeddings = torch.zeros_like(embeddings)
182 |
183 | if dy_dx is not None:
184 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
185 | else:
186 | grad_inputs = None
187 |
188 | if isinstance(min_level_id, int):
189 | _backend.grid_encode_backward(
190 | grad,
191 | inputs,
192 | embeddings,
193 | # embeddings_mask,
194 | offsets_list[min_level_id:max_level_id+1],
195 | resolutions_list[min_level_id:max_level_id],
196 | grad_embeddings,
197 | N, num_dim, n_features, n_levels_calc, 0, Rb,
198 | dy_dx,
199 | grad_inputs,
200 | binary_vxl,
201 | None
202 | )
203 | else:
204 | _backend.grid_encode_backward(
205 | grad,
206 | inputs,
207 | embeddings,
208 | # embeddings_mask,
209 | offsets_list,
210 | resolutions_list,
211 | grad_embeddings,
212 | N, num_dim, n_features, n_levels_calc, 0, Rb,
213 | dy_dx,
214 | grad_inputs,
215 | binary_vxl,
216 | min_level_id
217 | )
218 |
219 | if dy_dx is not None:
220 | grad_inputs = grad_inputs.to(inputs.dtype)
221 |
222 | return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None
223 | grid_encode = _grid_encode.apply
224 | class GridEncoder(nn.Module):
225 | def __init__(self,
226 | num_dim=3,
227 | n_features=2,
228 | resolutions_list=(16, 23, 32, 46, 64, 92, 128, 184, 256, 368, 512, 736),
229 | log2_hashmap_size=19,
230 | ste_binary = True,
231 | ste_multistep = False,
232 | add_noise = False,
233 | Q = 1
234 | ):
235 | super().__init__()
236 |
237 | resolutions_list = torch.tensor(resolutions_list).to(torch.int)
238 | n_levels = resolutions_list.numel()
239 |
240 | self.num_dim = num_dim # coord dims, 2 or 3
241 | self.n_levels = n_levels # num levels, each level multiply resolution by 2
242 | self.n_features = n_features # encode channels per level
243 | self.log2_hashmap_size = log2_hashmap_size
244 | self.output_dim = n_levels * n_features
245 | self.ste_binary = ste_binary
246 | self.ste_multistep = ste_multistep
247 | self.add_noise = add_noise
248 | self.Q = Q
249 |
250 | # allocate parameters
251 | offsets_list = []
252 | offset = 0
253 | self.max_params = 2 ** log2_hashmap_size
254 | for i in range(n_levels):
255 | resolution = resolutions_list[i].item()
256 | params_in_level = min(self.max_params, resolution ** num_dim) # limit max number
257 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
258 | offsets_list.append(offset)
259 | offset += params_in_level
260 | offsets_list.append(offset)
261 | offsets_list = torch.from_numpy(np.array(offsets_list, dtype=np.int32))
262 | self.register_buffer('offsets_list', offsets_list)
263 | self.register_buffer('resolutions_list', resolutions_list)
264 |
265 | self.n_params = offsets_list[-1] * n_features
266 |
267 | # parameters
268 | self.params = nn.Parameter(torch.empty(offset, n_features))
269 |
270 | self.reset_parameters()
271 |
272 | self.n_output_dims = n_levels * n_features
273 |
274 | def reset_parameters(self):
275 | std = 1e-4
276 | self.params.data.uniform_(-std, std)
277 |
278 | def __repr__(self):
279 | return f"GridEncoder: num_dim={self.num_dim} n_levels={self.n_levels} n_features={self.n_features} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.n_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.params.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
280 |
281 | def forward(self, inputs, min_level_id=None, max_level_id=None, test_phase=False, outspace_params=None, binary_vxl=None, PV=0):
282 | # inputs: [..., num_dim], normalized real world positions in [0, 1]
283 | # return: [..., n_levels * n_features]
284 |
285 | prefix_shape = list(inputs.shape[:-1])
286 | inputs = inputs.view(-1, self.num_dim)
287 |
288 | if outspace_params is not None:
289 | params = nn.Parameter(outspace_params)
290 | else:
291 | params = self.params
292 |
293 | if self.ste_binary:
294 | embeddings = STE_binary.apply(params)
295 | # embeddings = params
296 | elif (self.add_noise and not test_phase):
297 | embeddings = params + (torch.rand_like(params) - 0.5) * (1 / self.Q)
298 | elif (self.ste_multistep) or (self.add_noise and test_phase):
299 | embeddings = STE_multistep.apply(params, self.Q)
300 | else:
301 | embeddings = params
302 | # embeddings = embeddings * 0 # for ablation
303 |
304 | min_level_id = 0 if min_level_id is None else max(min_level_id, 0)
305 | max_level_id = self.n_levels if max_level_id is None else min(max_level_id, self.n_levels)
306 | n_levels_calc = max_level_id - min_level_id
307 |
308 | outputs = grid_encode(inputs, embeddings, self.offsets_list, self.resolutions_list, inputs.requires_grad, min_level_id, n_levels_calc, binary_vxl, PV)
309 | outputs = outputs.view(prefix_shape + [n_levels_calc * self.n_features])
310 |
311 | return outputs
312 |
--------------------------------------------------------------------------------
/utils/encodings_cuda.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import arithmetic
5 |
6 | chunk_size_cuda = 10000
7 |
8 | class STE_multistep(torch.autograd.Function):
9 | @staticmethod
10 | def forward(ctx, input, Q, input_mean=None):
11 | Q_round = torch.round(input / Q)
12 | Q_q = Q_round * Q
13 | return Q_q
14 | @staticmethod
15 | def backward(ctx, grad_output):
16 | return grad_output, None
17 |
18 | def get_binary_vxl_size(binary_vxl):
19 | # binary_vxl: {0, 1}
20 | # assert torch.unique(binary_vxl).mean() == 0.5
21 | ttl_num = binary_vxl.numel()
22 |
23 | pos_num = torch.sum(binary_vxl)
24 | neg_num = ttl_num - pos_num
25 |
26 | Pg = pos_num / ttl_num # + 1e-6
27 | Pg = torch.clamp(Pg, min=1e-6, max=1-1e-6)
28 | pos_prob = Pg
29 | neg_prob = (1 - Pg)
30 | pos_bit = pos_num * (-torch.log2(pos_prob))
31 | neg_bit = neg_num * (-torch.log2(neg_prob))
32 | ttl_bit = pos_bit + neg_bit
33 | ttl_bit += 32 # Pg
34 | # print('binary_vxl:', Pg.item(), ttl_bit.item(), ttl_num, pos_num.item(), neg_num.item())
35 | return Pg, ttl_bit, ttl_bit.item()/8.0/1024/1024, ttl_num
36 |
37 |
38 | def encoder_factorized_chunk(x, lower_func, Q:float = 1, file_name='tmp.b', chunk_size=1000_0000):
39 | # should be with 2 dimensions
40 | # lower_func: xxx._logits_cumulative
41 | assert file_name.endswith('.b')
42 | assert len(x.shape) == 2
43 | N = x.shape[0]
44 | chunks = int(np.ceil(N / chunk_size))
45 | bit_len_list = []
46 | for c in range(chunks):
47 | bit_len = encoder_factorized(
48 | x=x[c * chunk_size:c * chunk_size + chunk_size],
49 | lower_func=lower_func,
50 | Q=Q,
51 | file_name=file_name.replace('.b', f'_{str(c)}.b'),
52 | )
53 | bit_len_list.append(bit_len)
54 | return sum(bit_len_list)
55 |
56 |
57 | def encoder_factorized(x, lower_func, Q:float = 1, file_name='tmp.b'):
58 | '''
59 | The reason why int(max_value.item()) + 1 or int(max_value.item()) + 1 + 1:
60 | first 1: range does not include the last value, so +1
61 | second 1: if directly calculate, we need to use samples - 0.5, in order to include the whole value space,
62 | the max bound value after -0.5 should be max_value+0.5.
63 |
64 | Here we do not add the second 1, because we use pmf to calculate cdf, instead of directly calculate cdf
65 |
66 | example in here ("`" means sample-0.5 places, "|" means sample places):
67 | ` ` ` ` ` ` ` ` ` ` ` ` `
68 | lkl_lower | | | | lkl_upper | | | | -> cdf_lower | | | |
69 |
70 | example in other place ("`" means sample-0.5 places, "|" means sample places):
71 | ` ` ` ` `
72 | cdf_lower | | | |
73 |
74 | '''
75 | # should be with 2 dimensions
76 | # lower_func: xxx._logits_cumulative
77 | assert file_name.endswith('.b')
78 | assert len(x.shape) == 2
79 | x_int_round = torch.round(x / Q) # [100]
80 | max_value = x_int_round.max()
81 | min_value = x_int_round.min()
82 | samples = torch.tensor(range(int(min_value.item()), int(max_value.item()) + 1)).to(
83 | torch.float).to(x.device) # from min_value to max_value+1. shape = [max_value+1 - min_value]
84 | samples = samples.unsqueeze(0).unsqueeze(0).repeat(x.shape[-1], 1, 1) # [256, 1, max_value+1 - min_value]
85 | # lower_func: [C, 1, N]
86 | lower = lower_func((samples - 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value]
87 | upper = lower_func((samples + 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value]
88 | sign = -torch.sign(torch.add(lower, upper))
89 | sign = sign.detach()
90 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) # [256, 1, max_value+1 - min_value]
91 | cdf = torch.cumsum(pmf, dim=-1)
92 | lower = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [256, 1, max_value+1+1 - min_value]
93 | lower = lower.permute(1, 0, 2).contiguous().repeat(x.shape[0], 1, 1) # [100, 256, max_value+1+1 - min_value]
94 | x_int_round_idx = (x_int_round - min_value).to(torch.int16)
95 | x_int_round_idx = x_int_round_idx.view(-1) # [100*256]
96 | lower = lower.view(x_int_round_idx.shape[0], -1) # [100*256, max_value+1+1 - min_value]
97 | lower = torch.clamp(lower, min=0.0, max=1.0)
98 | assert (x_int_round_idx.to(torch.int32) == x_int_round.view(-1) - min_value).all()
99 |
100 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode(
101 | x_int_round_idx,
102 | lower,
103 | chunk_size_cuda,
104 | int(lower.shape[0]),
105 | int(lower.shape[1])
106 | )
107 | cnt_bytes = cnt_torch.cpu().numpy().tobytes()
108 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes()
109 | len_cnt_bytes = len(cnt_bytes)
110 | with open(file_name, 'wb') as fout:
111 | fout.write(min_value.to(torch.float32).cpu().numpy().tobytes())
112 | fout.write(max_value.to(torch.float32).cpu().numpy().tobytes())
113 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes())
114 | fout.write(cnt_bytes)
115 | fout.write(byte_stream_bytes)
116 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes)) * 8 + 32 * 3
117 | return bit_len
118 |
119 | def decoder_factorized_chunk(lower_func, Q, N_len, dim, file_name='tmp.b', device='cuda', chunk_size=1000_0000):
120 | assert file_name.endswith('.b')
121 | chunks = int(np.ceil(N_len / chunk_size))
122 | x_c_list = []
123 | for c in range(chunks):
124 | x_c = decoder_factorized(
125 | lower_func=lower_func,
126 | Q=Q,
127 | N_len=min(chunk_size, N_len-c*chunk_size),
128 | dim=dim,
129 | file_name=file_name.replace('.b', f'_{str(c)}.b'),
130 | device=device,
131 | )
132 | x_c_list.append(x_c)
133 | x_c_list = torch.cat(x_c_list, dim=0)
134 | return x_c_list
135 |
136 |
137 | def decoder_factorized(lower_func, Q, N_len, dim, file_name='tmp.b', device='cuda'):
138 | assert file_name.endswith('.b')
139 |
140 | with open(file_name, 'rb') as fin:
141 | min_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda")
142 | max_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda")
143 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0]
144 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda")
145 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda")
146 |
147 | samples = torch.tensor(range(int(min_value.item()), int(max_value.item()) + 1)).to(
148 | torch.float).to(device) # from min_value to max_value+1. shape = [max_value+1 - min_value]
149 | samples = samples.unsqueeze(0).unsqueeze(0).repeat(dim, 1, 1) # [256, 1, max_value+1 - min_value]
150 |
151 | # lower_func: [C, 1, N]
152 | lower = lower_func((samples - 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value]
153 | upper = lower_func((samples + 0.5) * Q, stop_gradient=False) # [256, 1, max_value+1 - min_value]
154 | sign = -torch.sign(torch.add(lower, upper))
155 | sign = sign.detach()
156 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) # [256, 1, max_value+1 - min_value]
157 | cdf = torch.cumsum(pmf, dim=-1)
158 | lower = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [256, 1, max_value+1+1 - min_value]
159 | lower = lower.permute(1, 0, 2).contiguous().repeat(N_len, 1, 1) # [100, 256, max_value+1+1 - min_value]
160 | lower = lower.view(N_len*dim, -1) # [100*256, max_value+1+1 - min_value]
161 | lower = torch.clamp(lower, min=0.0, max=1.0)
162 |
163 | sym_out = arithmetic.arithmetic_decode(
164 | lower,
165 | byte_stream_torch,
166 | cnt_torch,
167 | chunk_size_cuda,
168 | int(lower.shape[0]),
169 | int(lower.shape[1])
170 | ).to(device).to(torch.float32)
171 | x = sym_out + min_value
172 | x = x * Q
173 | x = x.reshape(N_len, dim)
174 | return x
175 |
176 |
177 | def encoder_gaussian_mixed_chunk(x, mean_list, scale_list, prob_list, Q, file_name='tmp.b', chunk_size=1000_0000):
178 | assert file_name.endswith('.b')
179 | assert len(x.shape) == 1
180 | x_view = x.view(-1)
181 | mean_list_view = [mean.view(-1) for mean in mean_list]
182 | scale_list_view = [scale.view(-1) for scale in scale_list]
183 | prob_list_view = [prob.view(-1) for prob in prob_list]
184 | assert x_view.shape[0]==mean_list_view[0].shape[0]==scale_list_view[0].shape[0]==prob_list_view[0].shape[0]
185 | N = x_view.shape[0]
186 | chunks = int(np.ceil(N/chunk_size))
187 | Is_Q_tensor = isinstance(Q, torch.Tensor)
188 | if Is_Q_tensor: Q_view = Q.view(-1)
189 | bit_len_list = []
190 | for c in range(chunks):
191 | bit_len = encoder_gaussian_mixed(
192 | x=x_view[c*chunk_size:c*chunk_size + chunk_size],
193 | mean_list=[mean[c*chunk_size:c*chunk_size + chunk_size] for mean in mean_list_view],
194 | scale_list=[scale[c*chunk_size:c*chunk_size + chunk_size] for scale in scale_list_view],
195 | prob_list=[prob[c*chunk_size:c*chunk_size + chunk_size] for prob in prob_list_view],
196 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q,
197 | file_name=file_name.replace('.b', f'_{str(c)}.b'),
198 | )
199 | bit_len_list.append(bit_len)
200 | return sum(bit_len_list)
201 |
202 |
203 | def encoder_gaussian_mixed(x, mean_list, scale_list, prob_list, Q, file_name='tmp.b'):
204 | # should be with single dimension
205 | assert file_name.endswith('.b')
206 | assert len(x.shape) == 1
207 | if not isinstance(Q, torch.Tensor):
208 | Q = torch.tensor([Q], dtype=x.dtype, device=x.device).repeat(x.shape[0])
209 | assert x.shape == mean_list[0].shape == scale_list[0].shape == prob_list[0].shape == Q.shape, f'{x.shape}, {mean_list[0].shape}, {scale_list[0].shape}, {prob_list[0].shape}, {Q.shape}'
210 | x_int_round = torch.round(x / Q) # [100]
211 | max_value = x_int_round.max()
212 | min_value = x_int_round.min()
213 | lower_all = int(0)
214 | for (mean, scale, prob) in zip(mean_list, scale_list, prob_list):
215 | lower = arithmetic.calculate_cdf(
216 | mean,
217 | scale,
218 | Q,
219 | min_value,
220 | max_value
221 | ) * prob.unsqueeze(-1)
222 | if isinstance(lower_all, int):
223 | lower_all = lower
224 | else:
225 | lower_all += lower
226 | lower = torch.clamp(lower_all, min=0.0, max=1.0)
227 | del mean
228 | del scale
229 | del prob
230 |
231 | x_int_round_idx = (x_int_round - min_value).to(torch.int16)
232 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode(
233 | x_int_round_idx,
234 | lower,
235 | chunk_size_cuda,
236 | int(lower.shape[0]),
237 | int(lower.shape[1])
238 | )
239 | cnt_bytes = cnt_torch.cpu().numpy().tobytes()
240 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes()
241 | len_cnt_bytes = len(cnt_bytes)
242 | with open(file_name, 'wb') as fout:
243 | fout.write(min_value.to(torch.float32).cpu().numpy().tobytes())
244 | fout.write(max_value.to(torch.float32).cpu().numpy().tobytes())
245 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes())
246 | fout.write(cnt_bytes)
247 | fout.write(byte_stream_bytes)
248 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes))*8 + 32 * 3
249 | return bit_len
250 |
251 |
252 | def decoder_gaussian_mixed_chunk(mean_list, scale_list, prob_list, Q, file_name='tmp.b', chunk_size=1000_0000):
253 | assert file_name.endswith('.b')
254 | mean_list_view = [mean.view(-1) for mean in mean_list]
255 | scale_list_view = [scale.view(-1) for scale in scale_list]
256 | prob_list_view = [prob.view(-1) for prob in prob_list]
257 | N = mean_list_view[0].shape[0]
258 | chunks = int(np.ceil(N/chunk_size))
259 | Is_Q_tensor = isinstance(Q, torch.Tensor)
260 | if Is_Q_tensor: Q_view = Q.view(-1)
261 | x_c_list = []
262 | for c in range(chunks):
263 | x_c = decoder_gaussian_mixed(
264 | mean_list=[mean[c*chunk_size:c*chunk_size + chunk_size] for mean in mean_list_view],
265 | scale_list=[scale[c*chunk_size:c*chunk_size + chunk_size] for scale in scale_list_view],
266 | prob_list=[prob[c*chunk_size:c*chunk_size + chunk_size] for prob in prob_list_view],
267 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q,
268 | file_name=file_name.replace('.b', f'_{str(c)}.b'),
269 | )
270 | x_c_list.append(x_c)
271 | x_c_list = torch.cat(x_c_list, dim=0).type_as(mean_list[0])
272 | return x_c_list
273 |
274 |
275 | def decoder_gaussian_mixed(mean_list, scale_list, prob_list, Q, file_name='tmp.b'):
276 | assert file_name.endswith('.b')
277 | m0 = mean_list[0]
278 | if not isinstance(Q, torch.Tensor):
279 | Q = torch.tensor([Q], dtype=m0.dtype, device=m0.device).repeat(m0.shape[0])
280 | assert mean_list[0].shape == scale_list[0].shape == prob_list[0].shape == Q.shape
281 |
282 | with open(file_name, 'rb') as fin:
283 | min_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda")
284 | max_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda")
285 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0]
286 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda")
287 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda")
288 |
289 | lower_all = int(0)
290 | for (mean, scale, prob) in zip(mean_list, scale_list, prob_list):
291 | lower = arithmetic.calculate_cdf(
292 | mean,
293 | scale,
294 | Q,
295 | min_value,
296 | max_value
297 | ) * prob.unsqueeze(-1)
298 | if isinstance(lower_all, int):
299 | lower_all = lower
300 | else:
301 | lower_all += lower
302 | lower = torch.clamp(lower_all, min=0.0, max=1.0)
303 |
304 | sym_out = arithmetic.arithmetic_decode(
305 | lower,
306 | byte_stream_torch,
307 | cnt_torch,
308 | chunk_size_cuda,
309 | int(lower.shape[0]),
310 | int(lower.shape[1])
311 | ).to(mean.device).to(torch.float32)
312 | x = sym_out + min_value
313 | x = x * Q
314 | return x
315 |
316 |
317 | def encoder_gaussian_chunk(x, mean, scale, Q, file_name='tmp.b', chunk_size=1000_0000):
318 | assert file_name.endswith('.b')
319 | assert len(x.shape) == 1
320 | x_view = x.view(-1)
321 | mean_view = mean.view(-1)
322 | scale_view = scale.view(-1)
323 | N = x_view.shape[0]
324 | chunks = int(np.ceil(N/chunk_size))
325 | Is_Q_tensor = isinstance(Q, torch.Tensor)
326 | if Is_Q_tensor: Q_view = Q.view(-1)
327 | bit_len_list = []
328 | for c in range(chunks):
329 | bit_len = encoder_gaussian(
330 | x=x_view[c*chunk_size:c*chunk_size + chunk_size],
331 | mean=mean_view[c*chunk_size:c*chunk_size + chunk_size],
332 | scale=scale_view[c*chunk_size:c*chunk_size + chunk_size],
333 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q,
334 | file_name=file_name.replace('.b', f'_{str(c)}.b'),
335 | )
336 | bit_len_list.append(bit_len)
337 | return sum(bit_len_list)
338 |
339 |
340 | def encoder_gaussian(x, mean, scale, Q, file_name='tmp.b'):
341 | # should be single dimension
342 | assert file_name.endswith('.b')
343 | assert len(x.shape) == 1
344 | if not isinstance(Q, torch.Tensor):
345 | Q = torch.tensor([Q], dtype=mean.dtype, device=mean.device).repeat(mean.shape[0])
346 | x_int_round = torch.round(x / Q) # [100]
347 | max_value = x_int_round.max()
348 | min_value = x_int_round.min()
349 |
350 | lower = arithmetic.calculate_cdf(
351 | mean,
352 | scale,
353 | Q,
354 | min_value,
355 | max_value
356 | )
357 |
358 | x_int_round_idx = (x_int_round - min_value).to(torch.int16)
359 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode(
360 | x_int_round_idx,
361 | lower,
362 | chunk_size_cuda,
363 | int(lower.shape[0]),
364 | int(lower.shape[1])
365 | )
366 | cnt_bytes = cnt_torch.cpu().numpy().tobytes()
367 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes()
368 | len_cnt_bytes = len(cnt_bytes)
369 | with open(file_name, 'wb') as fout:
370 | fout.write(min_value.to(torch.float32).cpu().numpy().tobytes())
371 | fout.write(max_value.to(torch.float32).cpu().numpy().tobytes())
372 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes())
373 | fout.write(cnt_bytes)
374 | fout.write(byte_stream_bytes)
375 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes))*8 + 32 * 3
376 | return bit_len
377 |
378 | def decoder_gaussian_chunk(mean, scale, Q, file_name='tmp.b', chunk_size=1000_0000):
379 | assert file_name.endswith('.b')
380 | mean_view = mean.view(-1)
381 | scale_view = scale.view(-1)
382 | N = mean_view.shape[0]
383 | chunks = int(np.ceil(N/chunk_size))
384 | Is_Q_tensor = isinstance(Q, torch.Tensor)
385 | if Is_Q_tensor: Q_view = Q.view(-1)
386 | x_c_list = []
387 | for c in range(chunks):
388 | x_c = decoder_gaussian(
389 | mean=mean_view[c*chunk_size:c*chunk_size + chunk_size],
390 | scale=scale_view[c*chunk_size:c*chunk_size + chunk_size],
391 | Q=Q_view[c*chunk_size:c*chunk_size + chunk_size] if Is_Q_tensor else Q,
392 | file_name=file_name.replace('.b', f'_{str(c)}.b'),
393 | )
394 | x_c_list.append(x_c)
395 | x_c_list = torch.cat(x_c_list, dim=0).type_as(mean)
396 | return x_c_list
397 |
398 |
399 | def decoder_gaussian(mean, scale, Q, file_name='tmp.b'):
400 | # should be single dimension
401 | assert file_name.endswith('.b')
402 | assert len(mean.shape) == 1
403 | assert mean.shape == scale.shape
404 | if not isinstance(Q, torch.Tensor):
405 | Q = torch.tensor([Q], dtype=mean.dtype, device=mean.device).repeat(mean.shape[0])
406 |
407 | with open(file_name, 'rb') as fin:
408 | min_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda")
409 | max_value = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy(), device="cuda")
410 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0]
411 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda")
412 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda")
413 |
414 | lower = arithmetic.calculate_cdf(
415 | mean,
416 | scale,
417 | Q,
418 | min_value,
419 | max_value
420 | )
421 |
422 | sym_out = arithmetic.arithmetic_decode(
423 | lower,
424 | byte_stream_torch,
425 | cnt_torch,
426 | chunk_size_cuda,
427 | int(lower.shape[0]),
428 | int(lower.shape[1])
429 | ).to(mean.device).to(torch.float32)
430 | x = sym_out + min_value
431 | x = x * Q
432 | return x
433 |
434 |
435 | def encoder(x, file_name='tmp.b'):
436 | # x: 0 or 1
437 | assert file_name[-2:] == '.b'
438 | x = x.detach().view(-1)
439 | p = torch.zeros_like(x).to(torch.float32)
440 | prob_1 = x.sum() / x.numel()
441 | p[...] = prob_1
442 | p_u = 1 - p.unsqueeze(-1)
443 | p_0 = torch.zeros_like(p_u)
444 | p_1 = torch.ones_like(p_u)
445 | # Encode to bytestream.
446 | output_cdf = torch.cat([p_0, p_u, p_1], dim=-1)
447 | sym = torch.floor(x).to(torch.int16)
448 | (byte_stream_torch, cnt_torch) = arithmetic.arithmetic_encode(
449 | sym,
450 | output_cdf,
451 | chunk_size_cuda,
452 | int(output_cdf.shape[0]),
453 | int(output_cdf.shape[1])
454 | )
455 | cnt_bytes = cnt_torch.cpu().numpy().tobytes()
456 | byte_stream_bytes = byte_stream_torch.cpu().numpy().tobytes()
457 | len_cnt_bytes = len(cnt_bytes)
458 | with open(file_name, 'wb') as fout:
459 | fout.write(prob_1.to(torch.float32).cpu().numpy().tobytes())
460 | fout.write(np.array([len_cnt_bytes]).astype(np.int32).tobytes())
461 | fout.write(cnt_bytes)
462 | fout.write(byte_stream_bytes)
463 | bit_len = (len(byte_stream_bytes) + len(cnt_bytes)) * 8 + 32 * 2
464 | return bit_len
465 |
466 |
467 | def decoder(N_len, file_name='tmp.b', device='cuda'):
468 | assert file_name[-2:] == '.b'
469 |
470 | with open(file_name, 'rb') as fin:
471 | prob_1 = torch.tensor(np.frombuffer(fin.read(4), dtype=np.float32).copy())
472 | len_cnt_bytes = np.frombuffer(fin.read(4), dtype=np.int32)[0]
473 | cnt_torch = torch.tensor(np.frombuffer(fin.read(len_cnt_bytes), dtype=np.int32).copy(), device="cuda")
474 | byte_stream_torch = torch.tensor(np.frombuffer(fin.read(), dtype=np.uint8).copy(), device="cuda")
475 | p = torch.zeros(size=[N_len], dtype=torch.float32, device="cuda")
476 | p[...] = prob_1
477 | p_u = 1 - p.unsqueeze(-1)
478 | p_0 = torch.zeros_like(p_u)
479 | p_1 = torch.ones_like(p_u)
480 | # Encode to bytestream.
481 | output_cdf = torch.cat([p_0, p_u, p_1], dim=-1)
482 | # Read from a file.
483 | # Decode from bytestream.
484 | sym_out = arithmetic.arithmetic_decode(
485 | output_cdf,
486 | byte_stream_torch,
487 | cnt_torch,
488 | chunk_size_cuda,
489 | int(output_cdf.shape[0]),
490 | int(output_cdf.shape[1])
491 | )
492 | return sym_out
493 |
494 |
--------------------------------------------------------------------------------
/utils/entropy_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as nnf
4 | import numpy as np
5 | from torch.distributions.uniform import Uniform
6 | from utils.encodings import use_clamp
7 |
8 | class Entropy_gaussian_clamp(nn.Module):
9 | def __init__(self, Q=1):
10 | super(Entropy_gaussian_clamp, self).__init__()
11 | self.Q = Q
12 | def forward(self, x, mean, scale, Q=None):
13 | if Q is None:
14 | Q = self.Q
15 | if use_clamp:
16 | x_mean = x.mean()
17 | x_min = x_mean - 15_000 * Q
18 | x_max = x_mean + 15_000 * Q
19 | x = torch.clamp(x, min=x_min.detach(), max=x_max.detach())
20 | scale = torch.clamp(scale, min=1e-9)
21 | m1 = torch.distributions.normal.Normal(mean, scale)
22 | lower = m1.cdf(x - 0.5*Q)
23 | upper = m1.cdf(x + 0.5*Q)
24 | likelihood = torch.abs(upper - lower)
25 | likelihood = Low_bound.apply(likelihood)
26 | bits = -torch.log2(likelihood)
27 | return bits
28 |
29 |
30 | class Entropy_gaussian(nn.Module):
31 | def __init__(self, Q=1):
32 | super(Entropy_gaussian, self).__init__()
33 | self.Q = Q
34 | def forward(self, x, mean, scale, Q=None, x_mean=None):
35 | if Q is None:
36 | Q = self.Q
37 | if use_clamp:
38 | if x_mean is None:
39 | x_mean = x.mean()
40 | x_min = x_mean - 15_000 * Q
41 | x_max = x_mean + 15_000 * Q
42 | x = torch.clamp(x, min=x_min.detach(), max=x_max.detach())
43 | scale = torch.clamp(scale, min=1e-9)
44 | m1 = torch.distributions.normal.Normal(mean, scale)
45 | lower = m1.cdf(x - 0.5*Q)
46 | upper = m1.cdf(x + 0.5*Q)
47 | likelihood = torch.abs(upper - lower)
48 | likelihood = Low_bound.apply(likelihood)
49 | bits = -torch.log2(likelihood)
50 | return bits
51 |
52 |
53 | class Entropy_bernoulli(nn.Module):
54 | def __init__(self):
55 | super().__init__()
56 | def forward(self, x, p):
57 | # p = torch.sigmoid(p)
58 | p = torch.clamp(p, min=1e-6, max=1 - 1e-6)
59 | pos_mask = (1 + x) / 2.0 # 1 -> 1, -1 -> 0
60 | neg_mask = (1 - x) / 2.0 # -1 -> 1, 1 -> 0
61 | pos_prob = p
62 | neg_prob = 1 - p
63 | param_bit = -torch.log2(pos_prob) * pos_mask + -torch.log2(neg_prob) * neg_mask
64 | return param_bit
65 |
66 |
67 | class Entropy_factorized(nn.Module):
68 | def __init__(self, channel=32, init_scale=10, filters=(3, 3, 3), likelihood_bound=1e-6,
69 | tail_mass=1e-9, optimize_integer_offset=True, Q=1):
70 | super(Entropy_factorized, self).__init__()
71 | self.filters = tuple(int(t) for t in filters)
72 | self.init_scale = float(init_scale)
73 | self.likelihood_bound = float(likelihood_bound)
74 | self.tail_mass = float(tail_mass)
75 | self.optimize_integer_offset = bool(optimize_integer_offset)
76 | self.Q = Q
77 | if not 0 < self.tail_mass < 1:
78 | raise ValueError(
79 | "`tail_mass` must be between 0 and 1")
80 | filters = (1,) + self.filters + (1,)
81 | scale = self.init_scale ** (1.0 / (len(self.filters) + 1))
82 | self._matrices = nn.ParameterList([])
83 | self._bias = nn.ParameterList([])
84 | self._factor = nn.ParameterList([])
85 | for i in range(len(self.filters) + 1):
86 | init = np.log(np.expm1(1.0 / scale / filters[i + 1]))
87 | self.matrix = nn.Parameter(torch.FloatTensor(
88 | channel, filters[i + 1], filters[i]))
89 | self.matrix.data.fill_(init)
90 | self._matrices.append(self.matrix)
91 | self.bias = nn.Parameter(
92 | torch.FloatTensor(channel, filters[i + 1], 1))
93 | noise = np.random.uniform(-0.5, 0.5, self.bias.size())
94 | noise = torch.FloatTensor(noise)
95 | self.bias.data.copy_(noise)
96 | self._bias.append(self.bias)
97 | if i < len(self.filters):
98 | self.factor = nn.Parameter(
99 | torch.FloatTensor(channel, filters[i + 1], 1))
100 | self.factor.data.fill_(0.0)
101 | self._factor.append(self.factor)
102 |
103 | def _logits_cumulative(self, logits, stop_gradient):
104 | for i in range(len(self.filters) + 1):
105 | matrix = nnf.softplus(self._matrices[i])
106 | if stop_gradient:
107 | matrix = matrix.detach()
108 | # print('dqnwdnqwdqwdqwf:', matrix.shape, logits.shape)
109 | logits = torch.matmul(matrix, logits)
110 | bias = self._bias[i]
111 | if stop_gradient:
112 | bias = bias.detach()
113 | logits += bias
114 | if i < len(self._factor):
115 | factor = nnf.tanh(self._factor[i])
116 | if stop_gradient:
117 | factor = factor.detach()
118 | logits += factor * nnf.tanh(logits)
119 | return logits
120 |
121 | def forward(self, x, Q=None):
122 | # x: [N, C], quantized
123 | if Q is None:
124 | Q = self.Q
125 | else:
126 | Q = Q.permute(1, 0).contiguous()
127 | x = x.permute(1, 0).contiguous() # [C, N]
128 | # print('dqwdqwdqwdqwfqwf:', x.shape, Q.shape)
129 | lower = self._logits_cumulative(x - 0.5*(1/Q), stop_gradient=False)
130 | upper = self._logits_cumulative(x + 0.5*(1/Q), stop_gradient=False)
131 | sign = -torch.sign(torch.add(lower, upper))
132 | sign = sign.detach()
133 | likelihood = torch.abs(
134 | nnf.sigmoid(sign * upper) - nnf.sigmoid(sign * lower))
135 | likelihood = Low_bound.apply(likelihood)
136 | bits = -torch.log2(likelihood) # [C, N]
137 | bits = bits.permute(1, 0).contiguous()
138 | return bits
139 |
140 |
141 | class Low_bound(torch.autograd.Function):
142 | @staticmethod
143 | def forward(ctx, x):
144 | ctx.save_for_backward(x)
145 | x = torch.clamp(x, min=1e-6)
146 | return x
147 |
148 | @staticmethod
149 | def backward(ctx, g):
150 | x, = ctx.saved_tensors
151 | grad1 = g.clone()
152 | grad1[x < 1e-6] = 0
153 | pass_through_if = np.logical_or(
154 | x.cpu().numpy() >= 1e-6, g.cpu().numpy() < 0.0)
155 | t = torch.Tensor(pass_through_if+0.0).cuda()
156 | return grad1 * t
157 |
158 |
159 | class UniverseQuant(torch.autograd.Function):
160 | @staticmethod
161 | def forward(ctx, x):
162 | #b = np.random.uniform(-1,1)
163 | b = 0
164 | uniform_distribution = Uniform(-0.5*torch.ones(x.size())
165 | * (2**b), 0.5*torch.ones(x.size())*(2**b)).sample().cuda()
166 | return torch.round(x+uniform_distribution)-uniform_distribution
167 |
168 | @staticmethod
169 | def backward(ctx, g):
170 |
171 | return g
172 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 |
18 | def inverse_sigmoid(x):
19 | return torch.log(x/(1-x))
20 |
21 | def PILtoTorch(pil_image, resolution):
22 | resized_image_PIL = pil_image.resize(resolution)
23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
24 | if len(resized_image.shape) == 3:
25 | return resized_image.permute(2, 0, 1)
26 | else:
27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28 |
29 | def get_expon_lr_func_mine(
30 | lr_init, lr_final, lr_delay_mult=1.0, max_steps=1000000
31 | ):
32 | def helper(step):
33 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
34 | # Disable this parameter
35 | return 0.0
36 | delay_rate = 1.0
37 |
38 | if step < 10000:
39 | t = np.clip((step - 0) / (10000 - 0), 0, 1)
40 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
41 | else:
42 | t = np.clip((step - 10000) / (30000 - 10000), 0, 1)
43 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
44 |
45 | return delay_rate * log_lerp
46 |
47 | return helper
48 |
49 | def get_expon_lr_func(
50 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000, step_sub=0,
51 | ):
52 | """
53 | Copied from Plenoxels
54 |
55 | Continuous learning rate decay function. Adapted from JaxNeRF
56 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
57 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
58 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
59 | function of lr_delay_mult, such that the initial learning rate is
60 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
61 | to the normal learning rate when steps>lr_delay_steps.
62 | :param conf: config subtree 'lr' or similar
63 | :param max_steps: int, the number of steps during optimization.
64 | :return HoF which takes step as input
65 | """
66 |
67 | def helper(step):
68 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
69 | # Disable this parameter
70 | return 0.0
71 | if lr_delay_steps > 0:
72 | # A kind of reverse cosine decay.
73 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
74 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
75 | )
76 | else:
77 | delay_rate = 1.0
78 | t = np.clip((step-step_sub) / (max_steps-step_sub), 0, 1)
79 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
80 | return delay_rate * log_lerp
81 |
82 | return helper
83 |
84 | def strip_lowerdiag(L):
85 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
86 |
87 | uncertainty[:, 0] = L[:, 0, 0]
88 | uncertainty[:, 1] = L[:, 0, 1]
89 | uncertainty[:, 2] = L[:, 0, 2]
90 | uncertainty[:, 3] = L[:, 1, 1]
91 | uncertainty[:, 4] = L[:, 1, 2]
92 | uncertainty[:, 5] = L[:, 2, 2]
93 | return uncertainty
94 |
95 | def strip_symmetric(sym):
96 | return strip_lowerdiag(sym)
97 |
98 | def build_rotation(r):
99 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
100 |
101 | q = r / norm[:, None]
102 |
103 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
104 |
105 | r = q[:, 0]
106 | x = q[:, 1]
107 | y = q[:, 2]
108 | z = q[:, 3]
109 |
110 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
111 | R[:, 0, 1] = 2 * (x*y - r*z)
112 | R[:, 0, 2] = 2 * (x*z + r*y)
113 | R[:, 1, 0] = 2 * (x*y + r*z)
114 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
115 | R[:, 1, 2] = 2 * (y*z - r*x)
116 | R[:, 2, 0] = 2 * (x*z - r*y)
117 | R[:, 2, 1] = 2 * (y*z + r*x)
118 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
119 | return R
120 |
121 | def build_scaling_rotation(s, r):
122 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
123 | R = build_rotation(r)
124 |
125 | L[:,0,0] = s[:,0]
126 | L[:,1,1] = s[:,1]
127 | L[:,2,2] = s[:,2]
128 |
129 | L = R @ L
130 | return L
131 |
132 | def safe_state(silent):
133 | old_f = sys.stdout
134 | class F:
135 | def __init__(self, silent):
136 | self.silent = silent
137 |
138 | def write(self, x):
139 | if not self.silent:
140 | if x.endswith("\n"):
141 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
142 | else:
143 | old_f.write(x)
144 |
145 | def flush(self):
146 | old_f.flush()
147 |
148 | sys.stdout = F(silent)
149 |
150 | random.seed(0)
151 | np.random.seed(0)
152 | torch.manual_seed(0)
153 | torch.cuda.set_device(torch.device("cuda:0"))
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points : np.array
19 | colors : np.array
20 | normals : np.array
21 |
22 | def geom_transform_points(points, transf_matrix):
23 | P, _ = points.shape
24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
25 | points_hom = torch.cat([points, ones], dim=1)
26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
27 |
28 | denom = points_out[..., 3:] + 0.0000001
29 | return (points_out[..., :3] / denom).squeeze(dim=0)
30 |
31 | def getWorld2View(R, t):
32 | Rt = np.zeros((4, 4))
33 | Rt[:3, :3] = R.transpose()
34 | Rt[:3, 3] = t
35 | Rt[3, 3] = 1.0
36 | return np.float32(Rt)
37 |
38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
39 | Rt = np.zeros((4, 4))
40 | Rt[:3, :3] = R.transpose()
41 | Rt[:3, 3] = t
42 | Rt[3, 3] = 1.0
43 |
44 | C2W = np.linalg.inv(Rt)
45 | cam_center = C2W[:3, 3]
46 | cam_center = (cam_center + translate) * scale
47 | C2W[:3, 3] = cam_center
48 | Rt = np.linalg.inv(C2W)
49 | return np.float32(Rt)
50 |
51 | def getProjectionMatrix(znear, zfar, fovX, fovY):
52 | tanHalfFovY = math.tan((fovY / 2))
53 | tanHalfFovX = math.tan((fovX / 2))
54 |
55 | top = tanHalfFovY * znear
56 | bottom = -top
57 | right = tanHalfFovX * znear
58 | left = -right
59 |
60 | P = torch.zeros(4, 4)
61 |
62 | z_sign = 1.0
63 |
64 | P[0, 0] = 2.0 * znear / (right - left)
65 | P[1, 1] = 2.0 * znear / (top - bottom)
66 | P[0, 2] = (right + left) / (right - left)
67 | P[1, 2] = (top + bottom) / (top - bottom)
68 | P[3, 2] = z_sign
69 | P[2, 2] = z_sign * zfar / (zfar - znear)
70 | P[2, 3] = -(zfar * znear) / (zfar - znear)
71 | return P
72 |
73 | def fov2focal(fov, pixels):
74 | return pixels / (2 * math.tan(fov / 2))
75 |
76 | def focal2fov(focal, pixels):
77 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 | def mse(img1, img2):
15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16 |
17 | def psnr(img1, img2):
18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
20 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 | def l1_loss(network_output, gt):
18 | return torch.abs((network_output - gt)).mean()
19 |
20 | def l2_loss(network_output, gt):
21 | return ((network_output - gt) ** 2).mean()
22 |
23 | def gaussian(window_size, sigma):
24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25 | return gauss / gauss.sum()
26 |
27 | def create_window(window_size, channel):
28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31 | return window
32 |
33 | def ssim(img1, img2, window_size=11, size_average=True):
34 | channel = img1.size(-3)
35 | window = create_window(window_size, channel)
36 |
37 | if img1.is_cuda:
38 | window = window.cuda(img1.get_device())
39 | window = window.type_as(img1)
40 |
41 | return _ssim(img1, img2, window, window_size, channel, size_average)
42 |
43 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46 |
47 | mu1_sq = mu1.pow(2)
48 | mu2_sq = mu2.pow(2)
49 | mu1_mu2 = mu1 * mu2
50 |
51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54 |
55 | C1 = 0.01 ** 2
56 | C2 = 0.03 ** 2
57 |
58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59 |
60 | if size_average:
61 | return ssim_map.mean()
62 | else:
63 | return ssim_map.mean(1).mean(1).mean(1)
64 |
65 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 | def mkdir_p(folder_path):
17 | # Creates a directory. equivalent to using mkdir -p on the command line
18 | try:
19 | makedirs(folder_path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == EEXIST and path.isdir(folder_path):
22 | pass
23 | else:
24 | raise
25 |
26 | def searchForMaxIteration(folder):
27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28 | return max(saved_iters)
29 |
--------------------------------------------------------------------------------
/utils/visualize_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Optional, Tuple, Type
2 |
3 | import cv2
4 | import numpy as np
5 | import torch as th
6 | import torch.nn.functional as F
7 |
8 |
9 | def add_label_centered(
10 | img: np.ndarray,
11 | text: str,
12 | font_scale: float = 1.0,
13 | thickness: int = 2,
14 | alignment: str = "top",
15 | color: Tuple[int, int, int] = (0, 255, 0),
16 | ) -> np.ndarray:
17 | font = cv2.FONT_HERSHEY_SIMPLEX
18 | textsize = cv2.getTextSize(text, font, font_scale, thickness=thickness)[0]
19 | img = img.astype(np.uint8).copy()
20 |
21 | if alignment == "top":
22 | cv2.putText(
23 | img,
24 | text,
25 | ((img.shape[1] - textsize[0]) // 2, 50),
26 | font,
27 | font_scale,
28 | color,
29 | thickness=thickness,
30 | lineType=cv2.LINE_AA,
31 | )
32 | elif alignment == "bottom":
33 | cv2.putText(
34 | img,
35 | text,
36 | ((img.shape[1] - textsize[0]) // 2, img.shape[0] - textsize[1]),
37 | font,
38 | font_scale,
39 | color,
40 | thickness=thickness,
41 | lineType=cv2.LINE_AA,
42 | )
43 | else:
44 | raise ValueError("Unknown text alignment")
45 |
46 | return img
47 |
48 | def tensor2rgbjet(
49 | tensor: th.Tensor, x_max: Optional[float] = None, x_min: Optional[float] = None
50 | ) -> np.ndarray:
51 | return cv2.applyColorMap(tensor2rgb(tensor, x_max=x_max, x_min=x_min), cv2.COLORMAP_JET)
52 |
53 |
54 | def tensor2rgb(
55 | tensor: th.Tensor, x_max: Optional[float] = None, x_min: Optional[float] = None
56 | ) -> np.ndarray:
57 | x = tensor.data.cpu().numpy()
58 | if x_min is None:
59 | x_min = x.min()
60 | if x_max is None:
61 | x_max = x.max()
62 |
63 | gain = 255 / np.clip(x_max - x_min, 1e-3, None)
64 | x = (x - x_min) * gain
65 | x = x.clip(0.0, 255.0)
66 | x = x.astype(np.uint8)
67 | return x
68 |
69 |
70 | def tensor2image(
71 | tensor: th.Tensor,
72 | x_max: Optional[float] = 1.0,
73 | x_min: Optional[float] = 0.0,
74 | mode: str = "rgb",
75 | mask: Optional[th.Tensor] = None,
76 | label: Optional[str] = None,
77 | ) -> np.ndarray:
78 |
79 | tensor = tensor.detach()
80 |
81 | # Apply mask
82 | if mask is not None:
83 | tensor = tensor * mask
84 |
85 | if len(tensor.size()) == 2:
86 | tensor = tensor[None]
87 |
88 | # Make three channel image
89 | assert len(tensor.size()) == 3, tensor.size()
90 | n_channels = tensor.shape[0]
91 | if n_channels == 1:
92 | tensor = tensor.repeat(3, 1, 1)
93 | elif n_channels != 3:
94 | raise ValueError(f"Unsupported number of channels {n_channels}.")
95 |
96 | # Convert to display format
97 | img = tensor.permute(1, 2, 0)
98 |
99 | if mode == "rgb":
100 | img = tensor2rgb(img, x_max=x_max, x_min=x_min)
101 | elif mode == "jet":
102 | # `cv2.applyColorMap` assumes input format in BGR
103 | img[:, :, :3] = img[:, :, [2, 1, 0]]
104 | img = tensor2rgbjet(img, x_max=x_max, x_min=x_min)
105 | # convert back to rgb
106 | img[:, :, :3] = img[:, :, [2, 1, 0]]
107 | else:
108 | raise ValueError(f"Unsupported mode {mode}.")
109 |
110 | if label is not None:
111 | img = add_label_centered(img, label)
112 |
113 | return img
114 |
115 | # d: b x 1 x H x W
116 | # screenCoords: b x 2 x H X W
117 | # focal: b x 2 x 2
118 | # princpt: b x 2
119 | # out: b x 3 x H X W
120 | def depthImgToPosCam_Batched(d, screenCoords, focal, princpt):
121 | p = screenCoords - princpt[:, :, None, None]
122 | x = (d * p[:, 0:1, :, :]) / focal[:, 0:1, 0, None, None]
123 | y = (d * p[:, 1:2, :, :]) / focal[:, 1:2, 1, None, None]
124 | return th.cat([x, y, d], dim=1)
125 |
126 | # p: b x 3 x H x W
127 | # out: b x 3 x H x W
128 | def computeNormalsFromPosCam_Batched(p):
129 | p = F.pad(p, (1, 1, 1, 1), "replicate")
130 | d0 = p[:, :, 2:, 1:-1] - p[:, :, :-2, 1:-1]
131 | d1 = p[:, :, 1:-1, 2:] - p[:, :, 1:-1, :-2]
132 | n = th.cross(d0, d1, dim=1)
133 | norm = th.norm(n, dim=1, keepdim=True)
134 | norm = norm + 1e-5
135 | norm[norm < 1e-5] = 1 # Can not backprop through this
136 | return -n / norm
137 |
138 | def visualize_normal(inputs, depth_p):
139 | # Normals
140 | uv = th.stack(
141 | th.meshgrid(
142 | th.arange(depth_p.shape[2]), th.arange(depth_p.shape[1]), indexing="xy"
143 | ),
144 | dim=0,
145 | )[None].float().cuda()
146 | position = depthImgToPosCam_Batched(
147 | depth_p[None, ...], uv, inputs["focal"], inputs["princpt"]
148 | )
149 | normal = 0.5 * (computeNormalsFromPosCam_Batched(position) + 1.0)
150 | normal = normal[0, [2, 1, 0], :, :] # legacy code assumes BGR format
151 | normal_p = tensor2image(normal, label="normal_p")
152 |
153 | return normal_p
--------------------------------------------------------------------------------