├── .dockerignore ├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE.md ├── README.md ├── add_files.bash ├── arguments └── __init__.py ├── assets ├── better.png ├── logo_graphdeco.png ├── logo_inria.png ├── logo_mpi.png ├── logo_mpi.svg ├── logo_uca.png ├── select.png ├── teaser.png └── worse.png ├── convert.py ├── docker_train.py ├── docker_train.sh ├── docker_visualize.py ├── environment.yml ├── extract_metadata.py ├── full_eval.py ├── gaussian_renderer ├── __init__.py ├── ever.py ├── fast_renderer.py └── network_gui.py ├── host_render_server.py ├── install.bash ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── measure_fps.py ├── metrics.py ├── model_output ├── cameras.json ├── cfg_args └── input.ply ├── notebooks ├── demo_images │ ├── 3dgs.png │ ├── ea.png │ ├── es.png │ ├── ga.png │ ├── gs.png │ ├── os.png │ ├── side_view.3dgs.png │ ├── tris.png │ └── ts.png ├── render_rotating_room.ipynb ├── render_sibr_paths.ipynb ├── render_size_animation.ipynb ├── simple_render.ipynb ├── visualize_mlp.ipynb └── visualize_prims.ipynb ├── partial_eval.py ├── render.py ├── requirements.txt ├── resize_images.py ├── run.sh ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── contractions.py ├── dataset_readers.py ├── gaussian_model.py └── sphere_init.py ├── sibr_patch.patch ├── train.py └── utils ├── __init__.py ├── cam_util.py ├── camera_utils.py ├── camera_utils_zipnerf.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── math.py ├── safe_math.py ├── sh_utils.py ├── stepfun.py └── system_utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | 10 | ims/ 11 | iters/ 12 | splits/ 13 | *.mp4 14 | eval*/ 15 | *.so 16 | *.lock 17 | usace/* 18 | **/.slangtorch_cache/* 19 | *.orig 20 | optix/* 21 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "SIBR_viewers"] 5 | path = SIBR_viewers 6 | url = https://gitlab.inria.fr/sibr/sibr_core.git 7 | [submodule "ever"] 8 | path = ever 9 | url = https://github.com/google/ever.git 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an NVIDIA CUDA base image that includes development libraries 2 | FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 3 | 4 | # Non-interactive mode for apt-get 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | 7 | 8 | # Set an environment variable for OptiX installation. 9 | # Adjust this to wherever you've placed OptiX inside the container or mount at runtime: 10 | ENV OptiX_INSTALL_DIR=/opt/OptiX_7.4 11 | 12 | # ------------------------------------------------------ 13 | # 1) Install System Dependencies 14 | # ------------------------------------------------------ 15 | RUN apt-get update && apt-get install -y --no-install-recommends \ 16 | wget \ 17 | git \ 18 | cmake \ 19 | unzip \ 20 | build-essential \ 21 | libglew-dev \ 22 | libassimp-dev \ 23 | libboost-all-dev \ 24 | libgtk-3-dev \ 25 | libopencv-dev \ 26 | libglfw3-dev \ 27 | libavdevice-dev \ 28 | libavcodec-dev \ 29 | libeigen3-dev \ 30 | libxxf86vm-dev \ 31 | libembree-dev \ 32 | # libabsl-dev \ 33 | libcgal-dev \ 34 | libglm-dev \ 35 | && rm -rf /var/lib/apt/lists/* 36 | 37 | # ------------------------------------------------------ 38 | # 2) Install a Miniconda / Conda environment 39 | # - We use Miniconda3 as an example here. 40 | # ------------------------------------------------------ 41 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 42 | bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \ 43 | rm Miniconda3-latest-Linux-x86_64.sh 44 | 45 | # Make conda available and create environment 46 | ENV PATH="/opt/conda/bin:${PATH}" 47 | RUN conda update -n base -c defaults conda && \ 48 | conda create -n ever python=3.10 -y && \ 49 | conda clean -ya 50 | 51 | # By default, we activate conda env inside container with a script or using ENV: 52 | SHELL ["/bin/bash", "-c"] 53 | RUN echo "conda activate ever" >> ~/.bashrc 54 | 55 | # ------------------------------------------------------ 56 | # 3) (Optional) Install Slang 57 | # Replace this section with the correct steps to install or build Slang from source if needed. 58 | # ------------------------------------------------------ 59 | # RUN git clone --recursive https://github.com/shader-slang/slang.git /opt/slang && \ 60 | # cd /opt/slang && \ 61 | # # Example: build from source; replace with actual Slang build instructions 62 | # mkdir build && cd build && \ 63 | # cmake -DCMAKE_BUILD_TYPE=Release .. && \ 64 | # make -j"$(nproc)" && \ 65 | # make install 66 | 67 | RUN wget https://github.com/shader-slang/slang/releases/download/v2025.6.1/slang-2025.6.1-linux-x86_64.zip && \ 68 | mkdir slang_install && \ 69 | cd slang_install && \ 70 | unzip ../slang-2025.6.1-linux-x86_64.zip && \ 71 | cp bin/* /usr/bin/ 72 | 73 | # Clone, build, and install abseil-cpp. 74 | RUN git clone https://github.com/abseil/abseil-cpp.git /tmp/abseil-cpp && \ 75 | cd /tmp/abseil-cpp && \ 76 | mkdir build && cd build && \ 77 | cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && \ 78 | make -j$(nproc) && \ 79 | make install && \ 80 | ldconfig && \ 81 | rm -rf /tmp/abseil-cpp 82 | 83 | # ------------------------------------------------------ 84 | # 4) Install Python packages (within the 'ever' env) 85 | # ------------------------------------------------------ 86 | RUN source activate ever && \ 87 | # Adjust the PyTorch install line for your specific CUDA version if needed 88 | pip3 install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 && \ 89 | pip3 install --no-cache-dir cmake 90 | 91 | # ------------------------------------------------------ 92 | # 5) Final Container Setup 93 | # ------------------------------------------------------ 94 | # We'll define /ever_training as our working directory but 95 | # won't copy any code here. We'll rely on runtime mounting. 96 | WORKDIR / 97 | 98 | COPY ./requirements.txt / 99 | 100 | RUN cd / && \ 101 | source activate ever && \ 102 | pip install -r requirements.txt 103 | 104 | COPY optix/ /opt/OptiX_7.4 105 | COPY . /ever_training 106 | 107 | RUN ls /opt/OptiX_7.4 108 | 109 | ENV TORCH_CUDA_ARCH_LIST="5.0;6.0;6.1;7.0;7.5;8.0;8.6" 110 | ENV CUDAARCHS="50 60 61 70 75 80 86" 111 | ENV LD_LIBRARY_PATH="/slang_install/lib/" 112 | 113 | WORKDIR /ever_training 114 | RUN source activate ever && \ 115 | rm -r ever/build && \ 116 | bash install.bash 117 | 118 | # Expose any ports needed for training or viewer 119 | EXPOSE 6009 120 | 121 | # By default, just start a shell in the 'ever' environment 122 | CMD ["/bin/bash", "-c", "source activate ever && exec bash"] 123 | 124 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exact Volumetric Ellipsoid Rendering for Real-time View Synthesis 2 | This is the repository with changes to 3DGS's training code to use the EVER rendering method. 3 | 4 | Ever is a method for real-time differentiable emission-only volume rendering. Unlike recent 5 | rasterization based approach by 3D Gaussian Splatting (3DGS), our primitive based representation 6 | allows for exact volume rendering, rather than alpha compositing 3D Gaussian billboards. As such, 7 | unlike 3DGS our formulation does not suffer from popping artifacts and view dependent density, but 8 | still achieves frame rates of ∼30 FPS at 720p on an NVIDIA RTX4090. Because our approach is built 9 | upon ray tracing it supports rendering techniques such as defocus blur and camera distortion (e.g. 10 | such as from fisheye cameras), which are difficult to achieve by rasterization. We show that our 11 | method has higher performance and fewer blending issues than 3DGS and other subsequent works, 12 | especially on the challenging large-scale scenes from the Zip-NeRF dataset where it achieves SOTA 13 | results among real-time techniques. 14 | 15 | Datasets: 16 | [mipnerf360pt1](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip), 17 | [mipnerf360pt2](https://storage.googleapis.com/gresearch/refraw360/360_extra_scenes.zip), 18 | [zipnerf](https://smerf-3d.github.io/) 19 | 20 | `zipnerf-undistorted` is used for evaluation against 3DGS. 21 | 22 | More details can be found in our [paper](https://arxiv.org/abs/2410.01804) or at our [website](https://half-potato.gitlab.io/posts/ever/) 23 | 24 |
25 |
26 |

BibTeX

27 |

 28 | @misc{mai2024everexactvolumetricellipsoid, title={EVER: Exact Volumetric Ellipsoid Rendering for Real-time View Synthesis},  author={Alexander Mai and Peter Hedman and George Kopanas and Dor Verbin and David Futschik and Qiangeng Xu and Falko Kuester and Jon Barron and Yinda Zhang}, year={2024}, eprint={2410.01804}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2410.01804},  }
 29 | 
30 |
31 |
32 | 33 | 34 | ## Quick Install 35 | 36 | If you wish to skip the install, use this command to train: 37 | 38 | ``` 39 | git clone --recursive https://github.com/half-potato/ever_training 40 | python docker_train.py -s -m ... 41 | ``` 42 | The docker train script should have the same arguments as `train.py`. To visualize, do: 43 | ``` 44 | bash docker_visualize.py 6006 127.0.0.1 45 | ``` 46 | and in a different terminal: 47 | ``` 48 | cd SIBR_viewers 49 | ./install/bin/SIBR_remoteGaussian_app --ip 127.0.0.1 --port 6009 50 | ``` 51 | 52 | 53 | ### Dependencies 54 | - OptiX 7.4, which must be downloaded from NVIDIA's [website](https://developer.nvidia.com/designworks/optix/downloads/legacy). This is downloaded and placed somewhere on your computer, then use `export OptiX_INSTALL_DIR=...` to set the variable to that location. 55 | - [*SlangD*](https://github.com/shader-slang/slang). We recommend using the latest version you can, as they have fixed quite a few bugs. 56 | We can install the rest of the dependencies as follows: 57 | ``` 58 | sudo apt install -y libglew-dev libassimp-dev libboost-all-dev libgtk-3-dev libopencv-dev libglfw3-dev libavdevice-dev libavcodec-dev libeigen3-dev libxxf86vm-dev libembree-dev libabsl-dev libcgal-dev libglm-dev 59 | conda env create --name ever python==3.10 60 | conda activate ever 61 | conda install pip 62 | # adjust for cuda version 63 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 64 | ``` 65 | 66 | 67 | For Manjaro, Arch, or other rolling release, run the following: 68 | ``` 69 | export CXX=/usr/bin/g++-11 CC=/usr/bin/gcc-11 70 | ``` 71 | 72 | If you have multiple cuda versions, make sure to specify which one to use using the following command: 73 | ``` 74 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.6/lib64 75 | export PATH=$PATH:/usr/local/cuda-12.6/bin 76 | ``` 77 | 78 | Now, download the files and run `bash install.bash` 79 | ``` 80 | git clone --recursive https://github.com/half-potato/ever_training 81 | cd ever_training 82 | pip install -r requirements.txt 83 | bash install.bash 84 | ``` 85 | If you get a bunch of compilation errors, it could be that you need to run the export line for the CXX and CC versions. 86 | 87 | We can now train using the following command: 88 | ``` 89 | python train.py -s 90 | ``` 91 | 92 | Tested on Manjaro and Ubuntu Linux 22.04. 93 | 94 | ### Evaluation 95 | By default, the trained models use all available images in the dataset. To train them while withholding a test set for evaluation, use the ```--eval``` flag. This way, you can render training/test sets and produce error metrics as follows: 96 | ```shell 97 | python train.py -s --eval # Train with train/test split --images (images_4 for 360 outdoor, images_2 for 360 indoor) 98 | python render.py -m # Generate renderings 99 | python metrics.py -m # Compute error metrics on renderings 100 | ``` 101 | To run training on a dataset with images with varying exposure levels, we must first extract the metadata from the images. This can be done with this command: 102 | ``` 103 | python extract_metadata.py $DATASET/images $DATASET/metadata.json 104 | ``` 105 | 106 | If you want to evaluate our [pre-trained models](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/pretrained/models.zip), you will have to download the corresponding source data sets and indicate their location to ```render.py``` with an additional ```--source_path/-s``` flag. Note: The pre-trained models were created with the release codebase. This code base has been cleaned up and includes bugfixes, hence the metrics you get from evaluating them will differ from those in the paper. 107 | ```shell 108 | python render.py -m -s 109 | python metrics.py -m 110 | ``` 111 | These have the same arguments as 3DGS. 112 | 113 | To run the full benchmark, use the following command: 114 | ```shell 115 | python full_eval.py -m360 $NERF_DATSASETS/360 -zn $NERF_DATSASETS/zipnerf_ud -db $NERF_DATSASETS/db -tnt $NERF_DATSASETS/tandt --output_path eval 116 | ``` 117 | 118 | ## Interactive Viewers 119 | For all viewing purposes, we rely on the [SIBR](https://sibr.gitlabpages.inria.fr/) remote viewer. The training script will host a server to view training, and we provide the `host_render_server.py` file for viewing trained models. 120 | 121 | The viewer can then be run as follows: 122 | ``` 123 | python host_render_server.py -m $TRAINED_MODEL_LOCATION -s $SCENE_LOCATION --port $PORT --ip $IP` 124 | ``` 125 | The `$SCENE_LOCATION` only needs to be provided if viewing on a different machine than the model was trained on. `$IP` is for viewing on remote machines. By default, it is `127.0.0.1`. By default, `$PORT` is 6009. Once the render server has been hosted, we can then run the SIBR remote viewer in a separate terminal and connect it. 126 | Then, on a different terminal, run: 127 | ``` 128 | ./install/bin/SIBR_remoteGaussian_app --ip $IP --port $PORT 129 | ``` 130 | 131 | 132 |
133 | Primary Command Line Arguments for Network Viewer 134 | 135 | #### --path / -s 136 | Argument to override model's path to source dataset. 137 | #### --ip 138 | IP to use for connection to a running training script. 139 | #### --port 140 | Port to use for connection to a running training script. 141 | #### --rendering-size 142 | Takes two space separated numbers to define the resolution at which network rendering occurs, ```1200``` width by default. 143 | Note that to enforce an aspect that differs from the input images, you need ```--force-aspect-ratio``` too. 144 | #### --load_images 145 | Flag to load source dataset images to be displayed in the top view for each camera. 146 |
147 |
148 | -------------------------------------------------------------------------------- /add_files.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cp ever/new_files/*.py . 3 | cp -r ever/new_files/notebooks . 4 | cp ever/new_files/scene/* scene/ 5 | cp ever/new_files/gaussian_renderer/* gaussian_renderer/ 6 | cp ever/new_files/utils/* utils/ 7 | 8 | cd SIBR_viewers 9 | git apply ../ever/new_files/sibr_patch.patch 10 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | import os 14 | import sys 15 | from argparse import ArgumentParser, Namespace 16 | 17 | 18 | class GroupParams: 19 | pass 20 | 21 | 22 | class ParamGroup: 23 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 24 | group = parser.add_argument_group(name) 25 | for key, value in vars(self).items(): 26 | shorthand = False 27 | if key.startswith("_"): 28 | shorthand = True 29 | key = key[1:] 30 | t = type(value) 31 | value = value if not fill_none else None 32 | if shorthand: 33 | if t == bool: 34 | group.add_argument( 35 | "--" + key, ("-" + key[0:1]), default=value, action="store_true" 36 | ) 37 | elif t == list or t == tuple: 38 | group.add_argument( 39 | "--" + key, 40 | ("-" + key[0:1]), 41 | nargs="+", 42 | type=type(value[0]), 43 | default=value, 44 | ) 45 | else: 46 | group.add_argument( 47 | "--" + key, ("-" + key[0:1]), default=value, type=t 48 | ) 49 | else: 50 | if t == bool: 51 | group.add_argument("--" + key, default=value, action="store_true") 52 | elif t == list or t == tuple: 53 | group.add_argument( 54 | "--" + key, nargs="+", type=type(value[0]), default=value 55 | ) 56 | else: 57 | group.add_argument("--" + key, default=value, type=t) 58 | 59 | def extract(self, args): 60 | group = GroupParams() 61 | for arg in vars(args).items(): 62 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 63 | setattr(group, arg[0], arg[1]) 64 | return group 65 | 66 | 67 | class ModelParams(ParamGroup): 68 | def __init__(self, parser, sentinel=False): 69 | self.sh_degree = 3 70 | self._source_path = "" 71 | self._model_path = "" 72 | self._images = "images" 73 | self._resolution = -1 74 | self._white_background = False 75 | self.data_device = "cpu" 76 | self.render_spline = False 77 | self.use_neural_network = False 78 | self.eval = False 79 | self.num_additional_pts = 10000 80 | self.additional_size_multi = 1.0 81 | self.num_spline_frames = 480 82 | self.glo_latent_dim = 64 83 | self.max_opacity = 0.99 84 | self.tmin = 0.2 85 | 86 | super().__init__(parser, "Loading Parameters", sentinel) 87 | 88 | def extract(self, args): 89 | g = super().extract(args) 90 | g.source_path = os.path.abspath(g.source_path) 91 | return g 92 | 93 | 94 | class PipelineParams(ParamGroup): 95 | def __init__(self, parser): 96 | self.convert_SHs_python = True 97 | self.compute_cov3D_python = False 98 | self.enable_GLO = False 99 | self.debug = False 100 | super().__init__(parser, "Pipeline Parameters") 101 | 102 | class OptimizationParams(ParamGroup): 103 | def __init__(self, parser): 104 | self.iterations = 30_000 105 | 106 | self.betas = [0.9, 0.999] 107 | 108 | self.position_lr_final = 4e-7 #0.0000004 109 | self.position_lr_delay_mult = 0.01 110 | self.position_lr_max_steps = 30_000 111 | self.position_lr_init = 4e-5 #0.00004 112 | 113 | self.glo_lr = 0.00 114 | self.glo_network_lr = 0.00005 115 | 116 | self.feature_lr = 0.0025 117 | self.feature_rest_lr = 0.00025 118 | self.bg_lr = 0.0 119 | self.opacity_lr = 0.0125 120 | self.scaling_lr = 0.005 121 | self.rotation_lr = 0.001 122 | self.min_opacity = 0.005 123 | self.min_split_opacity = 0.01 124 | self.percent_dense = 0.0025 125 | self.lambda_dssim = 0.2 126 | 127 | self.lambda_anisotropic = 1e-1 128 | self.lambda_distortion = 0 129 | self.sh_up_interval = 2000 130 | 131 | self.densification_interval = 200 132 | self.opacity_reset_interval = 300000 133 | self.densify_from_iter = 1500 134 | self.densify_until_iter = 16_000 135 | 136 | self.densify_grad_threshold: float = 2.5e-7 137 | 138 | self.clone_grad_threshold: float = 1e-1 139 | 140 | self.center_pixel = False 141 | self.fallback_xy_grad = False 142 | 143 | self.random_background = False 144 | super().__init__(parser, "Optimization Parameters") 145 | 146 | 147 | def get_combined_args(parser: ArgumentParser): 148 | cmdlne_string = sys.argv[1:] 149 | cfgfile_string = "Namespace()" 150 | args_cmdline = parser.parse_args(cmdlne_string) 151 | 152 | try: 153 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 154 | print("Looking for config file in", cfgfilepath) 155 | with open(cfgfilepath) as cfg_file: 156 | print("Config file found: {}".format(cfgfilepath)) 157 | cfgfile_string = cfg_file.read() 158 | except TypeError: 159 | print("Config file not found at") 160 | pass 161 | args_cfgfile = eval(cfgfile_string) 162 | 163 | merged_dict = vars(args_cfgfile).copy() 164 | for k, v in vars(args_cmdline).items(): 165 | if v != None: 166 | merged_dict[k] = v 167 | return Namespace(**merged_dict) 168 | -------------------------------------------------------------------------------- /assets/better.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/better.png -------------------------------------------------------------------------------- /assets/logo_graphdeco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/logo_graphdeco.png -------------------------------------------------------------------------------- /assets/logo_inria.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/logo_inria.png -------------------------------------------------------------------------------- /assets/logo_mpi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/logo_mpi.png -------------------------------------------------------------------------------- /assets/logo_uca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/logo_uca.png -------------------------------------------------------------------------------- /assets/select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/select.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/teaser.png -------------------------------------------------------------------------------- /assets/worse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/assets/worse.png -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | print(feat_extracton_cmd) 44 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 45 | exit(exit_code) 46 | 47 | ## Feature matching 48 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 49 | --database_path " + args.source_path + "/distorted/database.db \ 50 | --SiftMatching.use_gpu " + str(use_gpu) 51 | exit_code = os.system(feat_matching_cmd) 52 | if exit_code != 0: 53 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 54 | exit(exit_code) 55 | 56 | ### Bundle adjustment 57 | # The default Mapper tolerance is unnecessarily large, 58 | # decreasing it speeds up bundle adjustment steps. 59 | mapper_cmd = (colmap_command + " mapper \ 60 | --database_path " + args.source_path + "/distorted/database.db \ 61 | --image_path " + args.source_path + "/input \ 62 | --output_path " + args.source_path + "/distorted/sparse \ 63 | --Mapper.ba_global_function_tolerance=0.000001") 64 | exit_code = os.system(mapper_cmd) 65 | if exit_code != 0: 66 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 67 | exit(exit_code) 68 | 69 | ### Image undistortion 70 | ## We need to undistort our images into ideal pinhole intrinsics. 71 | img_undist_cmd = (colmap_command + " image_undistorter \ 72 | --image_path " + args.source_path + "/input \ 73 | --input_path " + args.source_path + "/distorted/sparse/0 \ 74 | --output_path " + args.source_path + "\ 75 | --output_type COLMAP") 76 | exit_code = os.system(img_undist_cmd) 77 | if exit_code != 0: 78 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 79 | exit(exit_code) 80 | 81 | files = os.listdir(args.source_path + "/sparse") 82 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 83 | # Copy each file from the source directory to the destination directory 84 | for file in files: 85 | if file == '0': 86 | continue 87 | source_file = os.path.join(args.source_path, "sparse", file) 88 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 89 | shutil.move(source_file, destination_file) 90 | 91 | if(args.resize): 92 | print("Copying and resizing...") 93 | 94 | # Resize images. 95 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 97 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 98 | # Get the list of files in the source directory 99 | files = os.listdir(args.source_path + "/images") 100 | # Copy each file from the source directory to the destination directory 101 | for file in files: 102 | source_file = os.path.join(args.source_path, "images", file) 103 | 104 | destination_file = os.path.join(args.source_path, "images_2", file) 105 | shutil.copy2(source_file, destination_file) 106 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 107 | if exit_code != 0: 108 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 109 | exit(exit_code) 110 | 111 | destination_file = os.path.join(args.source_path, "images_4", file) 112 | shutil.copy2(source_file, destination_file) 113 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 114 | if exit_code != 0: 115 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 116 | exit(exit_code) 117 | 118 | destination_file = os.path.join(args.source_path, "images_8", file) 119 | shutil.copy2(source_file, destination_file) 120 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 121 | if exit_code != 0: 122 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 123 | exit(exit_code) 124 | 125 | print("Done.") 126 | -------------------------------------------------------------------------------- /docker_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | import argparse 5 | import subprocess 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser( 9 | description="Run a Docker container for training with dataset & output mounted." 10 | ) 11 | # Named argument -s / --scene => dataset directory 12 | parser.add_argument( 13 | "-s", "--scene", 14 | required=True, 15 | help="Path to the dataset directory (e.g. /data/nerf_datasets/zipnerf_ud/london)." 16 | ) 17 | # Named argument -m / --model_path => output directory 18 | parser.add_argument( 19 | "-m", "--model_path", 20 | default="./model_output", 21 | help="Path to the model/output directory." 22 | ) 23 | # Optional port/ip 24 | parser.add_argument( 25 | "--port", 26 | default="6009", 27 | help="Port to map inside the container. Defaults to 6009." 28 | ) 29 | parser.add_argument( 30 | "--ip", 31 | default="127.0.0.1", 32 | help="IP to bind the port to. Defaults to 127.0.0.1." 33 | ) 34 | 35 | # Use parse_known_args to capture any extra flags (unknown) 36 | # that we want to forward to train.py: 37 | known_args, unknown_args = parser.parse_known_args() 38 | 39 | # Now build the docker command: 40 | docker_cmd = [ 41 | "docker", "run", "--rm", "--gpus", "all", 42 | "-v", "/tmp/NVIDIA:/tmp/NVIDIA", 43 | # "--user", "$(id -u):$(id -g)", 44 | "-e", "NVIDIA_DRIVER_CAPABILITIES=graphics,compute,utility", 45 | # Mount the scene/dataset directory and model_path 46 | "-v", f"{known_args.scene}:/data/dataset", 47 | "-v", f"{known_args.model_path}:/data/output", 48 | # Also mount the current directory for your code 49 | #"-v", f"{os.getcwd()}:/ever_training2", 50 | # Port mapping 51 | "-p", f"{known_args.ip}:{known_args.port}:{known_args.port}", 52 | "halfpotato/ever:latest", 53 | "bash", "-c", 54 | ( 55 | "source activate ever && " 56 | # "cd /ever_training2 && " 57 | # "rm -r ever && " 58 | # "cp -r /ever_training/ever . && " 59 | # "$@" references extra arguments from the final "_" placeholder 60 | "python train.py -s /data/dataset -m /data/output \"$@\"" 61 | ), 62 | "_" # Placeholder for extra arguments 63 | ] 64 | 65 | # Append the unknown_args so train.py sees them 66 | docker_cmd += unknown_args 67 | print(docker_cmd) 68 | 69 | print("Running:", " ".join(docker_cmd)) # For debugging 70 | subprocess.run(docker_cmd, check=True) 71 | 72 | if __name__ == "__main__": 73 | main() 74 | 75 | -------------------------------------------------------------------------------- /docker_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # docker_train.sh 3 | # 4 | # Usage: bash docker_train.sh [additional train.py flags] 5 | # 6 | # This script mounts the dataset and output directories to /data/dataset and /data/output 7 | # inside the Docker container and then runs: 8 | # python train.py -s /data/dataset -m /data/output 9 | 10 | if [ "$#" -lt 2 ]; then 11 | echo "Usage: $0 [additional train.py flags]" 12 | exit 1 13 | fi 14 | 15 | DATASET_DIR="$1" 16 | OUTPUT_DIR="$2" 17 | PORT="${3:-6009}" 18 | IP="${4:-127.0.0.1}" 19 | shift 4 20 | 21 | docker run --rm --gpus all \ 22 | -v /tmp/NVIDIA:/tmp/NVIDIA \ 23 | -e NVIDIA_DRIVER_CAPABILITIES=graphics,compute,utility \ 24 | -v "$DATASET_DIR":/data/dataset \ 25 | -v "$OUTPUT_DIR":/data/output \ 26 | -v "$(pwd)":/ever_training2 \ 27 | -p "$IP:$PORT:$PORT" \ 28 | halfpotato/ever:latest \ 29 | bash -c 'source activate ever && cd /ever_training2 && rm -r ever && cp -r /ever_training/ever . && python train.py -s /data/dataset -m /data/output "$@"' _ "$@" 30 | 31 | -------------------------------------------------------------------------------- /docker_visualize.py: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # docker_visualize.sh 3 | # 4 | # Usage: 5 | # bash docker_visualize.sh [additional host_render_server.py flags] 6 | # 7 | # This script mounts the model and scene directories into the container at /data/trained_model and /data/scene, 8 | # exposes the specified port on the given IP, and then executes: 9 | # python host_render_server.py -m /data/trained_model -s /data/scene --port --ip [extra flags] 10 | 11 | if [ "$#" -lt 4 ]; then 12 | echo "Usage: $0 [additional host_render_server.py flags]" 13 | exit 1 14 | fi 15 | 16 | TRAINED_MODEL_LOCATION="$1" 17 | SCENE_LOCATION="$2" 18 | PORT="${3:-6009}" 19 | IP="${4:-127.0.0.1}" 20 | shift 4 21 | 22 | docker run --rm --gpus all -it \ 23 | -v /tmp/NVIDIA:/tmp/NVIDIA \ 24 | -e NVIDIA_DRIVER_CAPABILITIES=graphics,compute,utility \ 25 | -v "$TRAINED_MODEL_LOCATION":/data/trained_model \ 26 | -v "$SCENE_LOCATION":/data/scene \ 27 | -p "$IP:$PORT:$PORT" \ 28 | -v "$(pwd)":/ever_training2 \ 29 | ever \ 30 | bash -c "source activate ever && cd /ever_training2 && rm -r ever && cp -r /ever_training/ever . && python host_render_server.py -m /data/trained_model -s /data/scene --port $PORT --ip $IP $*" 31 | 32 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ever 2 | channels: 3 | - conda-forge 4 | - nvidia 5 | - pytorch 6 | - pytorch3d 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_kmp_llvm 10 | - absl-py=2.1.0=pyhd8ed1ab_0 11 | - alsa-lib=1.2.10=hd590300_0 12 | - aom=3.7.1=h59595ed_0 13 | - asttokens=2.4.0=pyhd8ed1ab_0 14 | - attr=2.5.1=h166bdaf_1 15 | - attrs=23.2.0=pyh71513ae_0 16 | - backcall=0.2.0=pyh9f0ad1d_0 17 | - backports=1.0=pyhd8ed1ab_3 18 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 19 | - binutils_impl_linux-64=2.40=hf600244_0 20 | - binutils_linux-64=2.40=hbdbef99_2 21 | - blas=2.116=mkl 22 | - blas-devel=3.9.0=16_linux64_mkl 23 | - brotli=1.1.0=hd590300_1 24 | - brotli-bin=1.1.0=hd590300_1 25 | - brotli-python=1.1.0=py311hb755f60_1 26 | - bzip2=1.0.8=h7f98852_4 27 | - c-ares=1.21.0=hd590300_0 28 | - ca-certificates=2023.7.22=hbcca054_0 29 | - cairo=1.18.0=h3faef2a_0 30 | - cattrs=23.2.3=pyhd8ed1ab_0 31 | - certifi=2023.7.22=pyhd8ed1ab_0 32 | - charset-normalizer=3.3.0=pyhd8ed1ab_0 33 | - colorama=0.4.6=pyhd8ed1ab_0 34 | - comm=0.1.4=pyhd8ed1ab_0 35 | - contourpy=1.1.1=py311h9547e67_1 36 | - cuda-cudart=12.1.105=0 37 | - cuda-cupti=12.1.105=0 38 | - cuda-libraries=12.1.0=0 39 | - cuda-nvrtc=12.1.105=0 40 | - cuda-nvtx=12.1.105=0 41 | - cuda-opencl=12.3.52=0 42 | - cuda-runtime=12.1.0=0 43 | - cuda-version=11.8=h70ddcb2_2 44 | - cudatoolkit=11.8.0=h4ba93d1_12 45 | - cudnn=8.8.0.121=h838ba91_3 46 | - cycler=0.12.1=pyhd8ed1ab_0 47 | - dataclasses=0.8=pyhc8e2a94_3 48 | - dav1d=1.2.1=hd590300_0 49 | - dbus=1.13.6=h5008d03_3 50 | - debugpy=1.8.0=py311hb755f60_1 51 | - decorator=5.1.1=pyhd8ed1ab_0 52 | - docstring-to-markdown=0.14=pyhd8ed1ab_0 53 | - exceptiongroup=1.1.3=pyhd8ed1ab_0 54 | - executing=1.2.0=pyhd8ed1ab_0 55 | - expat=2.5.0=hcb278e6_1 56 | - ffmpeg=6.1.0=gpl_h402741f_101 57 | - filelock=3.12.4=pyhd8ed1ab_0 58 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 59 | - font-ttf-inconsolata=3.000=h77eed37_0 60 | - font-ttf-source-code-pro=2.038=h77eed37_0 61 | - font-ttf-ubuntu=0.83=hab24e00_0 62 | - fontconfig=2.14.2=h14ed4e7_0 63 | - fonts-conda-ecosystem=1=0 64 | - fonts-conda-forge=1=0 65 | - fonttools=4.43.1=py311h459d7ec_0 66 | - freeglut=3.2.2=hac7e632_2 67 | - freetype=2.12.1=h267a509_2 68 | - fribidi=1.0.10=h36c2ea0_0 69 | - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 70 | - gcc_impl_linux-64=13.2.0=h338b0a0_3 71 | - gcc_linux-64=13.2.0=h112eaf3_2 72 | - gettext=0.21.1=h27087fc_0 73 | - glib=2.78.0=hfc55251_0 74 | - glib-tools=2.78.0=hfc55251_0 75 | - gmp=6.3.0=h59595ed_0 76 | - gmpy2=2.1.2=py311h6a5fa03_1 77 | - gnutls=3.7.9=hb077bed_0 78 | - graphite2=1.3.13=h58526e2_1001 79 | - grpcio=1.58.1=py311ha6695c7_2 80 | - gst-plugins-base=1.22.6=h8e1006c_2 81 | - gstreamer=1.22.6=h98fc4e7_2 82 | - gxx_impl_linux-64=13.2.0=h338b0a0_3 83 | - gxx_linux-64=13.2.0=hc53e3bf_2 84 | - harfbuzz=8.2.1=h3d44ed6_0 85 | - hdf5=1.14.2=nompi_h4f84152_100 86 | - icecream=2.1.3=pyhd8ed1ab_0 87 | - icu=73.2=h59595ed_0 88 | - idna=3.4=pyhd8ed1ab_0 89 | - imageio=2.31.5=pyh8c1a49c_0 90 | - importlib-metadata=6.8.0=pyha770c72_0 91 | - importlib_metadata=6.8.0=hd8ed1ab_0 92 | - iopath=0.1.9=pyhd8ed1ab_0 93 | - ipykernel=6.26.0=pyhf8b6a83_0 94 | - ipympl=0.9.3=pyhd8ed1ab_0 95 | - ipython=8.16.1=pyh0d859eb_0 96 | - ipython_genutils=0.2.0=py_1 97 | - ipywidgets=8.1.1=pyhd8ed1ab_0 98 | - jasper=4.1.0=he6dfbbe_0 99 | - jax=0.4.19=pyhd8ed1ab_0 100 | - jaxlib=0.4.19=cuda118py311h045a74e_200 101 | - jedi=0.19.1=pyhd8ed1ab_0 102 | - jedi-language-server=0.41.2=pyhd8ed1ab_0 103 | - jinja2=3.1.2=pyhd8ed1ab_1 104 | - jupyter_client=8.5.0=pyhd8ed1ab_0 105 | - jupyter_core=5.4.0=py311h38be061_0 106 | - jupyterlab_widgets=3.0.9=pyhd8ed1ab_0 107 | - kernel-headers_linux-64=2.6.32=he073ed8_16 108 | - keyutils=1.6.1=h166bdaf_0 109 | - kiwisolver=1.4.5=py311h9547e67_1 110 | - kornia=0.7.0=pyhd8ed1ab_0 111 | - krb5=1.21.2=h659d440_0 112 | - lame=3.100=h166bdaf_1003 113 | - lcms2=2.15=hb7c19ff_3 114 | - ld_impl_linux-64=2.40=h41732ed_0 115 | - lerc=4.0.0=h27087fc_0 116 | - libabseil=20230802.1=cxx17_h59595ed_0 117 | - libaec=1.1.2=h59595ed_1 118 | - libass=0.17.1=h8fe9dca_1 119 | - libblas=3.9.0=16_linux64_mkl 120 | - libbrotlicommon=1.1.0=hd590300_1 121 | - libbrotlidec=1.1.0=hd590300_1 122 | - libbrotlienc=1.1.0=hd590300_1 123 | - libcap=2.69=h0f662aa_0 124 | - libcblas=3.9.0=16_linux64_mkl 125 | - libclang=15.0.7=default_h7634d5b_3 126 | - libclang13=15.0.7=default_h9986a30_3 127 | - libcublas=12.1.0.26=0 128 | - libcufft=11.0.2.4=0 129 | - libcufile=1.8.0.34=0 130 | - libcups=2.3.3=h4637d8d_4 131 | - libcurand=10.3.4.52=0 132 | - libcurl=8.4.0=hca28451_0 133 | - libcusolver=11.4.4.55=0 134 | - libcusparse=12.0.2.55=0 135 | - libdeflate=1.19=hd590300_0 136 | - libdrm=2.4.114=h166bdaf_0 137 | - libedit=3.1.20191231=he28a2e2_2 138 | - libev=4.33=hd590300_2 139 | - libevent=2.1.12=hf998b51_1 140 | - libexpat=2.5.0=hcb278e6_1 141 | - libffi=3.4.2=h7f98852_5 142 | - libflac=1.4.3=h59595ed_0 143 | - libgcc-devel_linux-64=13.2.0=ha9c7c90_103 144 | - libgcc-ng=13.2.0=h807b86a_2 145 | - libgcrypt=1.10.1=h166bdaf_0 146 | - libgfortran-ng=13.2.0=h69a702a_2 147 | - libgfortran5=13.2.0=ha4646dd_2 148 | - libglib=2.78.0=hebfc3b9_0 149 | - libglu=9.0.0=hac7e632_1003 150 | - libgomp=13.2.0=h807b86a_2 151 | - libgpg-error=1.47=h71f35ed_0 152 | - libgrpc=1.58.1=he06187c_2 153 | - libhwloc=2.9.3=default_h554bfaf_1009 154 | - libiconv=1.17=h166bdaf_0 155 | - libidn2=2.3.4=h166bdaf_0 156 | - libjpeg-turbo=3.0.0=hd590300_1 157 | - liblapack=3.9.0=16_linux64_mkl 158 | - liblapacke=3.9.0=16_linux64_mkl 159 | - libllvm15=15.0.7=h5cf9203_3 160 | - libnghttp2=1.52.0=h61bc06f_0 161 | - libnpp=12.0.2.50=0 162 | - libnsl=2.0.1=hd590300_0 163 | - libnvjitlink=12.1.105=0 164 | - libnvjpeg=12.1.1.14=0 165 | - libogg=1.3.4=h7f98852_1 166 | - libopencv=4.8.1=py311h60c0964_4 167 | - libopenvino=2023.1.0=h59595ed_2 168 | - libopenvino-auto-batch-plugin=2023.1.0=h59595ed_2 169 | - libopenvino-auto-plugin=2023.1.0=h59595ed_2 170 | - libopenvino-hetero-plugin=2023.1.0=h59595ed_2 171 | - libopenvino-intel-cpu-plugin=2023.1.0=h59595ed_2 172 | - libopenvino-intel-gpu-plugin=2023.1.0=h59595ed_2 173 | - libopenvino-ir-frontend=2023.1.0=h59595ed_2 174 | - libopenvino-onnx-frontend=2023.1.0=h59595ed_2 175 | - libopenvino-paddle-frontend=2023.1.0=h59595ed_2 176 | - libopenvino-pytorch-frontend=2023.1.0=h59595ed_2 177 | - libopenvino-tensorflow-frontend=2023.1.0=h59595ed_2 178 | - libopenvino-tensorflow-lite-frontend=2023.1.0=h59595ed_2 179 | - libopus=1.3.1=h7f98852_1 180 | - libpciaccess=0.17=h166bdaf_0 181 | - libpng=1.6.39=h753d276_0 182 | - libpq=16.0=hfc447b1_1 183 | - libprotobuf=4.24.3=hf27288f_1 184 | - libre2-11=2023.06.02=h7a70373_0 185 | - libsanitizer=13.2.0=h7e041cc_3 186 | - libsndfile=1.2.2=hc60ed4a_1 187 | - libsodium=1.0.18=h36c2ea0_1 188 | - libsqlite=3.43.2=h2797004_0 189 | - libssh2=1.11.0=h0841786_0 190 | - libstdcxx-devel_linux-64=13.2.0=ha9c7c90_103 191 | - libstdcxx-ng=13.2.0=h7e041cc_2 192 | - libsystemd0=254=h3516f8a_0 193 | - libtasn1=4.19.0=h166bdaf_0 194 | - libtiff=4.6.0=ha9c0a0a_2 195 | - libunistring=0.9.10=h7f98852_0 196 | - libuuid=2.38.1=h0b41bf4_0 197 | - libva=2.20.0=hd590300_0 198 | - libvorbis=1.3.7=h9c3ff4c_0 199 | - libvpx=1.13.1=h59595ed_0 200 | - libwebp-base=1.3.2=hd590300_0 201 | - libxcb=1.15=h0b41bf4_0 202 | - libxkbcommon=1.6.0=h5d7e998_0 203 | - libxml2=2.11.5=h232c23b_1 204 | - libzlib=1.2.13=hd590300_5 205 | - llvm-openmp=15.0.7=h0cdce71_0 206 | - lsprotocol=2023.0.1=pyhd8ed1ab_0 207 | - lz4-c=1.9.4=hcb278e6_0 208 | - markdown=3.5.2=pyhd8ed1ab_0 209 | - markupsafe=2.1.3=py311h459d7ec_1 210 | - matplotlib=3.8.0=py311h38be061_2 211 | - matplotlib-base=3.8.0=py311h54ef318_2 212 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 213 | - mkl=2022.1.0=h84fe81f_915 214 | - mkl-devel=2022.1.0=ha770c72_916 215 | - mkl-include=2022.1.0=h84fe81f_915 216 | - ml_dtypes=0.3.1=py311h320fe9a_2 217 | - mpc=1.3.1=hfe3b2da_0 218 | - mpfr=4.2.0=hb012696_0 219 | - mpg123=1.32.3=h59595ed_0 220 | - mpmath=1.3.0=pyhd8ed1ab_0 221 | - munkres=1.1.4=pyh9f0ad1d_0 222 | - mysql-common=8.0.33=hf1915f5_5 223 | - mysql-libs=8.0.33=hca2cd23_5 224 | - nccl=2.19.3.1=h6103f9b_0 225 | - ncurses=6.4=hcb278e6_0 226 | - nest-asyncio=1.5.8=pyhd8ed1ab_0 227 | - nettle=3.9.1=h7ab15ed_0 228 | - networkx=3.2=pyhd8ed1ab_1 229 | - nspr=4.35=h27087fc_0 230 | - nss=3.94=h1d7d5a4_0 231 | - numpy=1.26.0=py311h64a7726_0 232 | - ocl-icd=2.3.1=h7f98852_0 233 | - ocl-icd-system=1.0.0=1 234 | - openh264=2.3.1=hcb278e6_2 235 | - openjpeg=2.5.0=h488ebb8_3 236 | - openssl=3.1.3=hd590300_0 237 | - opt-einsum=3.3.0=hd8ed1ab_2 238 | - opt_einsum=3.3.0=pyhc1e730c_2 239 | - p11-kit=0.24.1=hc5aa10d_0 240 | - packaging=23.2=pyhd8ed1ab_0 241 | - parso=0.8.3=pyhd8ed1ab_0 242 | - pcre2=10.40=hc3806b6_0 243 | - pexpect=4.8.0=pyh1a96a4e_2 244 | - pickleshare=0.7.5=py_1003 245 | - pillow=10.1.0=py311ha6c5da5_0 246 | - pip=23.3=pyhd8ed1ab_0 247 | - pixman=0.42.2=h59595ed_0 248 | - platformdirs=3.11.0=pyhd8ed1ab_0 249 | - ply=3.11=py_1 250 | - plyfile=1.0.1=pyhd8ed1ab_0 251 | - portalocker=2.8.2=py311h38be061_1 252 | - prompt-toolkit=3.0.39=pyha770c72_0 253 | - prompt_toolkit=3.0.39=hd8ed1ab_0 254 | - protobuf=4.24.3=py311h46cbc50_1 255 | - psutil=5.9.5=py311h459d7ec_1 256 | - pthread-stubs=0.4=h36c2ea0_1001 257 | - ptyprocess=0.7.0=pyhd3deb0d_0 258 | - pugixml=1.14=h59595ed_0 259 | - pulseaudio-client=16.1=hb77b528_5 260 | - pure_eval=0.2.2=pyhd8ed1ab_0 261 | - py-opencv=4.8.1=py311hcd063c8_4 262 | - pygls=1.3.0=pyhd8ed1ab_0 263 | - pygments=2.16.1=pyhd8ed1ab_0 264 | - pyparsing=3.1.1=pyhd8ed1ab_0 265 | - pyqt=5.15.9=py311hf0fb5b6_5 266 | - pyqt5-sip=12.12.2=py311hb755f60_5 267 | - pysocks=1.7.1=pyha2e5f31_6 268 | - python=3.11.6=hab00c5b_0_cpython 269 | - python-dateutil=2.8.2=pyhd8ed1ab_0 270 | - python_abi=3.11=4_cp311 271 | - pytorch=2.1.0=py3.11_cuda12.1_cudnn8.9.2_0 272 | - pytorch-cuda=12.1=ha16c6d3_5 273 | - pytorch-mutex=1.0=cuda 274 | - pytorch3d=0.7.5=py311_cu121_pyt210 275 | - pyyaml=6.0.1=py311h459d7ec_1 276 | - pyzmq=25.1.1=py311h34ded2d_2 277 | - qt-main=5.15.8=h82b777d_17 278 | - re2=2023.06.02=h2873b5e_0 279 | - readline=8.2=h8228510_1 280 | - requests=2.31.0=pyhd8ed1ab_0 281 | - scipy=1.11.3=py311h64a7726_1 282 | - setuptools=68.2.2=pyhd8ed1ab_0 283 | - sip=6.7.12=py311hb755f60_0 284 | - six=1.16.0=pyh6c4a22f_0 285 | - snappy=1.1.10=h9fff704_0 286 | - stack_data=0.6.2=pyhd8ed1ab_0 287 | - svt-av1=1.7.0=h59595ed_0 288 | - sympy=1.12=pypyh9d50eac_103 289 | - sysroot_linux-64=2.12=he073ed8_16 290 | - tabulate=0.9.0=pyhd8ed1ab_1 291 | - tbb=2021.10.0=h00ab1b0_2 292 | - tensorboard=2.16.2=pyhd8ed1ab_0 293 | - tensorboard-data-server=0.7.0=py311h63ff55d_1 294 | - termcolor=2.4.0=pyhd8ed1ab_0 295 | - tk=8.6.13=h2797004_0 296 | - toml=0.10.2=pyhd8ed1ab_0 297 | - tomli=2.0.1=pyhd8ed1ab_0 298 | - torchaudio=2.1.0=py311_cu121 299 | - torchtriton=2.1.0=py311 300 | - torchvision=0.16.0=py311_cu121 301 | - tornado=6.3.3=py311h459d7ec_1 302 | - tqdm=4.66.1=pyhd8ed1ab_0 303 | - traitlets=5.12.0=pyhd8ed1ab_0 304 | - typing-extensions=4.8.0=hd8ed1ab_0 305 | - typing_extensions=4.8.0=pyha770c72_0 306 | - tzdata=2023c=h71feb2d_0 307 | - urllib3=2.0.7=pyhd8ed1ab_0 308 | - wcwidth=0.2.8=pyhd8ed1ab_0 309 | - werkzeug=3.0.1=pyhd8ed1ab_0 310 | - wheel=0.41.2=pyhd8ed1ab_0 311 | - widgetsnbextension=4.0.9=pyhd8ed1ab_0 312 | - x264=1!164.3095=h166bdaf_2 313 | - x265=3.5=h924138e_3 314 | - xcb-util=0.4.0=hd590300_1 315 | - xcb-util-image=0.4.0=h8ee46fc_1 316 | - xcb-util-keysyms=0.4.0=h8ee46fc_1 317 | - xcb-util-renderutil=0.3.9=hd590300_1 318 | - xcb-util-wm=0.4.1=h8ee46fc_1 319 | - xkeyboard-config=2.40=hd590300_0 320 | - xorg-fixesproto=5.0=h7f98852_1002 321 | - xorg-inputproto=2.3.2=h7f98852_1002 322 | - xorg-kbproto=1.0.7=h7f98852_1002 323 | - xorg-libice=1.1.1=hd590300_0 324 | - xorg-libsm=1.2.4=h7391055_0 325 | - xorg-libx11=1.8.7=h8ee46fc_0 326 | - xorg-libxau=1.0.11=hd590300_0 327 | - xorg-libxdmcp=1.1.3=h7f98852_0 328 | - xorg-libxext=1.3.4=h0b41bf4_2 329 | - xorg-libxfixes=5.0.3=h7f98852_1004 330 | - xorg-libxi=1.7.10=h7f98852_0 331 | - xorg-libxrender=0.9.11=hd590300_0 332 | - xorg-renderproto=0.11.1=h7f98852_1002 333 | - xorg-xextproto=7.3.0=h0b41bf4_1003 334 | - xorg-xf86vidmodeproto=2.3.1=h7f98852_1002 335 | - xorg-xproto=7.0.31=h7f98852_1007 336 | - xz=5.2.6=h166bdaf_0 337 | - yacs=0.1.8=pyhd8ed1ab_0 338 | - yaml=0.2.5=h7f98852_2 339 | - zeromq=4.3.5=h59595ed_0 340 | - zipp=3.17.0=pyhd8ed1ab_0 341 | - zlib=1.2.13=hd590300_5 342 | - zstd=1.5.5=hfc55251_0 343 | 344 | -------------------------------------------------------------------------------- /extract_metadata.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from pathlib import Path 16 | from argparse import ArgumentParser, Namespace 17 | import pickle 18 | from PIL import Image, ExifTags, TiffImagePlugin 19 | import subprocess 20 | import json 21 | from tqdm import tqdm 22 | 23 | parser = ArgumentParser(description="Extract metadata parameters") 24 | parser.add_argument('images', type=Path) 25 | parser.add_argument('output', type=Path) 26 | parser.add_argument('--exif_tool_path', type=str, default="exiftool") 27 | args = parser.parse_args() 28 | 29 | metadatas = { 30 | } 31 | 32 | iso_tags = ["Sony ISO", "ISO"] 33 | exposure_tags = ["Sony Exposure Time 2", "Exposure Time"] 34 | aperature_tags = ["FNumber", "Sony F Number 2"] 35 | 36 | def get_value(data, tags): 37 | vs = [data[t] for t in tags if t in data] 38 | return vs[0] if len(vs) > 0 else -1 39 | 40 | for path in tqdm(args.images.iterdir()): 41 | try: 42 | img = Image.open(path) 43 | except: 44 | print(path, " is not an image") 45 | 46 | process = subprocess.Popen( 47 | [args.exif_tool_path,str(path)], 48 | stdout=subprocess.PIPE, 49 | stderr=subprocess.STDOUT, 50 | universal_newlines=True) 51 | exif = { 52 | ExifTags.TAGS[k]: float(v) if isinstance(v, TiffImagePlugin.IFDRational) else v 53 | for k, v in img._getexif().items() 54 | if k in ExifTags.TAGS 55 | } 56 | for tag in process.stdout: 57 | line = tag.strip().split(':') 58 | exif[line[0].strip()] = line[-1].strip() 59 | data = dict( 60 | iso=get_value(exif, iso_tags), 61 | exposure=get_value(exif, exposure_tags), 62 | aperature=get_value(exif, aperature_tags), 63 | ) 64 | # print(exif) 65 | metadatas[path.name] = data 66 | 67 | with args.output.open("w") as f: 68 | json.dump(metadatas, f) 69 | 70 | -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | GLO_SCENES = ["alameda"] 16 | 17 | # mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 18 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 19 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 20 | # mipnerf360_indoor_scenes = ["counter", "kitchen", "bonsai"] 21 | tanks_and_temples_scenes = ["truck", "train"] 22 | deep_blending_scenes = ["drjohnson", "playroom"] 23 | zipnerf_scenes = ["alameda", "nyc", "london", "berlin"] 24 | 25 | 26 | parser = ArgumentParser(description="Full evaluation script parameters") 27 | parser.add_argument("--skip_training", action="store_true") 28 | parser.add_argument("--skip_rendering", action="store_true") 29 | parser.add_argument("--skip_metrics", action="store_true") 30 | parser.add_argument("--output_path", default="./eval") 31 | args, _ = parser.parse_known_args() 32 | 33 | parser.add_argument('--mipnerf360', "-m360", default='', type=str) 34 | parser.add_argument("--tanksandtemples", "-tat", default='', type=str) 35 | parser.add_argument("--deepblending", "-db", default='', type=str) 36 | parser.add_argument("--zipnerf", "-zn", default='', type=str) 37 | parser.add_argument("--skip_360_indoor", action='store_true') 38 | parser.add_argument("--skip_360_outdoor", action='store_true') 39 | parser.add_argument("--port", default=6009, type=int) 40 | parser.add_argument("--additional_args", default="", type=str) 41 | args = parser.parse_args() 42 | 43 | if args.skip_360_outdoor: 44 | mipnerf360_outdoor_scenes = [] 45 | if args.skip_360_indoor: 46 | mipnerf360_indoor_scenes = [] 47 | if len(args.mipnerf360) == 0: 48 | mipnerf360_indoor_scenes = [] 49 | mipnerf360_outdoor_scenes = [] 50 | if len(args.tanksandtemples) == 0: 51 | tanks_and_temples_scenes = [] 52 | if len(args.deepblending) == 0: 53 | deep_blending_scenes = [] 54 | if len(args.zipnerf) == 0: 55 | zipnerf_scenes = [] 56 | 57 | all_scenes = [] 58 | all_scenes.extend(mipnerf360_outdoor_scenes) 59 | all_scenes.extend(mipnerf360_indoor_scenes) 60 | all_scenes.extend(tanks_and_temples_scenes) 61 | all_scenes.extend(deep_blending_scenes) 62 | all_scenes.extend(zipnerf_scenes) 63 | 64 | if not args.skip_training: 65 | common_args = f" --quiet --eval --test_iterations -1 --port {args.port} {args.additional_args}" 66 | for scene in mipnerf360_outdoor_scenes: 67 | source = args.mipnerf360 + "/" + scene 68 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 69 | for scene in mipnerf360_indoor_scenes: 70 | source = args.mipnerf360 + "/" + scene 71 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 72 | for scene in tanks_and_temples_scenes: 73 | source = args.tanksandtemples + "/" + scene 74 | os.system( 75 | "python train.py -s " 76 | + source 77 | + " -m " 78 | + args.output_path 79 | + "/" 80 | + scene 81 | + common_args 82 | ) 83 | for scene in deep_blending_scenes: 84 | source = args.deepblending + "/" + scene 85 | os.system( 86 | "python train.py -s " 87 | + source 88 | + " -m " 89 | + args.output_path 90 | + "/" 91 | + scene 92 | + common_args 93 | ) 94 | for scene in zipnerf_scenes: 95 | glo_args = "" 96 | if scene in GLO_SCENES: 97 | glo_args = " --enable_GLO --glo_lr 0 --checkpoint_iterations 7000 30000 " 98 | source = args.zipnerf + "/" + scene 99 | os.system( 100 | "python train.py -s " 101 | + source 102 | + " -m " 103 | + args.output_path 104 | + "/" 105 | + scene 106 | + common_args 107 | + glo_args 108 | + " -r 1 --images images_2 " 109 | + " --position_lr_init 4e-5 --position_lr_final 4e-7 " 110 | + " --percent_dense 0.0005 --tmin 0" 111 | ) 112 | 113 | if not args.skip_rendering: 114 | all_sources = [] 115 | for scene in mipnerf360_outdoor_scenes: 116 | all_sources.append(args.mipnerf360 + "/" + scene) 117 | for scene in mipnerf360_indoor_scenes: 118 | all_sources.append(args.mipnerf360 + "/" + scene) 119 | for scene in tanks_and_temples_scenes: 120 | all_sources.append(args.tanksandtemples + "/" + scene) 121 | for scene in deep_blending_scenes: 122 | all_sources.append(args.deepblending + "/" + scene) 123 | 124 | common_args = " --quiet --eval --skip_train" 125 | for scene, source in zip(all_scenes, all_sources): 126 | # glo_args = "" 127 | # if scene in GLO_SCENES: 128 | # glo_args = f" --checkpoint {os.path.join(args.output_path,scene,'chkpnt7000.pth')} " 129 | # os.system( 130 | # "python render.py --iteration 7000 -s " 131 | # + source 132 | # + " -m " 133 | # + args.output_path 134 | # + "/" 135 | # + scene 136 | # + common_args 137 | # + glo_args 138 | # ) 139 | glo_args = "" 140 | if scene in GLO_SCENES: 141 | glo_args = f" --checkpoint {os.path.join(args.output_path,scene,'chkpnt30000.pth')} " 142 | os.system( 143 | "python render.py --iteration 30000 -s " 144 | + source 145 | + " -m " 146 | + args.output_path 147 | + "/" 148 | + scene 149 | + common_args 150 | + glo_args 151 | ) 152 | 153 | if not args.skip_metrics: 154 | scenes_string = "" 155 | for scene in all_scenes: 156 | scenes_string += '"' + args.output_path + "/" + scene + '" ' 157 | 158 | os.system("python metrics.py -m " + scenes_string) 159 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | # from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = pc.get_xyz 54 | means2D = screenspace_points 55 | opacity = pc.get_opacity 56 | 57 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 58 | # scaling / rotation by the rasterizer. 59 | scales = None 60 | rotations = None 61 | cov3D_precomp = None 62 | if pipe.compute_cov3D_python: 63 | cov3D_precomp = pc.get_covariance(scaling_modifier) 64 | else: 65 | scales = pc.get_scaling 66 | rotations = pc.get_rotation 67 | 68 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 69 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 70 | shs = None 71 | colors_precomp = None 72 | if override_color is None: 73 | if pipe.convert_SHs_python: 74 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 75 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 76 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 77 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 78 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 79 | else: 80 | shs = pc.get_features 81 | else: 82 | colors_precomp = override_color 83 | 84 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 85 | rendered_image, radii = rasterizer( 86 | means3D = means3D, 87 | means2D = means2D, 88 | shs = shs, 89 | colors_precomp = colors_precomp, 90 | opacities = opacity, 91 | scales = scales, 92 | rotations = rotations, 93 | cov3D_precomp = cov3D_precomp) 94 | 95 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 96 | # They will be excluded from value updates used in the splitting criteria. 97 | return {"render": rendered_image, 98 | "viewspace_points": screenspace_points, 99 | "visibility_filter" : radii > 0, 100 | "radii": radii} 101 | -------------------------------------------------------------------------------- /gaussian_renderer/ever.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import math 18 | 19 | from scene.gaussian_model import GaussianModel 20 | from ever.splinetracers.fast_ellipsoid_splinetracer import trace_rays 21 | # from splinetracer.splinetracers.ellipsoid_splinetracer import trace_rays 22 | MAX_ITERS = 400 23 | from ever.eval_sh import eval_sh as eval_sh2 24 | from utils.sh_utils import eval_sh, RGB2SH, SH2RGB 25 | from kornia import create_meshgrid 26 | import numpy as np 27 | from icecream import ic 28 | from scene.dataset_readers import ProjectionType 29 | from utils import camera_utils_zipnerf 30 | 31 | def get_ray_directions(H, W, focal, center=None, random=True): 32 | """ 33 | Get ray directions for all pixels in camera coordinate. 34 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 35 | ray-tracing-generating-camera-rays/standard-coordinate-systems 36 | Inputs: 37 | H, W, focal: image height, width and focal length 38 | Outputs: 39 | directions: (H, W, 3), the direction of the rays in camera coordinate 40 | """ 41 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]# + 0.5 42 | if random: 43 | grid = grid + torch.rand_like(grid) 44 | else: 45 | grid = grid + 0.5 46 | 47 | i, j = grid.unbind(-1) 48 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 49 | # see https://github.com/bmild/nerf/issues/24 50 | cent = center if center is not None else [W / 2, H / 2] 51 | directions = torch.stack( 52 | [(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1 53 | ) # (H, W, 3) 54 | 55 | return directions 56 | 57 | def get_rays(directions, c2w): 58 | """ 59 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 60 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 61 | ray-tracing-generating-camera-rays/standard-coordinate-systems 62 | Inputs: 63 | directions: (H, W, 3) precomputed ray directions in camera coordinate 64 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 65 | Outputs: 66 | rays_o: (H*W, 3), the origin of the rays in world coordinate 67 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 68 | """ 69 | # Rotate ray directions from camera coordinate to the world coordinate 70 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 71 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 72 | # The origin of all rays is the camera origin in world coordinate 73 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 74 | 75 | rays_d = rays_d.view(-1, 3) 76 | rays_o = rays_o.view(-1, 3) 77 | 78 | return rays_o, rays_d 79 | 80 | def camera2rays_full(view, **kwargs): 81 | w = view.image_width # // 4 82 | h = view.image_height # // 4 83 | # y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') 84 | device = torch.device('cuda') 85 | 86 | x, y = torch.meshgrid(torch.arange(w, device=device), torch.arange(h, device=device), indexing='xy') 87 | 88 | fx = 0.5 * w / np.tan(0.5 * view.FoVx) # original focal length 89 | fy = 0.5 * h / np.tan(0.5 * view.FoVy) # original focal length 90 | pixtocams = torch.eye(3, device=device) 91 | pixtocams[0, 0] = 1/fx 92 | pixtocams[1, 1] = 1/fy 93 | pixtocams[0, 2] = -w/2/fx 94 | pixtocams[1, 2] = -h/2/fy 95 | 96 | T = torch.linalg.inv(view.world_view_transform.T).to(device) 97 | origins, _, directions, _, _ = camera_utils_zipnerf.pixels_to_rays( 98 | x.reshape(-1), y.reshape(-1), 99 | pixtocams.reshape(1, 3, 3), 100 | T[:3].reshape(1, 3, 4), 101 | camtype=view.model, 102 | distortion_params=view.distortion_params, 103 | xnp=torch 104 | ) 105 | origins = origins.float().cuda().contiguous() 106 | directions = directions.float().cuda().contiguous() 107 | # ic(camera2rays(view)[1]) 108 | # ic(directions) 109 | return origins, directions 110 | 111 | def camera2rays(view, **kwargs): 112 | w = view.image_width 113 | h = view.image_height 114 | 115 | fx = 0.5 * w / math.tan(0.5 * view.FoVx) # original focal length 116 | fy = 0.5 * h / math.tan(0.5 * view.FoVy) # original focal length 117 | 118 | directions = get_ray_directions(h, w, [fx, fy], **kwargs).cuda() # (h, w, 3) 119 | directions = (directions / torch.norm(directions, dim=-1, keepdim=True)) 120 | 121 | T = torch.linalg.inv(view.world_view_transform.T.cuda()) 122 | rays_o, rays_d = get_rays( 123 | directions, 124 | T, 125 | ) # both (h*w, 3) 126 | rays_o = (rays_o).contiguous() 127 | return rays_o, rays_d 128 | 129 | def splinerender( 130 | view, 131 | pc: GaussianModel, 132 | pipe, 133 | bg_color: torch.Tensor, 134 | scaling_modifier=1.0, 135 | override_color=None, 136 | random=False, 137 | tmin=None, 138 | tmax=1e7, 139 | ): 140 | device = pc.get_xyz.device 141 | if view.model == ProjectionType.PERSPECTIVE: 142 | rays_o, rays_d = camera2rays(view, random=random) 143 | else: 144 | rays_o, rays_d = camera2rays_full(view, random=False) 145 | 146 | means2D = torch.zeros_like(pc.get_xyz[..., :2]) 147 | means2D.requires_grad = True 148 | 149 | w = view.image_width # // 4 150 | h = view.image_height # // 4 151 | 152 | fx = 0.5 * w / np.tan(0.5 * view.FoVx) # original focal length 153 | fy = 0.5 * h / np.tan(0.5 * view.FoVy) # original focal length 154 | K = torch.tensor([ 155 | [fx, 0, w/2, 0], 156 | [0, fy, h/2, 0], 157 | [0, 0, 1, 0], 158 | ], device="cuda").float() 159 | invK = torch.tensor([ 160 | [1/fx, 0, -w/2/fx], 161 | [0, 1/fy, -h/2/fy], 162 | [0, 0, 1], 163 | [0, 0, 0], 164 | ], device="cuda").float() 165 | device = "cuda" 166 | 167 | wct = view.world_view_transform.cuda().float() 168 | full_wct = torch.eye(4, device="cuda") 169 | full_wct[:, :3] = wct @ K.T 170 | 171 | shs = pc.get_features 172 | # shs[:, (pc.active_sh_degree+1)**2:] = 0 173 | # ic(shs.shape, shs[:, :(pc.active_sh_degree+1)**2].shape) 174 | if pipe.enable_GLO: 175 | if view.glo_vector is not None: 176 | glo_vector = view.glo_vector 177 | else: 178 | glo_vector = torch.zeros((1, 64), device='cuda') 179 | shs = pc.glo_network( 180 | glo_vector.reshape(1, -1), shs.reshape(shs.shape[0], -1) 181 | ).reshape(shs.shape) 182 | 183 | cam_pos = view.camera_center.to(device) 184 | T = torch.linalg.inv(wct.T) 185 | v = T[:3, 2] 186 | net_color = eval_sh2(pc.get_xyz, shs, cam_pos, pc.active_sh_degree) 187 | # ic(net_color, SH2RGB(features)) 188 | net_color = torch.nn.functional.softplus(net_color, beta=10) 189 | features = RGB2SH(net_color).reshape(-1, 1, 3) 190 | 191 | per_point_2d_filter_scale = torch.zeros(pc._xyz.shape[0], device=pc._xyz.device) 192 | 193 | if trace_rays.uses_density: 194 | scales, density = pc.get_scale_and_density_for_rendering(per_point_2d_filter_scale, scaling_modifier) 195 | else: 196 | scales, density = pc.get_scale_and_opacity_for_rendering(per_point_2d_filter_scale, scaling_modifier) 197 | tmin = pc.tmin if tmin is None else tmin 198 | out, extras = trace_rays( 199 | pc.get_xyz, 200 | scales, 201 | pc.get_rotation, 202 | density, 203 | features, 204 | rays_o, 205 | rays_d, 206 | tmin, 207 | tmax, 208 | 100, 209 | means2D, 210 | full_wct.reshape(1, 4, 4), 211 | max_iters=MAX_ITERS, 212 | return_extras=True, 213 | ) 214 | 215 | torch.cuda.synchronize() 216 | radii = torch.ones_like(means2D[..., 0]) 217 | 218 | rendered_image = out[:, :3].T.reshape(3, view.image_height, view.image_width) 219 | num_pixels = (extras['touch_count'] // 2) 220 | 221 | # aspect_ratio = scales.max(dim=-1).values / scales.min(dim=-1).values 222 | side_length = (num_pixels).float().sqrt() #/ aspect_ratio # mul by 2 to get to rect, then sqrt 223 | radii = side_length / 2 * np.sqrt(2) * 2.5 * 5 224 | 225 | return { 226 | "render": rendered_image, 227 | "viewspace_points": means2D, 228 | "visibility_filter": num_pixels >= 4, 229 | "touch_count": extras['touch_count'], 230 | "radii": radii, # match gaussian radius 231 | "iters": extras["iters"].reshape(view.image_height, view.image_width), 232 | "opacity": out[:, 3].reshape(-1, 1), 233 | "distortion_loss": out[:, 4].reshape(-1, 1), 234 | } 235 | 236 | -------------------------------------------------------------------------------- /gaussian_renderer/fast_renderer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | import math 17 | import numpy as np 18 | from scene import Scene 19 | import os 20 | import math 21 | from tqdm import tqdm 22 | from os import makedirs 23 | from gaussian_renderer.ever import get_ray_directions, get_rays, camera2rays_full 24 | import torchvision 25 | from utils.general_utils import safe_state 26 | from argparse import ArgumentParser 27 | from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams 28 | from gaussian_renderer import GaussianModel 29 | import time 30 | 31 | from utils.sh_utils import eval_sh, RGB2SH, SH2RGB 32 | from ever.splinetracers.fast_ellipsoid_splinetracer import sp 33 | from ever.eval_sh import eval_sh as eval_sh2 34 | from utils.graphics_utils import in_screen_from_ndc, project_points, visible_depth_from_camspace, fov2focal 35 | from scene.dataset_readers import ProjectionType 36 | 37 | 38 | MAX_ITERS = 200 39 | 40 | class FastRenderer: 41 | def __init__(self, view, pc, enable_GLO): 42 | self.device = pc.get_xyz.device 43 | self.enable_GLO = enable_GLO 44 | w = view.image_width 45 | h = view.image_height 46 | 47 | fx = 0.5 * w / math.tan(0.5 * view.FoVx) # original focal length 48 | fy = 0.5 * h / math.tan(0.5 * view.FoVy) # original focal length 49 | 50 | directions = get_ray_directions(h, w, [fx, fy], random=False).cuda() # (h, w, 3) 51 | self.directions = (directions / torch.norm(directions, dim=-1, keepdim=True)) 52 | self.otx = sp.OptixContext(torch.device("cuda:0")) 53 | self.prims = sp.Primitives(self.device) 54 | self.mean = pc.get_xyz.contiguous() 55 | self.quat = pc.get_rotation.contiguous() 56 | self.pc = pc 57 | 58 | per_point_2d_filter_scale = torch.zeros(self.pc._xyz.shape[0], device=self.device) 59 | self.per_point_2d_filter_scale = 1 60 | self.scales, self.density = pc.get_scale_and_density_for_rendering(self.per_point_2d_filter_scale, 1.0) 61 | 62 | color = self.get_color(view) 63 | half_attribs = torch.cat([self.mean, self.scales, self.quat], dim=1).half().contiguous() 64 | self.prims.add_primitives(self.mean, self.scales, self.quat, half_attribs, self.density, color) 65 | self.gas = sp.GAS(self.otx, self.device, self.prims, True, False, True) 66 | self.forward = sp.Forward(self.otx, self.device, self.prims, False) 67 | 68 | def set_camera(self, view): 69 | if view.model != ProjectionType.PERSPECTIVE: 70 | rays_o, rays_d = camera2rays_full(view, random=False) 71 | self.directions = rays_d 72 | else: 73 | 74 | w = view.image_width 75 | h = view.image_height 76 | 77 | fx = 0.5 * w / math.tan(0.5 * view.FoVx) # original focal length 78 | fy = 0.5 * h / math.tan(0.5 * view.FoVy) # original focal length 79 | 80 | directions = get_ray_directions(h, w, [fx, fy], random=False).cuda() # (h, w, 3) 81 | self.directions = (directions / torch.norm(directions, dim=-1, keepdim=True)) 82 | 83 | 84 | def get_color(self, view): 85 | shs = self.pc.get_features 86 | # shs[:, (self.pc.active_sh_degree+1)**2:] = 0 87 | if self.enable_GLO: 88 | if view.glo_vector is not None: 89 | glo_vector = view.glo_vector 90 | else: 91 | glo_vector = torch.zeros((1, 64), device='cuda') 92 | shs = self.pc.glo_network( 93 | glo_vector.reshape(1, -1), shs.reshape(shs.shape[0], -1) 94 | ).reshape(shs.shape) 95 | 96 | cam_pos = view.camera_center.to(self.device) 97 | # wct = view.world_view_transform.cuda().float() 98 | # T = torch.linalg.inv(wct.T) 99 | # v = T[:3, 2] 100 | net_color = eval_sh2(self.pc.get_xyz, shs, cam_pos, self.pc.active_sh_degree) 101 | # ic(net_color, SH2RGB(features)) 102 | net_color = torch.nn.functional.softplus(net_color, beta=10) 103 | features = RGB2SH(net_color).reshape(-1, 1, 3) 104 | return features.contiguous() 105 | 106 | def get_rays(self, view): 107 | T = torch.linalg.inv(view.world_view_transform.T.cuda()) 108 | rays_o, rays_d = get_rays( 109 | self.directions, 110 | T, 111 | ) # both (h*w, 3) 112 | rays_o = (rays_o).contiguous() 113 | return rays_o, rays_d 114 | 115 | def trace_rays(self, rayo, rayd, view, tmin, tmax): 116 | color = self.get_color(view) 117 | # prims = sp.Primitives(self.device) 118 | # half_attribs = torch.cat([self.mean, self.scales, self.quat], dim=1).half().contiguous() 119 | # prims.add_primitives(self.mean, self.scales, self.quat, half_attribs, self.density, color) 120 | # self.forward = sp.Forward(self.otx, self.device, prims, False) 121 | 122 | self.prims.set_features(color) 123 | self.forward.update_model(self.prims) 124 | 125 | out = self.forward.trace_rays(self.gas, rayo, rayd, tmin, tmax, MAX_ITERS, 1000) 126 | return out 127 | 128 | def render(self, 129 | view, 130 | pc, 131 | bg_color: torch.Tensor, 132 | tmin=None, 133 | scaling_modifier=1.0): 134 | rays_o, rays_d = self.get_rays(view) 135 | out = self.trace_rays(rays_o, rays_d, view, self.pc.tmin if tmin is None else tmin, 1e7) 136 | iters = out['saved'].iters 137 | rendered_image = out['color'][:, :3].T.reshape(3, view.image_height, view.image_width) 138 | return rendered_image 139 | 140 | 141 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /host_render_server.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import torch 18 | from random import randint 19 | from utils.loss_utils import l1_loss, ssim 20 | import sys 21 | from scene import Scene, GaussianModel 22 | from utils.general_utils import safe_state 23 | import uuid 24 | from tqdm import tqdm 25 | from utils.image_utils import psnr 26 | from argparse import ArgumentParser, Namespace 27 | from arguments import ModelParams, PipelineParams, OptimizationParams 28 | from icecream import ic 29 | import random 30 | import math 31 | import cv2 32 | import json 33 | import traceback 34 | from utils.system_utils import searchForMaxIteration 35 | import time 36 | from gaussian_renderer.fast_renderer import FastRenderer 37 | from gaussian_renderer import render, network_gui 38 | 39 | # renderFunc = splinerender 40 | # renderFunc = render 41 | from scene.dataset_readers import ProjectionType 42 | 43 | def convert_to_float(frac_str): 44 | try: 45 | return float(frac_str) 46 | except ValueError: 47 | num, denom = frac_str.split('/') 48 | try: 49 | leading, num = num.split(' ') 50 | whole = float(leading) 51 | except ValueError: 52 | whole = 0 53 | frac = float(num) / float(denom) 54 | return whole - frac if whole < 0 else whole + frac 55 | 56 | PREVIEW_RES_FACTOR = 1 57 | 58 | try: 59 | from torch.utils.tensorboard import SummaryWriter 60 | TENSORBOARD_FOUND = True 61 | except ImportError: 62 | TENSORBOARD_FOUND = False 63 | 64 | def set_glo_vector(viewpoint_cam, gaussians, camera_inds): 65 | camera_ind = camera_inds[viewpoint_cam.uid] 66 | viewpoint_cam.glo_vector = torch.cat( 67 | [gaussians.glo[camera_ind], torch.tensor([ 68 | math.log( 69 | viewpoint_cam.iso * viewpoint_cam.exposure / 1000), 70 | ], device=gaussians.glo.device) 71 | ] 72 | ) 73 | 74 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 75 | first_iter = 0 76 | gaussians = GaussianModel(dataset.sh_degree, dataset.use_neural_network, dataset.max_opacity) 77 | if checkpoint: 78 | (model_params, first_iter) = torch.load(checkpoint) 79 | gaussians.restore(model_params, opt) 80 | else: 81 | load_iteration = -1 82 | if load_iteration == -1: 83 | loaded_iter = searchForMaxIteration(os.path.join(dataset.model_path, "point_cloud")) 84 | else: 85 | loaded_iter = load_iteration 86 | print("Loading trained model at iteration {}".format(loaded_iter)) 87 | gaussians.load_ply(os.path.join(dataset.model_path, 88 | "point_cloud", 89 | "iteration_" + str(loaded_iter), 90 | "point_cloud.ply")) 91 | # gaussians.load_ply("output/a5911cf7-0/point_cloud/iteration_30000/point_cloud.ply") 92 | # gaussians.load_ply("/home/amai/Downloads/point_cloud.ply") 93 | # gaussians.load_ply("/home/amai/3DGS/output/20e2f33c-e/point_cloud/iteration_30000/point_cloud.ply", legacy_compat=True) 94 | # gaussians.load_ply("/home/amai/gaussian-splatting/output/242678df-0/point_cloud/iteration_30000/point_cloud.ply") 95 | # gaussians.training_setup(opt) 96 | 97 | if pipe.enable_GLO: 98 | metadata_path = os.path.join(dataset.source_path, "metadata.json") 99 | with open(metadata_path, "r") as f: 100 | metadata = json.load(f) 101 | 102 | first_metadata = metadata[list(metadata.keys())[0]] 103 | 104 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 105 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 106 | 107 | iter_start = torch.cuda.Event(enable_timing = True) 108 | iter_end = torch.cuda.Event(enable_timing = True) 109 | 110 | 111 | gaussians.training_setup(opt) 112 | torch.cuda.empty_cache() 113 | st = time.time() 114 | 115 | 116 | renderer = None 117 | 118 | while True: 119 | if network_gui.conn == None: 120 | network_gui.try_connect() 121 | while network_gui.conn != None: 122 | try: 123 | net_image_bytes = None 124 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 125 | custom_cam, do_training, _, _, keep_alive, scaling_modifer = network_gui.receive() 126 | if custom_cam != None: 127 | if pipe.enable_GLO: 128 | custom_cam.glo_vector = torch.cat( 129 | [gaussians.glo[0], torch.tensor([ 130 | math.log( 131 | float(first_metadata['iso']) * convert_to_float(first_metadata['exposure']) / 1000), 132 | ], device=gaussians.glo.device) 133 | ] 134 | ) 135 | # custom_cam.model = viewpoint_cam.model 136 | # custom_cam.distortion_params = viewpoint_cam.distortion_params 137 | # custom_cam.model=ProjectionType.FISHEYE 138 | custom_cam.model=ProjectionType.PERSPECTIVE 139 | # custom_cam.glo_vector = viewpoint_cam.glo_vector 140 | image_width = custom_cam.image_width 141 | image_height = custom_cam.image_height 142 | custom_cam.image_width = image_width // PREVIEW_RES_FACTOR 143 | custom_cam.image_height = image_height // PREVIEW_RES_FACTOR 144 | 145 | if renderer is None: 146 | renderer = FastRenderer(custom_cam, gaussians, pipe.enable_GLO) 147 | 148 | renderer.set_camera(custom_cam) 149 | 150 | st = time.time() 151 | net_image = renderer.render(custom_cam, pipe, background) 152 | 153 | # net_image = renderFunc(custom_cam, gaussians, pipe, background, scaling_modifer, random=False, tmin=0)["render"] 154 | print(1/(time.time()-st)) 155 | net_image = (torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy() 156 | net_image = cv2.resize(net_image, (image_width, image_height)) 157 | # ic(net_image.shape, net_image.dtype) 158 | net_image_bytes = memoryview(net_image) 159 | network_gui.send(net_image_bytes, dataset.source_path) 160 | torch.cuda.empty_cache() 161 | except Exception as e: 162 | print(traceback.format_exc()) 163 | network_gui.conn = None 164 | 165 | 166 | if __name__ == "__main__": 167 | # Set up command line argument parser 168 | parser = ArgumentParser(description="Training script parameters") 169 | lp = ModelParams(parser) 170 | op = OptimizationParams(parser) 171 | pp = PipelineParams(parser) 172 | parser.add_argument('--ip', type=str, default="127.0.0.1") 173 | parser.add_argument('--port', type=int, default=6009) 174 | parser.add_argument('--debug_from', type=int, default=-1) 175 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 176 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) 177 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 178 | parser.add_argument("--quiet", action="store_true") 179 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 180 | parser.add_argument("--start_checkpoint", type=str, default = None) 181 | args = parser.parse_args(sys.argv[1:]) 182 | args.save_iterations.append(args.iterations) 183 | # args.checkpoint_iterations.append(args.iterations) 184 | 185 | # Initialize system state (RNG) 186 | safe_state(args.quiet) 187 | 188 | # Start GUI server, configure and run training 189 | network_gui.init(args.ip, args.port) 190 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 191 | # training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 192 | training(lp.extract(args), op.extract(args), pp.extract(args), args.save_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 193 | 194 | # All done 195 | 196 | -------------------------------------------------------------------------------- /install.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cp ever/new_files/*.py . 4 | cp -r ever/new_files/notebooks . 5 | cp ever/new_files/scene/* scene/ 6 | cp ever/new_files/gaussian_renderer/* gaussian_renderer/ 7 | cp ever/new_files/utils/* utils/ 8 | 9 | git apply ../ever/new_files/sibr_patch.patch 10 | 11 | # Build splinetracer 12 | mkdir ever/build 13 | cd ever/build 14 | # CXX=/usr/bin/g++-11 CC=/usr/bin/gcc-11 cmake -DOptiX_INSTALL_DIR=$OptiX_INSTALL_DIR -D_GLIBCXX_USE_CXX11_ABI=1 .. 15 | # CXX=$CXX CC=$CC cmake -DOptiX_INSTALL_DIR=$OptiX_INSTALL_DIR .. 16 | CXX=$CXX CC=$CC cmake -DOptiX_INSTALL_DIR=$OptiX_INSTALL_DIR -DCMAKE_CUDA_ARCHITECTURES="50;60;61;70;75;80;86" .. 17 | make -j8 18 | cd ../.. 19 | 20 | pip install -e submodules/simple-knn 21 | 22 | # SIBR Viewer 23 | cd SIBR_viewers 24 | cmake -Bbuild . -DCMAKE_BUILD_TYPE=Release 25 | cmake --build build -j24 --target install 26 | cd ../.. 27 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /measure_fps.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | import math 17 | import numpy as np 18 | from scene import Scene 19 | import os 20 | import math 21 | from tqdm import tqdm 22 | from os import makedirs 23 | from gaussian_renderer import get_ray_directions, get_rays 24 | import torchvision 25 | from utils.general_utils import safe_state 26 | from argparse import ArgumentParser 27 | from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams 28 | from gaussian_renderer import GaussianModel 29 | from gaussian_renderer.fast_renderer import FastRenderer 30 | import time 31 | 32 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 33 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 34 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 35 | 36 | makedirs(render_path, exist_ok=True) 37 | makedirs(gts_path, exist_ok=True) 38 | 39 | renderer = FastRenderer(views[0], gaussians, pipeline.enable_GLO) 40 | 41 | # warmup 42 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 43 | camera_inds = {view.uid: i for i, view in enumerate(views)} 44 | camera_ind = camera_inds[view.uid] 45 | #view.glo_vector = gaussians.glo[camera_ind] 46 | if gaussians.glo is not None: 47 | view.glo_vector = torch.cat( 48 | [gaussians.glo[camera_ind], torch.tensor([ 49 | math.log( 50 | view.iso * view.exposure / 1000), 51 | ], device=gaussians.glo.device) 52 | ] 53 | ) 54 | rendering = renderer.render(view, pipeline, background) 55 | 56 | fps = [] 57 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 58 | # if idx != 424: 59 | # continue 60 | camera_inds = {view.uid: i for i, view in enumerate(views)} 61 | camera_ind = camera_inds[view.uid] 62 | #view.glo_vector = gaussians.glo[camera_ind] 63 | if gaussians.glo is not None: 64 | view.glo_vector = torch.cat( 65 | [gaussians.glo[camera_ind], torch.tensor([ 66 | math.log( 67 | view.iso * view.exposure / 1000), 68 | ], device=gaussians.glo.device) 69 | ] 70 | ) 71 | st = time.time() 72 | rendering = renderer.render(view, pipeline, background) 73 | fps.append(time.time() - st) 74 | # print(time.time() - st, view.image_width, view.image_height) 75 | gt = view.original_image[0:3, :, :] 76 | # torchvision.utils.save_image(frendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 77 | # torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 78 | print(1/np.mean(fps)) 79 | 80 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, checkpoint, opt): 81 | with torch.no_grad(): 82 | gaussians = GaussianModel(dataset.sh_degree, dataset.use_neural_network, dataset.max_opacity) 83 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 84 | if checkpoint: 85 | (model_params, first_iter) = torch.load(checkpoint) 86 | gaussians.restore(model_params, opt) 87 | # if dataset.enable_mip_splatting: 88 | # gaussians.enable_mip_splatting( 89 | # dataset.low_pass_2d_kernel_size, dataset.low_pass_3d_kernel_size) 90 | # gaussians.update_low_pass_filter(scene.getTrainCameras()) 91 | 92 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 93 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 94 | 95 | if not skip_train: 96 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 97 | 98 | if not skip_test: 99 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 100 | 101 | if __name__ == "__main__": 102 | # Set up command line argument parser 103 | parser = ArgumentParser(description="Testing script parameters") 104 | model = ModelParams(parser, sentinel=True) 105 | op = OptimizationParams(parser) 106 | pipeline = PipelineParams(parser) 107 | parser.add_argument("--iteration", default=-1, type=int) 108 | parser.add_argument("--skip_train", action="store_true") 109 | parser.add_argument("--skip_test", action="store_true") 110 | parser.add_argument("--quiet", action="store_true") 111 | parser.add_argument("--checkpoint", default=None) 112 | args = get_combined_args(parser) 113 | print("Rendering " + args.model_path) 114 | 115 | # Initialize system state (RNG) 116 | safe_state(args.quiet) 117 | args.checkpoint = None 118 | 119 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.checkpoint, op.extract(args)) 120 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in sorted(os.listdir(renders_dir)): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | 71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 72 | ssims.append(ssim(renders[idx], gts[idx])) 73 | psnrs.append(psnr(renders[idx], gts[idx])) 74 | lpipss.append(lpips(2*renders[idx]-1, 2*gts[idx]-1, net_type='vgg')) 75 | # print(image_names[idx], psnrs[-1], ssims[-1], lpipss[-1]) 76 | 77 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 78 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 79 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 80 | print("") 81 | 82 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 83 | "PSNR": torch.tensor(psnrs).mean().item(), 84 | "LPIPS": torch.tensor(lpipss).mean().item()}) 85 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 86 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 87 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 88 | 89 | with open(scene_dir + "/results.json", 'w') as fp: 90 | json.dump(full_dict[scene_dir], fp, indent=True) 91 | with open(scene_dir + "/per_view.json", 'w') as fp: 92 | json.dump(per_view_dict[scene_dir], fp, indent=True) 93 | except Exception as e: 94 | print(e) 95 | print("Unable to compute metrics for model", scene_dir) 96 | 97 | if __name__ == "__main__": 98 | device = torch.device("cuda:0") 99 | torch.cuda.set_device(device) 100 | 101 | # Set up command line argument parser 102 | parser = ArgumentParser(description="Training script parameters") 103 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 104 | args = parser.parse_args() 105 | evaluate(args.model_paths) 106 | -------------------------------------------------------------------------------- /model_output/cfg_args: -------------------------------------------------------------------------------- 1 | Namespace(sh_degree=3, source_path='/data/dataset', model_path='/data/output', images='images_4', resolution=-1, white_background=False, data_device='cpu', render_spline=False, use_neural_network=False, eval=True, num_additional_pts=10000, additional_size_multi=1.0, num_spline_frames=480, glo_latent_dim=64, max_opacity=0.99, tmin=0.2) -------------------------------------------------------------------------------- /model_output/input.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/model_output/input.ply -------------------------------------------------------------------------------- /notebooks/demo_images/3dgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/3dgs.png -------------------------------------------------------------------------------- /notebooks/demo_images/ea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/ea.png -------------------------------------------------------------------------------- /notebooks/demo_images/es.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/es.png -------------------------------------------------------------------------------- /notebooks/demo_images/ga.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/ga.png -------------------------------------------------------------------------------- /notebooks/demo_images/gs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/gs.png -------------------------------------------------------------------------------- /notebooks/demo_images/os.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/os.png -------------------------------------------------------------------------------- /notebooks/demo_images/side_view.3dgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/side_view.3dgs.png -------------------------------------------------------------------------------- /notebooks/demo_images/tris.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/tris.png -------------------------------------------------------------------------------- /notebooks/demo_images/ts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/half-potato/ever_training/532547481921e7fa757b4235ef706fa1fb32adb0/notebooks/demo_images/ts.png -------------------------------------------------------------------------------- /notebooks/render_sibr_paths.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/amai/gaussian-splatting-merge\n", 13 | "2059\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import os\n", 19 | "os.environ[\"CC\"] = \"/usr/bin/gcc-11\"\n", 20 | "os.environ[\"CXX\"] = \"/usr/bin/g++-11\"\n", 21 | "import struct\n", 22 | "import numpy as np\n", 23 | "from tqdm import tqdm\n", 24 | "\n", 25 | "from pathlib import Path\n", 26 | "import imageio\n", 27 | "import os\n", 28 | "import sys\n", 29 | "sys.path.append(str(Path(os.path.abspath('')).parent))\n", 30 | "print(str(Path(os.path.abspath('')).parent))\n", 31 | "import torch\n", 32 | "from gaussian_renderer import GaussianModel, splinerender, render\n", 33 | "from scene import Scene\n", 34 | "from scene.cameras import Camera, MiniCam\n", 35 | "from torch import nn\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import imageio\n", 38 | "from pyquaternion import Quaternion\n", 39 | "from scene.dataset_readers import ProjectionType\n", 40 | "\n", 41 | "dataset = \"nyc\"\n", 42 | "# path = \"berlin5\"\n", 43 | "# dataset = \"train\"\n", 44 | "path = \"train5\"\n", 45 | "eval_path = \"/home/amai/gaussian-splatting-merge/eval.znf\"\n", 46 | "# eval_path = \"/home/amai/gaussian-splatting-merge/eval\"\n", 47 | "# output_path = Path(\"/data/popping_videos/ours0\") / dataset / \"images\"\n", 48 | "output_path = Path(f\"~/Videos/popping_paths/{path}/ours\")\n", 49 | "\n", 50 | "output_path = Path(f\"~/Videos/nyc1/ours/\")\n", 51 | "output_path.mkdir(parents=True, exist_ok=True)\n", 52 | "# f = open(f\"/data/video_paths/{dataset}/r1/path.path\", \"rb\")\n", 53 | "# f = open(f\"/home/amai/Videos/popping_paths/{path}.path\", \"rb\")\n", 54 | "f = open(f\"/home/amai/Videos/nyc1.path\", \"rb\")\n", 55 | "# f = open(f\"smooth.path\", \"rb\")\n", 56 | "data = f.read()\n", 57 | "N = int.from_bytes(data[:4])\n", 58 | "camera_size = 11\n", 59 | "\n", 60 | "cameras = np.array(struct.unpack(f'>{N*camera_size}f', data[4:])).reshape(N, -1)\n", 61 | "full_data = struct.unpack(f'>i{N*camera_size}f', data)\n", 62 | "N = full_data[0]\n", 63 | "cameras = np.array(full_data[1:]).reshape(N, -1)\n", 64 | "print(N)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 9, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "/home/amai/gaussian-splatting-merge/eval.znf/nyc\n", 77 | "Looking for config file in /home/amai/gaussian-splatting-merge/eval.znf/nyc/cfg_args\n", 78 | "Config file found: /home/amai/gaussian-splatting-merge/eval.znf/nyc/cfg_args\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "import shlex\n", 84 | "from argparse import ArgumentParser, Namespace\n", 85 | "from arguments import ModelParams, PipelineParams, OptimizationParams\n", 86 | "\n", 87 | "def get_combined_args(args_cmdline):\n", 88 | " cfgfile_string = \"Namespace()\"\n", 89 | "\n", 90 | " try:\n", 91 | " cfgfilepath = os.path.join(args_cmdline.model_path, \"cfg_args\")\n", 92 | " print(\"Looking for config file in\", cfgfilepath)\n", 93 | " with open(cfgfilepath) as cfg_file:\n", 94 | " print(\"Config file found: {}\".format(cfgfilepath))\n", 95 | " cfgfile_string = cfg_file.read()\n", 96 | " except TypeError:\n", 97 | " print(\"Config file not found at\")\n", 98 | " pass\n", 99 | " args_cfgfile = eval(cfgfile_string)\n", 100 | "\n", 101 | " merged_dict = vars(args_cfgfile).copy()\n", 102 | " for k, v in vars(args_cmdline).items():\n", 103 | " if v != None:\n", 104 | " merged_dict[k] = v\n", 105 | " return Namespace(**merged_dict)\n", 106 | "\n", 107 | "\n", 108 | "parser = ArgumentParser(description=\"Testing script parameters\")\n", 109 | "model = ModelParams(parser, sentinel=True)\n", 110 | "pipeline = PipelineParams(parser)\n", 111 | "args = parser.parse_args(shlex.split(f\"-m {Path(eval_path) / dataset} --images images_2 -r 1\"))\n", 112 | "print(args.model_path)\n", 113 | "args = get_combined_args(args)\n", 114 | "model = model.extract(args)\n", 115 | "# model.source_path = str(Path(\"/data/nerf_synthetic\") / dataset)\n", 116 | "# model.source_path = str(Path(\"/data/nerf_datasets/tandt/\") / dataset)\n", 117 | "model.source_path = str(Path(\"/data/nerf_datasets/zipnerf_ud\") / dataset)\n", 118 | "\n", 119 | "model.max_opacity = 0.99\n", 120 | "\n", 121 | "pipeline = pipeline.extract(args)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 10, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "Loading trained model at iteration 30000\n", 134 | "Reading camera 990/990\n", 135 | "Loading Training Cameras\n", 136 | "Loaded Train Cameras: 990\n", 137 | "Loaded Test Cameras: 0\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "gaussians = GaussianModel(model.sh_degree, model.max_opacity)\n", 143 | "scene = Scene(model, gaussians, load_iteration=-1, shuffle=False)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 11, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "refcam = scene.getTrainCameras()[0]" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 12, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from gaussian_renderer.fast_renderer import FastRenderer\n", 162 | "\n", 163 | "renderer = FastRenderer(refcam, gaussians, False)\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 13, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "name": "stderr", 173 | "output_type": "stream", 174 | "text": [ 175 | "100%|██████████| 2059/2059 [11:43<00:00, 2.93it/s]\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "background = torch.tensor([0, 0, 0], dtype=torch.float32, device=\"cuda\")\n", 181 | "# width = 1280\n", 182 | "# height = 720\n", 183 | "# width = 1200\n", 184 | "# height = 667\n", 185 | "width = refcam.image_width\n", 186 | "height = refcam.image_height\n", 187 | "image = torch.ones((3, height, width), dtype=float)\n", 188 | "for i in tqdm(range(N)):\n", 189 | " T = cameras[i, :3]\n", 190 | " # xyzw\n", 191 | " quat = cameras[i, 3:7]\n", 192 | " R = Quaternion(x=quat[0], y=quat[1], z=quat[2], w=quat[3]).transformation_matrix\n", 193 | " R[:3, 3] = T\n", 194 | " transf = np.linalg.inv(R).T\n", 195 | " # print(R, transf)\n", 196 | " # transf[1, :] = -transf[1, :]\n", 197 | " # transf[2, :] = -transf[2, :]\n", 198 | " transf[:, 1] = -transf[:, 1]\n", 199 | " transf[:, 2] = -transf[:, 2]\n", 200 | "\n", 201 | " # R = transf[:3, :3]\n", 202 | " # T = transf[:3, 3]\n", 203 | " fovy = cameras[i, -4]\n", 204 | " fovx = cameras[i, -3]\n", 205 | " fovy = refcam.FoVy\n", 206 | " fovx = refcam.FoVx\n", 207 | " znear = cameras[i, -2]\n", 208 | " zfar = cameras[i, -1]\n", 209 | " # view = Camera(0, R, T, aspect*fovy/180*np.pi, fovy/180*np.pi, image, image, \"fake\", 0)\n", 210 | " world_view_transform = torch.as_tensor(transf).float()\n", 211 | " full_proj_transform = torch.as_tensor(transf).float()\n", 212 | " # fovx = 1.699109673500061\n", 213 | " # fovx = 1.7087104320526123\n", 214 | " # fovx = 1.399527668952942\n", 215 | " view = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)\n", 216 | " view.model = ProjectionType.PERSPECTIVE\n", 217 | " with torch.no_grad():\n", 218 | " # rendering = splinerender(view, gaussians, pipeline, background, random=False)[\"render\"]\n", 219 | " renderer.set_camera(view)\n", 220 | " rendering = renderer.render(view, pipeline, background)\n", 221 | " # rendering = splinerender(cam, gaussians, pipeline, background)[\"render\"]\n", 222 | " byte_rendering = (rendering.permute(1, 2, 0).cpu().numpy()*255).clip(min=0, max=255).astype(np.uint8)\n", 223 | " full_output_path = output_path / f\"{i:06d}.png\"\n", 224 | " imageio.imwrite(str(full_output_path), byte_rendering)\n", 225 | " # plt.imshow(byte_rendering)\n", 226 | " # plt.show()" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [] 256 | } 257 | ], 258 | "metadata": { 259 | "kernelspec": { 260 | "display_name": "Python 3", 261 | "language": "python", 262 | "name": "python3" 263 | }, 264 | "language_info": { 265 | "codemirror_mode": { 266 | "name": "ipython", 267 | "version": 3 268 | }, 269 | "file_extension": ".py", 270 | "mimetype": "text/x-python", 271 | "name": "python", 272 | "nbconvert_exporter": "python", 273 | "pygments_lexer": "ipython3", 274 | "version": "3.11.6" 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 2 279 | } 280 | -------------------------------------------------------------------------------- /notebooks/render_size_animation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/amai/gaussian-splatting-merge\n", 13 | "4018\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import os\n", 19 | "VERSION = 9\n", 20 | "os.environ[\"CC\"] = f\"/usr/bin/gcc-{VERSION}\"\n", 21 | "os.environ[\"CXX\"] = f\"/usr/bin/g++-{VERSION}\"\n", 22 | "import struct\n", 23 | "import numpy as np\n", 24 | "from tqdm import tqdm\n", 25 | "import mediapy\n", 26 | "\n", 27 | "from pathlib import Path\n", 28 | "import imageio\n", 29 | "import os\n", 30 | "import sys\n", 31 | "sys.path.append(str(Path(os.path.abspath('')).parent))\n", 32 | "print(str(Path(os.path.abspath('')).parent))\n", 33 | "import torch\n", 34 | "from gaussian_renderer import GaussianModel, render\n", 35 | "from gaussian_renderer.eevr import splinerender\n", 36 | "from scene import Scene\n", 37 | "from scene.cameras import Camera, MiniCam\n", 38 | "from torch import nn\n", 39 | "import matplotlib.pyplot as plt\n", 40 | "import imageio\n", 41 | "from pyquaternion import Quaternion\n", 42 | "from scene.dataset_readers import ProjectionType\n", 43 | "\n", 44 | "dataset = \"london\"\n", 45 | "# path = \"berlin5\"\n", 46 | "# dataset = \"train\"\n", 47 | "path = \"train5\"\n", 48 | "eval_path = \"/home/amai/gaussian-splatting-merge/eval\"\n", 49 | "# eval_path = \"/home/amai/gaussian-splatting-merge/eval\"\n", 50 | "# output_path = Path(\"/data/popping_videos/ours0\") / dataset / \"images\"\n", 51 | "output_path = Path(f\"~/Videos/popping_paths/{path}/ours\")\n", 52 | "\n", 53 | "output_path = Path(f\"~/Videos/{dataset}/ours/\")\n", 54 | "output_path.mkdir(parents=True, exist_ok=True)\n", 55 | "# f = open(f\"/data/video_paths/{dataset}/r1/path.path\", \"rb\")\n", 56 | "# f = open(f\"/home/amai/Videos/popping_paths/{path}.path\", \"rb\")\n", 57 | "f = open(f\"/home/amai/Videos/alameda.path\", \"rb\")\n", 58 | "# f = open(f\"smooth.path\", \"rb\")\n", 59 | "data = f.read()\n", 60 | "N = int.from_bytes(data[:4])\n", 61 | "camera_size = 11\n", 62 | "\n", 63 | "cameras = np.array(struct.unpack(f'>{N*camera_size}f', data[4:])).reshape(N, -1)\n", 64 | "full_data = struct.unpack(f'>i{N*camera_size}f', data)\n", 65 | "N = full_data[0]\n", 66 | "cameras = np.array(full_data[1:]).reshape(N, -1)\n", 67 | "print(N)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "/home/amai/gaussian-splatting-merge/eval/london\n", 80 | "Looking for config file in /home/amai/gaussian-splatting-merge/eval/london/cfg_args\n", 81 | "Config file found: /home/amai/gaussian-splatting-merge/eval/london/cfg_args\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "import shlex\n", 87 | "from argparse import ArgumentParser, Namespace\n", 88 | "from arguments import ModelParams, PipelineParams, OptimizationParams\n", 89 | "\n", 90 | "def get_combined_args(args_cmdline):\n", 91 | " cfgfile_string = \"Namespace()\"\n", 92 | "\n", 93 | " try:\n", 94 | " cfgfilepath = os.path.join(args_cmdline.model_path, \"cfg_args\")\n", 95 | " print(\"Looking for config file in\", cfgfilepath)\n", 96 | " with open(cfgfilepath) as cfg_file:\n", 97 | " print(\"Config file found: {}\".format(cfgfilepath))\n", 98 | " cfgfile_string = cfg_file.read()\n", 99 | " except TypeError:\n", 100 | " print(\"Config file not found at\")\n", 101 | " pass\n", 102 | " args_cfgfile = eval(cfgfile_string)\n", 103 | "\n", 104 | " merged_dict = vars(args_cfgfile).copy()\n", 105 | " for k, v in vars(args_cmdline).items():\n", 106 | " if v != None:\n", 107 | " merged_dict[k] = v\n", 108 | " return Namespace(**merged_dict)\n", 109 | "\n", 110 | "\n", 111 | "parser = ArgumentParser(description=\"Testing script parameters\")\n", 112 | "model = ModelParams(parser, sentinel=True)\n", 113 | "pipeline = PipelineParams(parser)\n", 114 | "args = parser.parse_args(shlex.split(f\"-m {Path(eval_path) / dataset} --images images_2 -r 1\"))\n", 115 | "print(args.model_path)\n", 116 | "args = get_combined_args(args)\n", 117 | "model = model.extract(args)\n", 118 | "# model.source_path = str(Path(\"/data/nerf_synthetic\") / dataset)\n", 119 | "# model.source_path = str(Path(\"/data/nerf_datasets/tandt/\") / dataset)\n", 120 | "model.source_path = str(Path(\"/data/nerf_datasets/zipnerf_ud\") / dataset)\n", 121 | "\n", 122 | "model.max_opacity = 0.99\n", 123 | "\n", 124 | "pipeline = pipeline.extract(args)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 5, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stderr", 134 | "output_type": "stream", 135 | "text": [ 136 | "ic| self.max_opacity: 0.99\n" 137 | ] 138 | }, 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "Loading trained model at iteration 30000\n", 144 | "Reading camera 1874/1874\n", 145 | "Loading Training Cameras\n", 146 | "Loaded Train Cameras: 1639\n", 147 | "Loaded Test Cameras: 235\n" 148 | ] 149 | } 150 | ], 151 | "source": [ 152 | "gaussians = GaussianModel(model.sh_degree, model.max_opacity)\n", 153 | "scene = Scene(model, gaussians, load_iteration=-1, shuffle=False)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 6, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "refcam = scene.getTestCameras()[105]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "import cv2" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stderr", 181 | "output_type": "stream", 182 | "text": [ 183 | "200it [00:12, 15.52it/s]\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "background = torch.tensor([0, 0, 0], dtype=torch.float32, device=\"cuda\")\n", 189 | "# width = 1280\n", 190 | "# height = 720\n", 191 | "# width = 1200\n", 192 | "# height = 667\n", 193 | "width = refcam.image_width*4\n", 194 | "height = refcam.image_height*4\n", 195 | "image = torch.ones((3, height, width), dtype=float)\n", 196 | "frames = []\n", 197 | "for i, smod in tqdm(enumerate(torch.linspace(0.1, 1, 200))):\n", 198 | " with torch.no_grad():\n", 199 | " # rendering = renderer.render(view, pipeline, background)\n", 200 | " rendering = splinerender(refcam, gaussians, pipeline, background, scaling_modifier=smod)[\"render\"]\n", 201 | " byte_rendering = (rendering.permute(1, 2, 0).cpu().numpy()*255).clip(min=0, max=255).astype(np.uint8)\n", 202 | " byte_rendering = cv2.resize(byte_rendering, (refcam.image_width, refcam.image_height), interpolation=cv2.INTER_AREA)\n", 203 | " frames.append(byte_rendering)\n", 204 | " # full_output_path = output_path / f\"{i:06d}.png\"\n", 205 | " # imageio.imwrite(str(full_output_path), byte_rendering)\n", 206 | " # plt.imshow(byte_rendering)\n", 207 | " # plt.show()" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 10, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "frames = np.stack(frames, axis=0)\n", 217 | "frames = np.concatenate([\n", 218 | " frames,\n", 219 | " torch.as_tensor(frames[-1]).unsqueeze(0).expand(50, -1, -1, -1).numpy(),\n", 220 | " frames[::-1],\n", 221 | "])" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 13, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "mediapy.show_video(frames, bps=100000000)\n", 231 | "mediapy.write_video(\"size_animation.mp4\", frames, bps=100000000)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [] 247 | } 248 | ], 249 | "metadata": { 250 | "kernelspec": { 251 | "display_name": "Python 3", 252 | "language": "python", 253 | "name": "python3" 254 | }, 255 | "language_info": { 256 | "codemirror_mode": { 257 | "name": "ipython", 258 | "version": 3 259 | }, 260 | "file_extension": ".py", 261 | "mimetype": "text/x-python", 262 | "name": "python", 263 | "nbconvert_exporter": "python", 264 | "pygments_lexer": "ipython3", 265 | "version": "3.11.6" 266 | } 267 | }, 268 | "nbformat": 4, 269 | "nbformat_minor": 2 270 | } 271 | -------------------------------------------------------------------------------- /partial_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | # mipnerf360_outdoor_scenes = ["bicycle", "treehill"] 16 | mipnerf360_outdoor_scenes = ["treehill"] 17 | mipnerf360_indoor_scenes = ["counter"] 18 | tanks_and_temples_scenes = [] 19 | deep_blending_scenes = [] 20 | 21 | parser = ArgumentParser(description="Full evaluation script parameters") 22 | parser.add_argument("--skip_training", action="store_true") 23 | parser.add_argument("--skip_rendering", action="store_true") 24 | parser.add_argument("--skip_metrics", action="store_true") 25 | parser.add_argument("--output_path", default="./eval") 26 | args, _ = parser.parse_known_args() 27 | 28 | all_scenes = [] 29 | all_scenes.extend(mipnerf360_outdoor_scenes) 30 | all_scenes.extend(mipnerf360_indoor_scenes) 31 | all_scenes.extend(tanks_and_temples_scenes) 32 | all_scenes.extend(deep_blending_scenes) 33 | 34 | if not args.skip_training or not args.skip_rendering: 35 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 36 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 37 | parser.add_argument("--deepblending", "-db", required=True, type=str) 38 | parser.add_argument("--port", default=6009) 39 | parser.add_argument("--additional_args", default="", type=str) 40 | args = parser.parse_args() 41 | 42 | if not args.skip_training: 43 | # common_args = " --quiet --eval --test_iterations -1 --port 6008 --lambda_distortion 0" 44 | common_args = f" --eval --test_iterations -1 --port {args.port} {args.additional_args}" 45 | for scene in mipnerf360_outdoor_scenes: 46 | source = args.mipnerf360 + "/" + scene 47 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 48 | for scene in mipnerf360_indoor_scenes: 49 | source = args.mipnerf360 + "/" + scene 50 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 51 | for scene in tanks_and_temples_scenes: 52 | source = args.tanksandtemples + "/" + scene 53 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 54 | for scene in deep_blending_scenes: 55 | source = args.deepblending + "/" + scene 56 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 57 | 58 | if not args.skip_rendering: 59 | all_sources = [] 60 | for scene in mipnerf360_outdoor_scenes: 61 | all_sources.append(args.mipnerf360 + "/" + scene) 62 | for scene in mipnerf360_indoor_scenes: 63 | all_sources.append(args.mipnerf360 + "/" + scene) 64 | for scene in tanks_and_temples_scenes: 65 | all_sources.append(args.tanksandtemples + "/" + scene) 66 | for scene in deep_blending_scenes: 67 | all_sources.append(args.deepblending + "/" + scene) 68 | 69 | common_args = " --quiet --eval --skip_train" 70 | for scene, source in zip(all_scenes, all_sources): 71 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 72 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 73 | 74 | if not args.skip_metrics: 75 | scenes_string = "" 76 | for scene in all_scenes: 77 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 78 | 79 | os.system("python metrics.py -m " + scenes_string) 80 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria, Google 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | import math 16 | from tqdm import tqdm 17 | from os import makedirs 18 | from gaussian_renderer import render 19 | from gaussian_renderer.ever import splinerender 20 | import torchvision 21 | from utils.general_utils import safe_state 22 | from argparse import ArgumentParser 23 | from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams 24 | from gaussian_renderer import GaussianModel 25 | from scene.dataset_readers import ProjectionType 26 | 27 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 28 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 29 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 30 | 31 | makedirs(render_path, exist_ok=True) 32 | makedirs(gts_path, exist_ok=True) 33 | 34 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 35 | # if idx != 424: 36 | # continue 37 | N = 1 38 | frendering = None 39 | camera_inds = {view.uid: i for i, view in enumerate(views)} 40 | for i in range(N): 41 | camera_ind = camera_inds[view.uid] 42 | #view.glo_vector = gaussians.glo[camera_ind] 43 | if gaussians.glo is not None: 44 | view.glo_vector = torch.cat( 45 | [gaussians.glo[camera_ind], torch.tensor([ 46 | math.log( 47 | view.iso * view.exposure / 1000), 48 | ], device=gaussians.glo.device) 49 | ] 50 | ) 51 | # view.model=ProjectionType.PERSPECTIVE 52 | rendering = splinerender(view, gaussians, pipeline, background, random=False)["render"] 53 | if frendering is None: 54 | frendering = rendering / N 55 | else: 56 | frendering += rendering / N 57 | gt = view.original_image[0:3, :, :] 58 | torchvision.utils.save_image(frendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 59 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 60 | 61 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, checkpoint, opt): 62 | with torch.no_grad(): 63 | gaussians = GaussianModel(dataset.sh_degree, dataset.use_neural_network, dataset.max_opacity, dataset.tmin) 64 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 65 | if checkpoint: 66 | (model_params, first_iter) = torch.load(checkpoint) 67 | gaussians.restore(model_params, opt) 68 | 69 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 70 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 71 | 72 | if not skip_train: 73 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 74 | 75 | if not skip_test: 76 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 77 | 78 | if __name__ == "__main__": 79 | # Set up command line argument parser 80 | parser = ArgumentParser(description="Testing script parameters") 81 | model = ModelParams(parser, sentinel=True) 82 | op = OptimizationParams(parser) 83 | pipeline = PipelineParams(parser) 84 | parser.add_argument("--iteration", default=-1, type=int) 85 | parser.add_argument("--skip_train", action="store_true") 86 | parser.add_argument("--skip_test", action="store_true") 87 | parser.add_argument("--quiet", action="store_true") 88 | parser.add_argument("--checkpoint", default=None) 89 | args = get_combined_args(parser) 90 | print("Rendering " + args.model_path) 91 | 92 | # Initialize system state (RNG) 93 | safe_state(args.quiet) 94 | args.checkpoint = args.checkpoint if hasattr(args, "checkpoint") else None 95 | 96 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.checkpoint, op.extract(args)) 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | six 3 | requests[socks] 4 | tqdm 5 | git+https://github.com/facebookresearch/pytorch3d.git 6 | kornia 7 | slangtorch 8 | slangpy 9 | icecream 10 | plyfile 11 | opencv-python 12 | -------------------------------------------------------------------------------- /resize_images.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--resize", action="store_true") 24 | parser.add_argument("--magick_executable", default="", type=str) 25 | args = parser.parse_args() 26 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 27 | 28 | print("Copying and resizing...") 29 | 30 | # Resize images. 31 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 32 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 33 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 34 | # Get the list of files in the source directory 35 | files = os.listdir(args.source_path + "/images") 36 | # Copy each file from the source directory to the destination directory 37 | for file in files: 38 | source_file = os.path.join(args.source_path, "images", file) 39 | 40 | destination_file = os.path.join(args.source_path, "images_2", file) 41 | shutil.copy2(source_file, destination_file) 42 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 43 | if exit_code != 0: 44 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 45 | exit(exit_code) 46 | 47 | destination_file = os.path.join(args.source_path, "images_4", file) 48 | shutil.copy2(source_file, destination_file) 49 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 50 | if exit_code != 0: 51 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 52 | exit(exit_code) 53 | 54 | destination_file = os.path.join(args.source_path, "images_8", file) 55 | shutil.copy2(source_file, destination_file) 56 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 57 | if exit_code != 0: 58 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 59 | exit(exit_code) 60 | print("Done.") 61 | 62 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for set in bicycle flowers garden stump treehill room counter kitchen bonsai; do 3 | # for set in bicycle flowers garden stump treehill; do 4 | # for set in kitchen bonsai; do 5 | # PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 python train.py -s /data/nerf_synthetic/$set --densify_grad_threshold=3e-7 --convert_SHs_python 6 | PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 python train.py -s /data/nerf_synthetic/$set --densify_grad_threshold=3e-7 --eval 7 | # PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 python train.py -s /data/nerf_synthetic/$set --densify_grad_threshold=3e-7 --use_neural_network --data_device cpu --eval --feature_rest_lr 0.0025 8 | done 9 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import os 14 | import random 15 | import copy 16 | import pickle 17 | import json 18 | import numpy as np 19 | from utils.system_utils import searchForMaxIteration 20 | from scene.dataset_readers import sceneLoadTypeCallbacks 21 | from scene.gaussian_model import GaussianModel 22 | from arguments import ModelParams 23 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 24 | from icecream import ic 25 | from utils import cam_util 26 | 27 | def transform_cameras_pca(cameras): 28 | if len(cameras) == 0: 29 | return cameras, np.eye(4) 30 | poses = np.stack([ 31 | np.linalg.inv(view.world_view_transform.T.cpu().numpy())[:3] 32 | for view in cameras], axis=0) 33 | new_poses, transform = cam_util.transform_poses_pca(poses) 34 | for i, cam in enumerate(cameras): 35 | T = np.eye(4) 36 | T[:3] = new_poses[i][:3] 37 | T = torch.linalg.inv(torch.tensor(T).float()).to(cam.world_view_transform.device) 38 | T[:3, 0] = T[:3, 0]*torch.linalg.det(T[:3, :3]) 39 | cameras[i] = set_pose(cam, T) 40 | return cameras, transform 41 | 42 | def set_pose(camera, T): 43 | # camera.world_view_transform = T.T 44 | # camera.full_proj_transform = ( 45 | # camera.world_view_transform.unsqueeze(0).bmm( 46 | # camera.projection_matrix.unsqueeze(0))).squeeze(0) 47 | # camera.camera_center = camera.world_view_transform.inverse()[3, :3] 48 | camera.R = T[:3, :3].T.numpy() 49 | camera.T = T[:3, 3].numpy() 50 | camera.update() 51 | return camera 52 | 53 | class Scene: 54 | 55 | gaussians : GaussianModel 56 | 57 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 58 | """b 59 | :param path: Path to colmap scene main folder. 60 | """ 61 | self.model_path = args.model_path 62 | self.loaded_iter = None 63 | self.gaussians = gaussians 64 | 65 | if load_iteration: 66 | if load_iteration == -1: 67 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 68 | else: 69 | self.loaded_iter = load_iteration 70 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 71 | 72 | self.train_cameras = {} 73 | self.test_cameras = {} 74 | 75 | if os.path.exists(os.path.join(args.source_path, "sparse")): 76 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 77 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 78 | print("Found transforms_train.json file, assuming Blender data set!") 79 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 80 | else: 81 | assert False, "Could not recognize scene type!" 82 | 83 | if not self.loaded_iter: 84 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 85 | dest_file.write(src_file.read()) 86 | json_cams = [] 87 | camlist = [] 88 | if scene_info.test_cameras: 89 | camlist.extend(scene_info.test_cameras) 90 | if scene_info.train_cameras: 91 | camlist.extend(scene_info.train_cameras) 92 | for id, cam in enumerate(camlist): 93 | json_cams.append(camera_to_JSON(id, cam)) 94 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 95 | json.dump(json_cams, file) 96 | 97 | if shuffle: 98 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 99 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 100 | 101 | self.cameras_extent = scene_info.nerf_normalization["radius"] 102 | 103 | # if not args.render_spline: 104 | for resolution_scale in resolution_scales: 105 | print("Loading Training Cameras") 106 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 107 | print(f"Loaded Train Cameras: {len(self.train_cameras[resolution_scale])}") 108 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 109 | print(f"Loaded Test Cameras: {len(self.test_cameras[resolution_scale])}") 110 | 111 | 112 | if not args.render_spline: 113 | pass 114 | else: 115 | 116 | # for resolution_scale in resolution_scales: 117 | # self.test_cameras[resolution_scale], _ = transform_cameras_pca(self.test_cameras[resolution_scale]) 118 | 119 | for resolution_scale in resolution_scales: 120 | test_cams = self.train_cameras[resolution_scale] 121 | flat_cameras, transform = transform_cameras_pca(test_cams) 122 | wT = np.eye(4) 123 | wT[:3] = transform[:3] 124 | wT = torch.tensor(wT).float() 125 | wT = torch.linalg.inv(wT) 126 | poses = np.stack([ 127 | np.linalg.inv(view.world_view_transform.T.cpu().numpy())[:3] 128 | for view in flat_cameras], axis=0) 129 | eposes = cam_util.generate_ellipse_path(poses, n_frames = args.num_spline_frames)#4*480) 130 | # eposes = poses 131 | refcam = self.train_cameras[resolution_scale][0] 132 | cameras = test_cams 133 | cameras = [] 134 | for i in range(eposes.shape[0]): 135 | # for i in [0]: 136 | camera = copy.copy(refcam) 137 | T = np.eye(4) 138 | T[:3] = eposes[i][:3] 139 | T = torch.tensor(T).float() 140 | # T = T @ wT 141 | T = wT @ T 142 | T = T @ torch.diag(torch.tensor([-1.0, 1.0, -1.0, 1.0])) 143 | T = torch.linalg.inv(T).to(args.data_device) 144 | # T[:3, 0] = T[:3, 0]*torch.linalg.det(T[:3, :3]) 145 | # U, S, Vt = torch.linalg.svd(T[:3, :3]) 146 | # T[:3, :3] = U @ Vt 147 | camera = set_pose(camera, T) 148 | camera.uid = i 149 | cameras.append(camera) 150 | # self.train_cameras[resolution_scale] = cameras 151 | self.test_cameras[resolution_scale] = cameras 152 | # print("Rendering spline") 153 | # with open(os.path.join(args.source_path, "render_poses.pkl"), "rb") as f: 154 | # render_path = pickle.load(f) 155 | # for resolution_scale in resolution_scales: 156 | # self.train_cameras[resolution_scale] = render_path 157 | # self.test_cameras[resolution_scale] = render_path 158 | 159 | if self.loaded_iter: 160 | # self.gaussians.load_th(os.path.join(self.model_path, 161 | # "point_cloud", 162 | # "iteration_" + str(self.loaded_iter), 163 | # "point_cloud.th")) 164 | self.gaussians.load_ply(os.path.join(self.model_path, 165 | "point_cloud", 166 | "iteration_" + str(self.loaded_iter), 167 | "point_cloud.ply")) 168 | else: 169 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, args.num_additional_pts, args.additional_size_multi) 170 | 171 | def save(self, iteration): 172 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 173 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 174 | 175 | def getTrainCameras(self, scale=1.0): 176 | return self.train_cameras[scale] 177 | 178 | def getTestCameras(self, scale=1.0): 179 | return self.test_cameras[scale] 180 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | from scene.dataset_readers import ProjectionType 17 | from icecream import ic 18 | 19 | class Camera(nn.Module): 20 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 21 | image_name, uid, 22 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", 23 | model=ProjectionType.PERSPECTIVE, distortion_params=None, 24 | exposure=1, iso=100, aperature=1): 25 | super(Camera, self).__init__() 26 | 27 | self.uid = uid 28 | self.colmap_id = colmap_id 29 | self.R = R 30 | self.T = T 31 | self.FoVx = FoVx 32 | self.FoVy = FoVy 33 | self.image_name = image_name 34 | self.model = model 35 | self.distortion_params = distortion_params 36 | self.glo_vector = None 37 | self.exposure = exposure 38 | self.iso = iso 39 | self.aperature = aperature 40 | 41 | try: 42 | self.data_device = torch.device(data_device) 43 | except Exception as e: 44 | print(e) 45 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 46 | self.data_device = torch.device("cuda") 47 | 48 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 49 | self.image_width = self.original_image.shape[2] 50 | self.image_height = self.original_image.shape[1] 51 | 52 | if gt_alpha_mask is not None: 53 | self.original_image *= gt_alpha_mask.to(self.data_device) 54 | self.alpha_mask = gt_alpha_mask 55 | else: 56 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 57 | 58 | self.zfar = 100.0 59 | self.znear = 0.01 60 | 61 | self.trans = trans 62 | self.scale = scale 63 | self.update() 64 | 65 | def update(self): 66 | self.world_view_transform = torch.tensor(getWorld2View2(self.R, self.T, self.trans, self.scale)).transpose(0, 1).to(self.data_device) 67 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(self.data_device) 68 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 69 | self.camera_center = self.world_view_transform.inverse()[3, :3] 70 | 71 | class MiniCam: 72 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 73 | self.image_width = width 74 | self.image_height = height 75 | self.FoVy = fovy 76 | self.FoVx = fovx 77 | self.znear = znear 78 | self.zfar = zfar 79 | self.world_view_transform = world_view_transform 80 | self.full_proj_transform = full_proj_transform 81 | view_inv = torch.inverse(self.world_view_transform) 82 | self.camera_center = view_inv[3][:3] 83 | 84 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | # assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/contractions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import pytorch3d.transforms 16 | import torch 17 | from icecream import ic 18 | 19 | 20 | def contract_points(x): 21 | mag = torch.linalg.norm(x, dim=-1)[..., None] 22 | return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag)) 23 | 24 | 25 | def inv_contract_points(z): 26 | z_mag_sq = torch.sum(z**2, dim=-1, keepdims=True) 27 | z_mag_sq = torch.maximum(torch.ones_like(z_mag_sq), z_mag_sq) 28 | inv_scale = 2 * torch.sqrt(z_mag_sq) - z_mag_sq 29 | x = z / inv_scale.clip(min=1e-4) 30 | return x 31 | 32 | 33 | def track_gaussians(fn, means, covs, densities): 34 | jc_means = torch.vmap(torch.func.jacrev(fn))(means.view(-1, means.shape[-1])) 35 | jc_means = jc_means.view(list(means.shape) + [means.shape[-1]]) 36 | 37 | # Only update covariances on positions outside the unit sphere 38 | mag = means.norm(dim=-1) 39 | mask = mag >= 1 40 | covs = covs.clone() 41 | covs[mask] = jc_means[mask] @ covs[mask] @ torch.transpose(jc_means[mask], -2, -1) 42 | 43 | # densities[mask] = densities[mask] * torch.linalg.det(jc_means)[mask].abs() 44 | 45 | return fn(means), covs, densities 46 | 47 | 48 | def contract_gaussians(means, covs, densities): 49 | return track_gaussians(contract_points, means, covs, densities) 50 | 51 | 52 | def inv_contract_gaussians(means, covs, densities): 53 | return track_gaussians(inv_contract_points, means, covs, densities) 54 | 55 | 56 | def to_cov(scale, quat): 57 | R = pytorch3d.transforms.quaternion_to_matrix(quat) 58 | S2 = torch.zeros_like(R) 59 | S2[:, 0, 0] = scale[:, 0] ** 2 60 | S2[:, 1, 1] = scale[:, 1] ** 2 61 | S2[:, 2, 2] = scale[:, 2] ** 2 62 | return torch.bmm(torch.bmm(R.permute(0, 2, 1), S2), R) 63 | 64 | 65 | def from_covs(Ms): 66 | eig = torch.linalg.eig(Ms) 67 | scales2 = eig.eigenvalues.real.clip(min=1e-10).sqrt() 68 | R2 = eig.eigenvectors.real.permute(0, 2, 1) 69 | R2 = R2 * torch.linalg.det(R2).reshape(-1, 1, 1) 70 | q2 = pytorch3d.transforms.matrix_to_quaternion(R2) 71 | return scales2, q2 72 | 73 | 74 | def inv_contract_gaussians_decomposed(means, scales, quats, densities): 75 | covs = to_cov(scales, quats) 76 | new_means, new_covs, new_densities = inv_contract_gaussians(means, covs, densities) 77 | new_scales, new_quats = from_covs(new_covs) 78 | return new_means, new_scales, new_quats, new_densities 79 | 80 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple, Optional 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | import enum 26 | import json 27 | 28 | class ProjectionType(enum.Enum): 29 | """Camera projection type (perspective pinhole, fisheye, or 360 pano).""" 30 | 31 | PERSPECTIVE = 'perspective' 32 | FISHEYE = 'fisheye' 33 | PANORAMIC = 'pano' 34 | 35 | class CameraInfo(NamedTuple): 36 | uid: int 37 | R: np.array 38 | T: np.array 39 | FovY: np.array 40 | FovX: np.array 41 | image: np.array 42 | image_path: str 43 | image_name: str 44 | width: int 45 | height: int 46 | model: ProjectionType 47 | distortion_params: Optional[dict] 48 | exposure: float 49 | iso: float 50 | aperature: float 51 | 52 | class SceneInfo(NamedTuple): 53 | point_cloud: BasicPointCloud 54 | train_cameras: list 55 | test_cameras: list 56 | nerf_normalization: dict 57 | ply_path: str 58 | 59 | def getNerfppNorm(cam_info): 60 | def get_center_and_diag(cam_centers): 61 | cam_centers = np.hstack(cam_centers) 62 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 63 | center = avg_cam_center 64 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 65 | diagonal = np.max(dist) 66 | return center.flatten(), diagonal 67 | 68 | cam_centers = [] 69 | 70 | for cam in cam_info: 71 | W2C = getWorld2View2(cam.R, cam.T) 72 | C2W = np.linalg.inv(W2C) 73 | cam_centers.append(C2W[:3, 3:4]) 74 | 75 | center, diagonal = get_center_and_diag(cam_centers) 76 | radius = diagonal * 1.1 77 | 78 | translate = -center 79 | 80 | return {"translate": translate, "radius": radius} 81 | 82 | def convert_to_float(frac_str): 83 | try: 84 | return float(frac_str) 85 | except ValueError: 86 | num, denom = frac_str.split('/') 87 | try: 88 | leading, num = num.split(' ') 89 | whole = float(leading) 90 | except ValueError: 91 | whole = 0 92 | frac = float(num) / float(denom) 93 | return whole - frac if whole < 0 else whole + frac 94 | 95 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, metadata_path): 96 | if os.path.isfile(metadata_path): 97 | with open(metadata_path, "r") as f: 98 | print("Loading metadata") 99 | metadata = json.load(f) 100 | else: 101 | metadata = None 102 | cam_infos = [] 103 | for idx, key in enumerate(cam_extrinsics): 104 | sys.stdout.write('\r') 105 | # the exact output you're looking for: 106 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 107 | sys.stdout.flush() 108 | 109 | extr = cam_extrinsics[key] 110 | intr = cam_intrinsics[extr.camera_id] 111 | height = intr.height 112 | width = intr.width 113 | 114 | uid = intr.id 115 | R = np.transpose(qvec2rotmat(extr.qvec)) 116 | T = np.array(extr.tvec) 117 | 118 | model = ProjectionType.PERSPECTIVE 119 | distortion_params = None 120 | if intr.model=="SIMPLE_PINHOLE": 121 | focal_length_x = intr.params[0] 122 | FovY = focal2fov(focal_length_x, height) 123 | FovX = focal2fov(focal_length_x, width) 124 | elif intr.model=="PINHOLE": 125 | focal_length_x = intr.params[0] 126 | focal_length_y = intr.params[1] 127 | FovY = focal2fov(focal_length_y, height) 128 | FovX = focal2fov(focal_length_x, width) 129 | elif intr.model=="OPENCV_FISHEYE": 130 | distortion_params = None 131 | fx, fy, cx, cy = intr.params[:4] 132 | k1, k2, k3, k4 = intr.params[4:] 133 | distortion_params = {} 134 | distortion_params['k1'] = k1 135 | distortion_params['k2'] = k2 136 | distortion_params['k3'] = k3 137 | distortion_params['k4'] = k4 138 | model = ProjectionType.FISHEYE 139 | focal_length_x = intr.params[0] 140 | focal_length_y = intr.params[1] 141 | FovY = focal2fov(focal_length_y, height) 142 | FovX = focal2fov(focal_length_x, width) 143 | else: 144 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 145 | 146 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 147 | image_name = os.path.basename(image_path).split(".")[0] 148 | image_path = image_path.replace(" ", "_") 149 | # temp = Image.open(image_path) 150 | # image = temp.copy() 151 | # temp.close() 152 | if metadata is not None: 153 | data = metadata[os.path.basename(image_path)] 154 | exposure = convert_to_float(data['exposure']) 155 | iso = float(data['iso']) 156 | aperature = float(data['aperature']) 157 | else: 158 | exposure = 1 159 | iso = 100 160 | aperature = 1 161 | image = None 162 | 163 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 164 | image_path=image_path, image_name=image_name, width=width, 165 | height=height, model=model, 166 | distortion_params=distortion_params, exposure=exposure, 167 | aperature=aperature, iso=iso) 168 | cam_infos.append(cam_info) 169 | sys.stdout.write('\n') 170 | return cam_infos 171 | 172 | def fetchPly(path): 173 | plydata = PlyData.read(path) 174 | vertices = plydata['vertex'] 175 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 176 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 177 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 178 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 179 | 180 | def storePly(path, xyz, rgb): 181 | # Define the dtype for the structured array 182 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 183 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 184 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 185 | 186 | normals = np.zeros_like(xyz) 187 | 188 | elements = np.empty(xyz.shape[0], dtype=dtype) 189 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 190 | elements[:] = list(map(tuple, attributes)) 191 | 192 | # Create the PlyData object and write to file 193 | vertex_element = PlyElement.describe(elements, 'vertex') 194 | ply_data = PlyData([vertex_element]) 195 | ply_data.write(path) 196 | 197 | def readColmapSceneInfo(path, images, eval, llffhold=8): 198 | try: 199 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 200 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 201 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 202 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 203 | except: 204 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 205 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 206 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 207 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 208 | 209 | reading_dir = "images_4" if images == None else images 210 | cam_infos_unsorted = readColmapCameras( 211 | cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 212 | images_folder=os.path.join(path, reading_dir), 213 | metadata_path=os.path.join(path, "metadata.json")) 214 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 215 | 216 | if eval: 217 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 218 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 219 | else: 220 | train_cam_infos = cam_infos 221 | test_cam_infos = [] 222 | 223 | nerf_normalization = getNerfppNorm(train_cam_infos) 224 | 225 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 226 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 227 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 228 | if not os.path.exists(ply_path): 229 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 230 | try: 231 | xyz, rgb, _ = read_points3D_binary(bin_path) 232 | except: 233 | xyz, rgb, _ = read_points3D_text(txt_path) 234 | storePly(ply_path, xyz, rgb) 235 | try: 236 | pcd = fetchPly(ply_path) 237 | except: 238 | pcd = None 239 | 240 | scene_info = SceneInfo(point_cloud=pcd, 241 | train_cameras=train_cam_infos, 242 | test_cameras=test_cam_infos, 243 | nerf_normalization=nerf_normalization, 244 | ply_path=ply_path) 245 | return scene_info 246 | 247 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 248 | cam_infos = [] 249 | 250 | with open(os.path.join(path, transformsfile)) as json_file: 251 | contents = json.load(json_file) 252 | fovx = contents["camera_angle_x"] 253 | 254 | frames = contents["frames"] 255 | for idx, frame in enumerate(frames): 256 | cam_name = os.path.join(path, frame["file_path"] + extension) 257 | 258 | # NeRF 'transform_matrix' is a camera-to-world transform 259 | c2w = np.array(frame["transform_matrix"]) 260 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 261 | c2w[:3, 1:3] *= -1 262 | 263 | # get the world-to-camera transform and set R, T 264 | w2c = np.linalg.inv(c2w) 265 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 266 | T = w2c[:3, 3] 267 | 268 | image_path = os.path.join(path, cam_name) 269 | image_name = Path(cam_name).stem 270 | image = Image.open(image_path) 271 | 272 | im_data = np.array(image.convert("RGBA")) 273 | 274 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 275 | 276 | norm_data = im_data / 255.0 277 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 278 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 279 | 280 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 281 | FovY = fovy 282 | FovX = fovx 283 | exposure = 1 284 | 285 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 286 | image_path=image_path, image_name=image_name, width=image.size[0], 287 | height=image.size[1], model=ProjectionType.PERSPECTIVE, 288 | distortion_params=None, exposure=exposure, aperature=aperature, iso=iso)) 289 | 290 | return cam_infos 291 | 292 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 293 | print("Reading Training Transforms") 294 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 295 | print("Reading Test Transforms") 296 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 297 | 298 | if not eval: 299 | train_cam_infos.extend(test_cam_infos) 300 | test_cam_infos = [] 301 | 302 | nerf_normalization = getNerfppNorm(train_cam_infos) 303 | 304 | ply_path = os.path.join(path, "points3d.ply") 305 | if not os.path.exists(ply_path): 306 | # Since this data set has no colmap data, we start with random points 307 | num_pts = 100_000 308 | print(f"Generating random point cloud ({num_pts})...") 309 | 310 | # We create random points inside the bounds of the synthetic Blender scenes 311 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 312 | shs = np.random.random((num_pts, 3)) / 255.0 313 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 314 | 315 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 316 | try: 317 | pcd = fetchPly(ply_path) 318 | except: 319 | pcd = None 320 | 321 | scene_info = SceneInfo(point_cloud=pcd, 322 | train_cameras=train_cam_infos, 323 | test_cameras=test_cam_infos, 324 | nerf_normalization=nerf_normalization, 325 | ply_path=ply_path) 326 | return scene_info 327 | 328 | sceneLoadTypeCallbacks = { 329 | "Colmap": readColmapSceneInfo, 330 | "Blender" : readNerfSyntheticInfo 331 | } 332 | -------------------------------------------------------------------------------- /scene/sphere_init.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from scene import contractions 16 | import torch 17 | import numpy as np 18 | from icecream import ic 19 | 20 | def l2_normalize_th(x, eps=torch.finfo(torch.float32).eps): 21 | """Normalize x to unit length along last axis.""" 22 | return x / torch.sqrt( 23 | torch.clip(torch.sum(x**2, dim=-1, keepdim=True), eps, None) 24 | ) 25 | 26 | def sphere_init( 27 | center, 28 | N, 29 | device, 30 | min_opacity=0.2, 31 | opacity_var=0.1, 32 | scale_multi=3, 33 | a=20, 34 | radius=1, 35 | sh_degree=0, 36 | scene_scale=1, 37 | **kwargs 38 | ): 39 | def contraction(x): 40 | return 1 / (1 - x + 1 / a) - 1 / (1 + 1 / a) 41 | 42 | distance = torch.rand((N, 1), device=device) ** (1 / 3) 43 | # c = torch.tensor(1.5, device=device) 44 | c = 2 45 | 46 | # here is the issue. We want a nice constant density field no matter how many d8s there are 47 | 48 | direction = l2_normalize_th(torch.randn((N, 3), dtype=torch.float32, device=device)) 49 | means = c * distance * direction 50 | 51 | area_per_sphere = 4 / 3 * np.pi / N * scale_multi 52 | side_len = (area_per_sphere / (3 / 4) / np.pi) ** (1 / 3) 53 | # scales = side_len + side_len/2*(2*torch.rand(N, 3, device=device)-1) 54 | scales = side_len * torch.ones((N, 3), device=device) 55 | scales = c * scales * 0.4 56 | 57 | quats = l2_normalize_th( 58 | 2 * torch.tensor(np.random.rand(N, 4), dtype=torch.float32, device=device) - 1 59 | ) 60 | 61 | length = ((scales.detach()).exp() ** 2).sum(dim=-1).sqrt() 62 | length = 0.1 63 | desired_opacity = 0.2 + 0.1 * torch.rand((N), device=device) 64 | calc_density = -torch.log(1 - desired_opacity) / length 65 | densities = calc_density / 2 66 | print("Init density max: ", densities.max()) 67 | 68 | means, scales, quats, densities = contractions.inv_contract_gaussians_decomposed( 69 | means, scales, quats, densities 70 | ) 71 | 72 | # length = (scale_activation(scales.detach()) ** 2).sum(dim=-1).sqrt() 73 | # length = 0.1 74 | # desired_opacity = 0.2 + 0.1 * torch.rand((N), device=device) 75 | # calc_density = -torch.log(1 - desired_opacity) / length 76 | # densities = calc_density / 10 77 | 78 | feats = torch.zeros( 79 | (N, (sh_degree + 1) ** 2, 3), dtype=torch.float32, device=device 80 | ) 81 | feats[:, 0:1, :] = torch.tensor( 82 | np.random.rand(N, 1, 3) * 0.3 + 0.3, dtype=torch.float32, device=device 83 | ) 84 | # feats = torch.rand((N, (sh_degree+1)**2, 3), dtype=torch.float32, device=device) * 0.3 + 0.3 85 | return ( 86 | means * scene_scale * radius + center.reshape(1, 3), 87 | scales * scene_scale * radius, 88 | quats, 89 | densities / scene_scale, 90 | feats, 91 | ) 92 | 93 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/cam_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import numpy as np 16 | from utils import stepfun 17 | from icecream import ic 18 | 19 | def rotation_about_axis(degrees, axis=0): 20 | """Creates rotation matrix about one of the coordinate axes.""" 21 | radians = degrees / 180.0 * np.pi 22 | rot2x2 = np.array( 23 | [[np.cos(radians), -np.sin(radians)], [np.sin(radians), np.cos(radians)]] 24 | ) 25 | r = np.eye(3) 26 | r[1:3, 1:3] = rot2x2 27 | r = np.roll(np.roll(r, axis, axis=0), axis, axis=1) 28 | p = np.eye(4) 29 | p[:3, :3] = r 30 | return p 31 | 32 | 33 | def normalize(x): 34 | """Normalization helper function.""" 35 | return x / np.linalg.norm(x) 36 | 37 | 38 | def focus_point_fn(poses, xnp = np): 39 | """Calculate nearest point to all focal axes in poses.""" 40 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 41 | m = xnp.eye(3) - directions * xnp.transpose(directions, [0, 2, 1]) 42 | mt_m = xnp.transpose(m, [0, 2, 1]) @ m 43 | focus_pt = xnp.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 44 | return focus_pt 45 | 46 | def viewmatrix( 47 | lookdir, 48 | up, 49 | position, 50 | lock_up = False, 51 | ): 52 | """Construct lookat view matrix.""" 53 | orthogonal_dir = lambda a, b: normalize(np.cross(a, b)) 54 | vecs = [None, normalize(up), normalize(lookdir)] 55 | # x-axis is always the normalized cross product of `lookdir` and `up`. 56 | vecs[0] = orthogonal_dir(vecs[1], vecs[2]) 57 | # Default is to lock `lookdir` vector, if lock_up is True lock `up` instead. 58 | ax = 2 if lock_up else 1 59 | # Set the not-locked axis to be orthogonal to the other two. 60 | vecs[ax] = orthogonal_dir(vecs[(ax + 1) % 3], vecs[(ax + 2) % 3]) 61 | m = np.stack(vecs + [position], axis=1) 62 | return m 63 | 64 | 65 | def generate_ellipse_path( 66 | poses, 67 | n_frames = 120, 68 | const_speed = True, 69 | z_variation = 0.0, 70 | z_phase = 0.0, 71 | rad_mult_min = 1.0, 72 | rad_mult_max = 1.0, 73 | render_rotate_xaxis = 0.0, 74 | render_rotate_yaxis = 0.0, 75 | use_avg_z_height = True, 76 | z_height_percentile = None, 77 | lock_up = False, 78 | ): 79 | """Generate an elliptical render path based on the given poses.""" 80 | # Calculate the focal point for the path (cameras point toward this). 81 | center = focus_point_fn(poses) 82 | # Default path height sits at z=0 (in middle of zero-mean capture pattern). 83 | xy_offset = center[:2] 84 | 85 | # Calculate lengths for ellipse axes based on input camera positions. 86 | xy_radii = np.percentile(np.abs(poses[:, :2, 3] - xy_offset), 90, axis=0) 87 | # Use ellipse that is symmetric about the focal point in xy. 88 | xy_low = xy_offset - xy_radii 89 | xy_high = xy_offset + xy_radii 90 | 91 | # Optional height variation, need not be symmetric. 92 | z_min = np.percentile((poses[:, 2, 3]), 10, axis=0) 93 | z_max = np.percentile((poses[:, 2, 3]), 90, axis=0) 94 | # ic(z_min, z_max) 95 | if use_avg_z_height or z_height_percentile is not None: 96 | # Center the path vertically around the average camera height, good for 97 | # datasets recentered by transform_poses_focus function. 98 | if z_height_percentile is None: 99 | z_init = poses[:, 2, 3].mean(axis=0) 100 | else: 101 | z_init = np.percentile(poses[:, 2, 3], z_height_percentile, axis=0) 102 | else: 103 | # Center the path at zero, good for datasets recentered by 104 | # transform_poses_pca function. 105 | z_init = 0 106 | z_low = z_init + z_variation * (z_min - z_init) 107 | z_high = z_init + z_variation * (z_max - z_init) 108 | 109 | xyz_low = np.array([*xy_low, z_low]) 110 | xyz_high = np.array([*xy_high, z_high]) 111 | 112 | def get_positions(theta): 113 | # Interpolate between bounds with trig functions to get ellipse in x-y. 114 | # Optionally also interpolate in z to change camera height along path. 115 | t_x = np.cos(theta) * 0.5 + 0.5 116 | t_y = np.sin(theta) * 0.5 + 0.5 117 | t_z = np.cos(theta + 2 * np.pi * z_phase) * 0.5 + 0.5 118 | t_xyz = np.stack([t_x, t_y, t_z], axis=-1) 119 | positions = xyz_low + t_xyz * (xyz_high - xyz_low) 120 | # Interpolate between min and max radius multipliers so the camera zooms in 121 | # and out of the scene center. 122 | t = np.sin(theta) * 0.5 + 0.5 123 | rad_mult = rad_mult_min + (rad_mult_max - rad_mult_min) * t 124 | positions = center + (positions - center) * rad_mult[:, None] 125 | return positions 126 | 127 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 128 | positions = get_positions(theta) 129 | 130 | if const_speed: 131 | # Resample theta angles so that the velocity is closer to constant. 132 | lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 133 | theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 134 | positions = get_positions(theta) 135 | 136 | # Throw away duplicated last position. 137 | positions = positions[:-1] 138 | 139 | # Set path's up vector to axis closest to average of input pose up vectors. 140 | avg_up = poses[:, :3, 1].mean(0) 141 | avg_up = avg_up / np.linalg.norm(avg_up) 142 | # ic(avg_up) 143 | ind_up = np.argmax(np.abs(avg_up)) 144 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 145 | 146 | # ic(positions, center) 147 | poses = np.stack([viewmatrix(p - center, up, p, lock_up) for p in positions]) 148 | 149 | poses = poses @ rotation_about_axis(-render_rotate_yaxis, axis=1) 150 | poses = poses @ rotation_about_axis(render_rotate_xaxis, axis=0) 151 | return poses 152 | 153 | def pad_poses(p): 154 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 155 | bottom = np.broadcast_to([0, 0, 0, 1.0], p[Ellipsis, :1, :4].shape) 156 | return np.concatenate([p[Ellipsis, :3, :4], bottom], axis=-2) 157 | 158 | 159 | def unpad_poses(p): 160 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 161 | return p[Ellipsis, :3, :4] 162 | 163 | 164 | def transform_poses_pca(poses): 165 | """Transforms poses so principal components lie on XYZ axes. 166 | 167 | Args: 168 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 169 | 170 | Returns: 171 | A tuple (poses, transform), with the transformed poses and the applied 172 | camera_to_world transforms. 173 | """ 174 | t = poses[:, :3, 3] 175 | t_mean = t.mean(axis=0) 176 | t = t - t_mean 177 | 178 | eigval, eigvec = np.linalg.eig(t.T @ t) 179 | # Sort eigenvectors in order of largest to smallest eigenvalue. 180 | inds = np.argsort(eigval)[::-1] 181 | eigvec = eigvec[:, inds] 182 | rot = eigvec.T 183 | if np.linalg.det(rot) < 0: 184 | rot = np.diag(np.array([1, 1, -1])) @ rot 185 | 186 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 187 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 188 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 189 | 190 | # Flip coordinate system if z component of y-axis is negative 191 | if poses_recentered.mean(axis=0)[2, 1] < 0: 192 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 193 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 194 | 195 | # # Just make sure it's it in the [-1, 1]^3 cube 196 | # scale_factor = 1.0 / np.max(np.abs(poses_recentered[:, :3, 3])) 197 | # poses_recentered[:, :3, 3] *= scale_factor 198 | # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 199 | 200 | return poses_recentered, transform 201 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | from PIL import Image 17 | 18 | WARNED = False 19 | 20 | def loadCam(args, id, cam_info, resolution_scale): 21 | image = Image.open(cam_info.image_path) 22 | orig_w, orig_h = image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 26 | else: # should be a type that converts to float 27 | if args.resolution == -1: 28 | if orig_w > 1600: 29 | global WARNED 30 | if not WARNED: 31 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 32 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 33 | WARNED = True 34 | global_down = orig_w / 1600 35 | else: 36 | global_down = 1 37 | else: 38 | global_down = orig_w / args.resolution 39 | 40 | scale = float(global_down) * float(resolution_scale) 41 | resolution = (int(orig_w / scale), int(orig_h / scale)) 42 | 43 | resized_image_rgb = PILtoTorch(image, resolution) 44 | 45 | gt_image = resized_image_rgb[:3, ...] 46 | loaded_mask = None 47 | 48 | if resized_image_rgb.shape[0] == 4: 49 | loaded_mask = resized_image_rgb[3:4, ...] 50 | 51 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 52 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 53 | image=gt_image, gt_alpha_mask=loaded_mask, 54 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, 55 | model=cam_info.model, distortion_params=cam_info.distortion_params, 56 | exposure=cam_info.exposure, aperature=cam_info.aperature, iso=cam_info.iso) 57 | 58 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 59 | camera_list = [] 60 | 61 | for id, c in enumerate(cam_infos): 62 | camera_list.append(loadCam(args, id, c, resolution_scale)) 63 | 64 | return camera_list 65 | 66 | def camera_to_JSON(id, camera : Camera): 67 | Rt = np.zeros((4, 4)) 68 | Rt[:3, :3] = camera.R.transpose() 69 | Rt[:3, 3] = camera.T 70 | Rt[3, 3] = 1.0 71 | 72 | W2C = np.linalg.inv(Rt) 73 | pos = W2C[:3, 3] 74 | rot = W2C[:3, :3] 75 | serializable_array_2d = [x.tolist() for x in rot] 76 | camera_entry = { 77 | 'id' : id, 78 | 'img_name' : camera.image_name, 79 | 'width' : camera.width, 80 | 'height' : camera.height, 81 | 'position': pos.tolist(), 82 | 'rotation': serializable_array_2d, 83 | 'fy' : fov2focal(camera.FovY, camera.height), 84 | 'fx' : fov2focal(camera.FovX, camera.width) 85 | } 86 | return camera_entry 87 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | import cv2 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x/(1-x)) 21 | 22 | def PILtoTorch(pil_image, resolution): 23 | # resized_image_PIL = cv2.resize(pil_image, resolution) 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | def get_expon_lr_func( 32 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 33 | ): 34 | """ 35 | Copied from Plenoxels 36 | 37 | Continuous learning rate decay function. Adapted from JaxNeRF 38 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 39 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 40 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 41 | function of lr_delay_mult, such that the initial learning rate is 42 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 43 | to the normal learning rate when steps>lr_delay_steps. 44 | :param conf: config subtree 'lr' or similar 45 | :param max_steps: int, the number of steps during optimization. 46 | :return HoF which takes step as input 47 | """ 48 | 49 | def helper(step): 50 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 51 | # Disable this parameter 52 | return 0.0 53 | if lr_delay_steps > 0: 54 | # A kind of reverse cosine decay. 55 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 56 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 57 | ) 58 | else: 59 | delay_rate = 1.0 60 | t = np.clip(step / max_steps, 0, 1) 61 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 62 | return delay_rate * log_lerp 63 | 64 | return helper 65 | 66 | def strip_lowerdiag(L): 67 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 68 | 69 | uncertainty[:, 0] = L[:, 0, 0] 70 | uncertainty[:, 1] = L[:, 0, 1] 71 | uncertainty[:, 2] = L[:, 0, 2] 72 | uncertainty[:, 3] = L[:, 1, 1] 73 | uncertainty[:, 4] = L[:, 1, 2] 74 | uncertainty[:, 5] = L[:, 2, 2] 75 | return uncertainty 76 | 77 | def strip_symmetric(sym): 78 | return strip_lowerdiag(sym) 79 | 80 | def build_rotation(r): 81 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 82 | 83 | q = r / norm[:, None] 84 | 85 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 86 | 87 | r = q[:, 0] 88 | x = q[:, 1] 89 | y = q[:, 2] 90 | z = q[:, 3] 91 | 92 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 93 | R[:, 0, 1] = 2 * (x*y - r*z) 94 | R[:, 0, 2] = 2 * (x*z + r*y) 95 | R[:, 1, 0] = 2 * (x*y + r*z) 96 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 97 | R[:, 1, 2] = 2 * (y*z - r*x) 98 | R[:, 2, 0] = 2 * (x*z - r*y) 99 | R[:, 2, 1] = 2 * (y*z + r*x) 100 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 101 | return R 102 | 103 | def build_scaling_rotation(s, r): 104 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 105 | R = build_rotation(r) 106 | 107 | L[:,0,0] = s[:,0] 108 | L[:,1,1] = s[:,1] 109 | L[:,2,2] = s[:,2] 110 | 111 | L = R @ L 112 | return L 113 | 114 | def safe_state(silent): 115 | old_f = sys.stdout 116 | class F: 117 | def __init__(self, silent): 118 | self.silent = silent 119 | 120 | def write(self, x): 121 | if not self.silent: 122 | if x.endswith("\n"): 123 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 124 | else: 125 | old_f.write(x) 126 | 127 | def flush(self): 128 | old_f.flush() 129 | 130 | sys.stdout = F(silent) 131 | 132 | random.seed(0) 133 | np.random.seed(0) 134 | torch.manual_seed(0) 135 | torch.cuda.set_device(torch.device("cuda:0")) 136 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | from utils import safe_math 17 | from icecream import ic 18 | 19 | class BasicPointCloud(NamedTuple): 20 | points : np.array 21 | colors : np.array 22 | normals : np.array 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | def getWorld2View(R, t): 34 | Rt = np.zeros((4, 4)) 35 | Rt[:3, :3] = R.transpose() 36 | Rt[:3, 3] = t 37 | Rt[3, 3] = 1.0 38 | return np.float32(Rt) 39 | 40 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 41 | Rt = np.zeros((4, 4)) 42 | Rt[:3, :3] = R.transpose() 43 | Rt[:3, 3] = t 44 | Rt[3, 3] = 1.0 45 | 46 | C2W = np.linalg.inv(Rt) 47 | cam_center = C2W[:3, 3] 48 | cam_center = (cam_center + translate) * scale 49 | C2W[:3, 3] = cam_center 50 | Rt = np.linalg.inv(C2W) 51 | return np.float32(Rt) 52 | 53 | def getProjectionMatrix(znear, zfar, fovX, fovY): 54 | tanHalfFovY = math.tan((fovY / 2)) 55 | tanHalfFovX = math.tan((fovX / 2)) 56 | 57 | top = tanHalfFovY * znear 58 | bottom = -top 59 | right = tanHalfFovX * znear 60 | left = -right 61 | 62 | P = torch.zeros(4, 4) 63 | 64 | z_sign = 1.0 65 | 66 | P[0, 0] = 2.0 * znear / (right - left) 67 | P[1, 1] = 2.0 * znear / (top - bottom) 68 | P[0, 2] = (right + left) / (right - left) 69 | P[1, 2] = (top + bottom) / (top - bottom) 70 | P[3, 2] = z_sign 71 | P[2, 2] = z_sign * zfar / (zfar - znear) 72 | P[2, 3] = -(zfar * znear) / (zfar - znear) 73 | return P 74 | 75 | def fov2focal(fov, pixels): 76 | return pixels / (2 * math.tan(fov / 2)) 77 | 78 | def focal2fov(focal, pixels): 79 | return 2*math.atan(pixels/(2*focal)) 80 | 81 | 82 | def project_points( 83 | points: torch.Tensor, transf_matrix: torch.Tensor 84 | ) -> torch.Tensor: 85 | """Projects points to NDC with a given P@V matrix.""" 86 | 87 | n_points, _ = points.shape 88 | ones = torch.ones(n_points, 1, dtype=points.dtype, device=points.device) 89 | points_hom = torch.cat([points, ones], dim=1) 90 | points_out = (transf_matrix @ points_hom[..., None]).squeeze(-1) 91 | 92 | denom = points_out[..., 3:] 93 | return safe_math.safe_div(points_out[..., :3], denom).squeeze(dim=0) 94 | 95 | 96 | def visible_depth_from_camspace( 97 | cameraspace_points: torch.Tensor, 98 | ) -> torch.Tensor: 99 | """Returns a bool tensor that indicates if a point is visible.""" 100 | z = cameraspace_points[:, 2] 101 | return z > 0.2 102 | 103 | 104 | def in_screen_from_ndc(ndc_points: torch.Tensor) -> torch.Tensor: 105 | """Returns a bool tensor that indicates if a point is in screen.""" 106 | x, y = ndc_points[:, 0], ndc_points[:, 1] 107 | return torch.logical_and( 108 | torch.logical_and(x > -1.3, x < 1.3), torch.logical_and(y > -1.3, y < 1.3) 109 | ) 110 | 111 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/safe_math.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | from torch.autograd import Function 17 | from icecream import ic 18 | #@title Math util 19 | tiny_val = torch.finfo(torch.float32).tiny 20 | min_val = torch.finfo(torch.float32).min 21 | max_val = torch.finfo(torch.float32).max 22 | 23 | def remove_zero(x): 24 | """Shifts `x` away from 0.""" 25 | return torch.where(torch.abs(x) < tiny_val, tiny_val, x) 26 | 27 | class SafeDiv(Function): 28 | @staticmethod 29 | def forward(n, d): 30 | r = torch.clip(n / remove_zero(d), min_val, max_val) 31 | return torch.where(torch.abs(d) < tiny_val, 0, r) 32 | 33 | @staticmethod 34 | def setup_context(ctx, inputs, outputs): 35 | n, d = inputs 36 | ctx.save_for_backward(n, d, outputs) 37 | 38 | @staticmethod 39 | def backward(ctx, g): 40 | n, d, r = ctx.saved_tensors 41 | dn = torch.clip(g / remove_zero(d), min_val, max_val) 42 | dd = torch.clip(-g * r / remove_zero(d), min_val, max_val) 43 | return dn, dd 44 | 45 | class SafeSqrt(Function): 46 | @staticmethod 47 | def forward(tensor): 48 | mask = torch.abs(tensor) < tiny_val 49 | val = tensor.sqrt() 50 | return torch.where(mask, 0, val) 51 | 52 | @staticmethod 53 | def setup_context(ctx, inputs, outputs): 54 | tensor, = inputs 55 | ctx.save_for_backward(tensor, outputs) 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | tensor, val = ctx.saved_tensors 60 | mask = torch.abs(tensor) < tiny_val 61 | rcp = safe_div(1, val) 62 | return torch.where( 63 | mask, max_val * grad_output, 0.5 * rcp * grad_output) 64 | 65 | safe_sqrt = SafeSqrt.apply 66 | safe_div = SafeDiv.apply 67 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def RGB2SH(rgb): 58 | return (rgb - 0.5) / C0 59 | 60 | def SH2RGB(sh): 61 | return sh * C0 + 0.5 62 | 63 | @torch.jit.script 64 | def eval_sh(deg: int, sh, dirs): 65 | """ 66 | Evaluate spherical harmonics at unit directions 67 | using hardcoded SH polynomials. 68 | Works with torch/np/jnp. 69 | ... Can be 0 or more batch dimensions. 70 | Args: 71 | deg: int SH deg. Currently, 0-3 supported 72 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 73 | dirs: jnp.ndarray unit directions [..., 3] 74 | Returns: 75 | [..., C] 76 | """ 77 | assert deg <= 4 and deg >= 0 78 | coeff = (deg + 1) ** 2 79 | assert sh.shape[-1] >= coeff 80 | 81 | result = 0.28209479177387814 * sh[..., 0] 82 | if deg > 0: 83 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 84 | result = (result - 85 | 0.4886025119029199 * y * sh[..., 1] + 86 | 0.4886025119029199 * z * sh[..., 2] - 87 | 0.4886025119029199 * x * sh[..., 3]) 88 | 89 | if deg > 1: 90 | xx, yy, zz = x * x, y * y, z * z 91 | xy, yz, xz = x * y, y * z, x * z 92 | result = (result + 93 | 1.0925484305920792 * xy * sh[..., 4] + 94 | -1.0925484305920792 * yz * sh[..., 5] + 95 | 0.31539156525252005 * (2.0 * zz - xx - yy) * sh[..., 6] + 96 | -1.0925484305920792 * xz * sh[..., 7] + 97 | 0.5462742152960396 * (xx - yy) * sh[..., 8]) 98 | 99 | if deg > 2: 100 | result = (result + 101 | -0.5900435899266435 * y * (3 * xx - yy) * sh[..., 9] + 102 | 2.890611442640554 * xy * z * sh[..., 10] + 103 | -0.4570457994644658 * y * (4 * zz - xx - yy)* sh[..., 11] + 104 | 0.3731763325901154 * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 105 | -0.4570457994644658 * x * (4 * zz - xx - yy) * sh[..., 13] + 106 | 1.445305721320277 * z * (xx - yy) * sh[..., 14] + 107 | -0.5900435899266435 * x * (xx - 3 * yy) * sh[..., 15]) 108 | 109 | if deg > 3: 110 | result = (result + 2.5033429417967046 * xy * (xx - yy) * sh[..., 16] + 111 | -1.7701307697799304 * yz * (3 * xx - yy) * sh[..., 17] + 112 | 0.9461746957575601 * xy * (7 * zz - 1) * sh[..., 18] + 113 | -0.6690465435572892 * yz * (7 * zz - 3) * sh[..., 19] + 114 | 0.10578554691520431 * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 115 | -0.6690465435572892 * xz * (7 * zz - 3) * sh[..., 21] + 116 | 0.47308734787878004 * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 117 | -1.7701307697799304 * xz * (xx - 3 * yy) * sh[..., 23] + 118 | 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 119 | return result 120 | -------------------------------------------------------------------------------- /utils/stepfun.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tools for manipulating step functions (piecewise-constant 1D functions). 17 | 18 | We have a shared naming and dimension convention for these functions. 19 | All input/output step functions are assumed to be aligned along the last axis. 20 | `t` always indicates the x coordinates of the *endpoints* of a step function. 21 | `y` indicates unconstrained values for the *bins* of a step function 22 | `w` indicates bin weights that sum to <= 1. `p` indicates non-negative bin 23 | values that *integrate* to <= 1. 24 | """ 25 | 26 | from utils import math 27 | import jax 28 | jax.config.update('jax_platform_name', 'cpu') 29 | import jax.numpy as jnp 30 | import numpy as np 31 | 32 | 33 | def device_is_tpu(): 34 | return False 35 | 36 | 37 | def assert_valid_stepfun(t, y): 38 | """Assert that step function (t, y) has a valid shape.""" 39 | if t.shape[-1] != y.shape[-1] + 1: 40 | raise ValueError( 41 | f'Invalid shapes ({t.shape}, {y.shape}) for a step function.' 42 | ) 43 | 44 | def query(tq, t, y, left=None, right=None): 45 | """Query step function (t, y) at locations tq. Edges repeat by default.""" 46 | assert_valid_stepfun(t, y) 47 | # Query the step function to recover the interval value. 48 | (i0, i1), ((yq, _),) = math.sorted_lookup(tq, t, (y,), device_is_tpu()) 49 | # Apply boundary conditions. 50 | left = y[Ellipsis, :1] if left is None else left 51 | right = y[Ellipsis, -1:] if right is None else right 52 | yq = math.select([(i1 == 0, left), (i0 == y.shape[-1], right)], yq) 53 | return yq 54 | 55 | 56 | def weight_to_pdf(t, w): 57 | """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" 58 | assert_valid_stepfun(t, w) 59 | td = jnp.diff(t) 60 | return jnp.where(td < np.finfo(np.float32).tiny, 0, math.safe_div(w, td)) 61 | 62 | 63 | def pdf_to_weight(t, p): 64 | """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" 65 | assert_valid_stepfun(t, p) 66 | return p * jnp.diff(t) 67 | 68 | 69 | def integrate_weights(w): 70 | """Compute the cumulative sum of w, assuming all weight vectors sum to 1. 71 | 72 | The output's size on the last dimension is one greater than that of the input, 73 | because we're computing the integral corresponding to the endpoints of a step 74 | function, not the integral of the interior/bin values. 75 | 76 | Args: 77 | w: Tensor, which will be integrated along the last axis. This is assumed to 78 | sum to 1 along the last axis, and this function will (silently) break if 79 | that is not the case. 80 | 81 | Returns: 82 | cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 83 | """ 84 | cw = jnp.minimum(1, jnp.cumsum(w[Ellipsis, :-1], axis=-1)) 85 | shape = cw.shape[:-1] + (1,) 86 | # Ensure that the CDF starts with exactly 0 and ends with exactly 1. 87 | cw0 = jnp.concatenate([jnp.zeros(shape), cw, jnp.ones(shape)], axis=-1) 88 | return cw0 89 | 90 | 91 | def invert_cdf(u, t, w_logits): 92 | """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" 93 | assert_valid_stepfun(t, w_logits) 94 | # Compute the PDF and CDF for each weight vector. 95 | w = jax.nn.softmax(w_logits, axis=-1) 96 | cw = integrate_weights(w) 97 | # Interpolate into the inverse CDF. 98 | t_new = math.sorted_interp(u, cw, t, device_is_tpu()) 99 | return t_new 100 | 101 | 102 | def sample( 103 | rng, 104 | t, 105 | w_logits, 106 | num_samples, 107 | single_jitter=False, 108 | deterministic_center=False, 109 | eps=jnp.finfo(jnp.float32).eps, 110 | ): 111 | """Piecewise-Constant PDF sampling from a step function. 112 | 113 | Args: 114 | rng: random number generator (or None for `linspace` sampling). 115 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 116 | w_logits: [..., num_bins], logits corresponding to bin weights 117 | num_samples: int, the number of samples. 118 | single_jitter: bool, if True, jitter every sample along each ray by the same 119 | amount in the inverse CDF. Otherwise, jitter each sample independently. 120 | deterministic_center: bool, if False, when `rng` is None return samples that 121 | linspace the entire PDF. If True, skip the front and back of the linspace 122 | so that the centers of each PDF interval are returned. 123 | eps: float, something like numerical epsilon. 124 | 125 | Returns: 126 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 127 | """ 128 | assert_valid_stepfun(t, w_logits) 129 | 130 | # Draw uniform samples. 131 | if rng is None: 132 | # Match the behavior of jax.random.uniform() by spanning [0, 1-eps]. 133 | if deterministic_center: 134 | pad = 1 / (2 * num_samples) 135 | u = jnp.linspace(pad, 1.0 - pad - eps, num_samples) 136 | else: 137 | u = jnp.linspace(0, 1.0 - eps, num_samples) 138 | u = jnp.broadcast_to(u, t.shape[:-1] + (num_samples,)) 139 | else: 140 | # `u` is in [0, 1) --- it can be zero, but it can never be 1. 141 | u_max = eps + (1 - eps) / num_samples 142 | max_jitter = (1 - u_max) / (num_samples - 1) - eps 143 | d = 1 if single_jitter else num_samples 144 | u = jnp.linspace(0, 1 - u_max, num_samples) + jax.random.uniform( 145 | rng, t.shape[:-1] + (d,), maxval=max_jitter 146 | ) 147 | 148 | return invert_cdf(u, t, w_logits) 149 | 150 | 151 | def sample_intervals( 152 | rng, 153 | t, 154 | w_logits, 155 | num_samples, 156 | single_jitter=False, 157 | domain=(-jnp.inf, jnp.inf), 158 | ): 159 | """Sample *intervals* (rather than points) from a step function. 160 | 161 | Args: 162 | rng: random number generator (or None for `linspace` sampling). 163 | t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) 164 | w_logits: [..., num_bins], logits corresponding to bin weights 165 | num_samples: int, the number of intervals to sample. 166 | single_jitter: bool, if True, jitter every sample along each ray by the same 167 | amount in the inverse CDF. Otherwise, jitter each sample independently. 168 | domain: (minval, maxval), the range of valid values for `t`. 169 | 170 | Returns: 171 | t_samples: jnp.ndarray(float32), [batch_size, num_samples]. 172 | """ 173 | assert_valid_stepfun(t, w_logits) 174 | if num_samples <= 1: 175 | raise ValueError(f'num_samples must be > 1, is {num_samples}.') 176 | 177 | # Sample a set of points from the step function. 178 | centers = sample( 179 | rng, t, w_logits, num_samples, single_jitter, deterministic_center=True 180 | ) 181 | 182 | # The intervals we return will span the midpoints of each adjacent sample. 183 | mid = (centers[Ellipsis, 1:] + centers[Ellipsis, :-1]) / 2 184 | 185 | # Each first/last fencepost is the reflection of the first/last midpoint 186 | # around the first/last sampled center. 187 | first = 2 * centers[Ellipsis, :1] - mid[Ellipsis, :1] 188 | last = 2 * centers[Ellipsis, -1:] - mid[Ellipsis, -1:] 189 | samples = jnp.concatenate([first, mid, last], axis=-1) 190 | 191 | # We clamp to the limits of the input domain, provided by the caller. 192 | samples = jnp.clip(samples, *domain) 193 | return samples 194 | 195 | 196 | def lossfun_distortion(t, w): 197 | """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" 198 | assert_valid_stepfun(t, w) 199 | 200 | # The loss incurred between all pairs of intervals. 201 | ut = (t[Ellipsis, 1:] + t[Ellipsis, :-1]) / 2 202 | dut = jnp.abs(ut[Ellipsis, :, None] - ut[Ellipsis, None, :]) 203 | loss_inter = jnp.sum(w * jnp.sum(w[Ellipsis, None, :] * dut, axis=-1), axis=-1) 204 | 205 | # The loss incurred within each individual interval with itself. 206 | loss_intra = jnp.sum(w**2 * jnp.diff(t), axis=-1) / 3 207 | 208 | return loss_inter + loss_intra 209 | 210 | 211 | def weighted_percentile(t, w, ps): 212 | """Compute the weighted percentiles of a step function. w's must sum to 1.""" 213 | assert_valid_stepfun(t, w) 214 | cw = integrate_weights(w) 215 | # We want to interpolate into the integrated weights according to `ps`. 216 | wprctile = jnp.vectorize(jnp.interp, signature='(n),(m),(m)->(n)')( 217 | jnp.array(ps) / 100, cw, t 218 | ) 219 | return wprctile 220 | 221 | 222 | def resample(t, tp, vp, use_avg=False): 223 | """Resample a step function defined by (tp, vp) into intervals t. 224 | 225 | Notation roughly matches jnp.interp. Resamples by summation by default. 226 | 227 | Args: 228 | t: tensor with shape (..., n+1), the endpoints to resample into. 229 | tp: tensor with shape (..., m+1), the endpoints of the step function being 230 | resampled. 231 | vp: tensor with shape (..., m), the values of the step function being 232 | resampled. 233 | use_avg: bool, if False, return the sum of the step function for each 234 | interval in `t`. If True, return the average, weighted by the width of 235 | each interval in `t`. 236 | 237 | Returns: 238 | v: tensor with shape (..., n), the values of the resampled step function. 239 | """ 240 | assert_valid_stepfun(tp, vp) 241 | if use_avg: 242 | wp = jnp.diff(tp) 243 | v_numer = resample(t, tp, vp * wp, use_avg=False) 244 | v_denom = resample(t, tp, wp, use_avg=False) 245 | v = math.safe_div(v_numer, v_denom) 246 | return v 247 | 248 | acc = jnp.cumsum(vp, axis=-1) 249 | acc0 = jnp.concatenate([jnp.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) 250 | acc0_resampled = jnp.vectorize(jnp.interp, signature='(n),(m),(m)->(n)')( 251 | t, tp, acc0 252 | ) 253 | v = jnp.diff(acc0_resampled, axis=-1) 254 | return v 255 | 256 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------