├── .clang-format ├── .clangd_template ├── .github └── workflows │ ├── building.yml │ ├── core_tests.yml │ └── cuda │ ├── Linux-env.sh │ ├── Linux.sh │ ├── Windows-env.sh │ └── Windows.sh ├── .gitignore ├── .gitmodules ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets └── test_garden.npz ├── docs └── _static │ └── imgs │ └── front-fig-stacked.jpg ├── examples ├── lidar_rendering.ipynb └── rolling_shutter.ipynb ├── gsplat ├── __init__.py ├── _helper.py ├── compression │ ├── __init__.py │ ├── png_compression.py │ └── sort.py ├── cuda │ ├── __init__.py │ ├── _backend.py │ ├── _torch_impl.py │ ├── _wrapper.py │ └── csrc │ │ ├── bindings.h │ │ ├── ext.cpp │ │ ├── helpers.cuh │ │ ├── projection.cu │ │ ├── rasterization.cu │ │ ├── relocation.cu │ │ ├── sh.cu │ │ └── utils.cuh ├── cuda_legacy │ ├── __init__.py │ ├── _backend.py │ ├── _torch_impl.py │ ├── _wrapper.py │ └── csrc │ │ ├── CMakeLists.txt │ │ ├── backward.cu │ │ ├── backward.cuh │ │ ├── bindings.cu │ │ ├── bindings.h │ │ ├── config.h │ │ ├── ext.cpp │ │ ├── forward.cu │ │ ├── forward.cuh │ │ ├── helpers.cuh │ │ └── sh.cuh ├── distributed.py ├── profile.py ├── relocation.py ├── rendering.py ├── strategy │ ├── __init__.py │ ├── base.py │ ├── default.py │ ├── mcmc.py │ └── ops.py ├── utils.py └── version.py ├── setup.py └── tests ├── _test_distributed.py ├── test_basic.py ├── test_compression.py ├── test_rasterization.py └── test_strategy.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | AlignAfterOpenBracket: BlockIndent 3 | BinPackArguments: false 4 | BinPackParameters: false 5 | IndentWidth: 4 6 | -------------------------------------------------------------------------------- /.clangd_template: -------------------------------------------------------------------------------- 1 | # To set up the clangd config. 2 | # 1. Activate an environment with cuda installed. (probably via conda, skip if using system CUDA) 3 | # 2. Run: 4 | # echo "# Autogenerated, see .clangd_template\!" > .clangd && sed -e "/^#/d" -e "s|YOUR_CUDA_PATH|$(dirname $(dirname $(which nvcc)))|" .clangd_template >> .clangd 5 | CompileFlags: 6 | Add: 7 | - -Xclang 8 | - -fcuda-allow-variadic-functions 9 | - --cuda-path=YOUR_CUDA_PATH 10 | Remove: 11 | - --diag_suppress=* 12 | - --generate-code=* 13 | - -gencode=* 14 | - -forward-unknown-to-host-compiler 15 | - -Xcompiler 16 | - -Xcudafe 17 | - --use_fast_math 18 | - --options-file 19 | - --compiler-options 20 | - --expt-relaxed-constexpr 21 | -------------------------------------------------------------------------------- /.github/workflows/building.yml: -------------------------------------------------------------------------------- 1 | name: Building Wheels 2 | 3 | on: [workflow_dispatch] 4 | 5 | jobs: 6 | 7 | wheel: 8 | runs-on: ${{ matrix.os }} 9 | environment: production 10 | 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | # os: [windows-2019] 15 | # python-version: ['3.7'] 16 | # torch-version: [1.11.0] 17 | # cuda-version: ['cu113'] 18 | os: [ubuntu-20.04, windows-2019] 19 | # support version based on: https://download.pytorch.org/whl/torch/ 20 | python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] 21 | torch-version: [1.11.0, 1.12.0, 1.13.0, 2.0.0] 22 | cuda-version: ['cu113', 'cu115', 'cu116', 'cu117', 'cu118'] 23 | exclude: 24 | - torch-version: 1.11.0 25 | python-version: '3.11' 26 | - torch-version: 1.11.0 27 | cuda-version: 'cu116' 28 | - torch-version: 1.11.0 29 | cuda-version: 'cu117' 30 | - torch-version: 1.11.0 31 | cuda-version: 'cu118' 32 | 33 | - torch-version: 1.12.0 34 | python-version: '3.11' 35 | - torch-version: 1.12.0 36 | cuda-version: 'cu115' 37 | - torch-version: 1.12.0 38 | cuda-version: 'cu117' 39 | - torch-version: 1.12.0 40 | cuda-version: 'cu118' 41 | 42 | - torch-version: 1.13.0 43 | cuda-version: 'cu102' 44 | - torch-version: 1.13.0 45 | cuda-version: 'cu113' 46 | - torch-version: 1.13.0 47 | cuda-version: 'cu115' 48 | - torch-version: 1.13.0 49 | cuda-version: 'cu118' 50 | 51 | - torch-version: 2.0.0 52 | python-version: '3.7' 53 | - torch-version: 2.0.0 54 | cuda-version: 'cu102' 55 | - torch-version: 2.0.0 56 | cuda-version: 'cu113' 57 | - torch-version: 2.0.0 58 | cuda-version: 'cu115' 59 | - torch-version: 2.0.0 60 | cuda-version: 'cu116' 61 | 62 | - os: windows-2019 63 | cuda-version: 'cu102' 64 | - os: windows-2019 65 | torch-version: 1.13.0 66 | python-version: '3.11' 67 | 68 | # - os: windows-2019 69 | # torch-version: 1.13.0 70 | # cuda-version: 'cu117' 71 | # python-version: '3.9' 72 | 73 | 74 | 75 | steps: 76 | - uses: actions/checkout@v3 77 | with: 78 | submodules: 'recursive' 79 | 80 | - name: Set up Python ${{ matrix.python-version }} 81 | uses: actions/setup-python@v4 82 | with: 83 | python-version: ${{ matrix.python-version }} 84 | 85 | - name: Upgrade pip 86 | run: | 87 | pip install --upgrade setuptools 88 | pip install ninja 89 | 90 | - name: Free up disk space 91 | if: ${{ runner.os == 'Linux' }} 92 | run: | 93 | sudo rm -rf /usr/share/dotnet 94 | 95 | - name: Install CUDA ${{ matrix.cuda-version }} 96 | if: ${{ matrix.cuda-version != 'cpu' }} 97 | run: | 98 | bash .github/workflows/cuda/${{ runner.os }}.sh ${{ matrix.cuda-version }} 99 | 100 | - name: Install PyTorch ${{ matrix.torch-version }}+${{ matrix.cuda-version }} 101 | run: | 102 | pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ matrix.cuda-version }} 103 | python -c "import torch; print('PyTorch:', torch.__version__)" 104 | python -c "import torch; print('CUDA:', torch.version.cuda)" 105 | python -c "import torch; print('CUDA Available:', torch.cuda.is_available())" 106 | 107 | - name: Patch PyTorch static constexpr on Windows 108 | if: ${{ runner.os == 'Windows' }} 109 | run: | 110 | Torch_DIR=`python -c 'import os; import torch; print(os.path.dirname(torch.__file__))'` 111 | sed -i '31,38c\ 112 | TORCH_API void lazy_init_num_threads();' ${Torch_DIR}/include/ATen/Parallel.h 113 | shell: bash 114 | 115 | - name: Set version 116 | if: ${{ runner.os != 'macOS' }} 117 | run: | 118 | VERSION=`sed -n 's/^__version__ = "\(.*\)"/\1/p' gsplat/version.py` 119 | TORCH_VERSION=`echo "pt${{ matrix.torch-version }}" | sed "s/..$//" | sed "s/\.//g"` 120 | CUDA_VERSION=`echo ${{ matrix.cuda-version }}` 121 | echo "New version name: $VERSION+$TORCH_VERSION$CUDA_VERSION" 122 | sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" gsplat/version.py 123 | shell: 124 | bash 125 | 126 | - name: Install main package for CPU 127 | if: ${{ matrix.cuda-version == 'cpu' }} 128 | run: | 129 | BUILD_NO_CUDA=1 pip install . 130 | 131 | - name: Build wheel 132 | run: | 133 | pip install wheel 134 | source .github/workflows/cuda/${{ runner.os }}-env.sh ${{ matrix.cuda-version }} 135 | python setup.py bdist_wheel --dist-dir=dist 136 | shell: bash # `source` does not exist in windows powershell 137 | 138 | - name: Test wheel 139 | run: | 140 | cd dist 141 | ls -lah 142 | pip install *.whl 143 | python -c "import gsplat; print('gsplat:', gsplat.__version__)" 144 | cd .. 145 | shell: bash # `ls -lah` does not exist in windows powershell 146 | -------------------------------------------------------------------------------- /.github/workflows/core_tests.yml: -------------------------------------------------------------------------------- 1 | name: Core Tests. 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | with: 19 | submodules: 'recursive' 20 | 21 | - name: Set up Python 3.8.12 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.8.12" 25 | - name: Install dependencies 26 | run: | 27 | pip install pytest 28 | pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cpu 29 | BUILD_NO_CUDA=1 pip install . 30 | - name: Run Tests. 31 | run: pytest tests/ 32 | -------------------------------------------------------------------------------- /.github/workflows/cuda/Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Took from https://github.com/pyg-team/pyg-lib/ 4 | 5 | case ${1} in 6 | cu118) 7 | export CUDA_HOME=/usr/local/cuda-11.8 8 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 9 | export PATH=${CUDA_HOME}/bin:${PATH} 10 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 11 | ;; 12 | cu117) 13 | export CUDA_HOME=/usr/local/cuda-11.7 14 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 15 | export PATH=${CUDA_HOME}/bin:${PATH} 16 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 17 | ;; 18 | cu116) 19 | export CUDA_HOME=/usr/local/cuda-11.6 20 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 21 | export PATH=${CUDA_HOME}/bin:${PATH} 22 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 23 | ;; 24 | cu115) 25 | export CUDA_HOME=/usr/local/cuda-11.5 26 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 27 | export PATH=${CUDA_HOME}/bin:${PATH} 28 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 29 | ;; 30 | cu113) 31 | export CUDA_HOME=/usr/local/cuda-11.3 32 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 33 | export PATH=${CUDA_HOME}/bin:${PATH} 34 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 35 | ;; 36 | cu102) 37 | export CUDA_HOME=/usr/local/cuda-10.2 38 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 39 | export PATH=${CUDA_HOME}/bin:${PATH} 40 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" 41 | ;; 42 | *) 43 | ;; 44 | esac -------------------------------------------------------------------------------- /.github/workflows/cuda/Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Took from https://github.com/pyg-team/pyg-lib/ 4 | 5 | OS=ubuntu2004 6 | 7 | case ${1} in 8 | cu118) 9 | CUDA=11.8 10 | APT_KEY=${OS}-${CUDA/./-}-local 11 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.0-520.61.05-1_amd64.deb 12 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.0/local_installers 13 | ;; 14 | cu117) 15 | CUDA=11.7 16 | APT_KEY=${OS}-${CUDA/./-}-local 17 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.1-515.65.01-1_amd64.deb 18 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.1/local_installers 19 | ;; 20 | cu116) 21 | CUDA=11.6 22 | APT_KEY=${OS}-${CUDA/./-}-local 23 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.2-510.47.03-1_amd64.deb 24 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.2/local_installers 25 | ;; 26 | cu115) 27 | CUDA=11.5 28 | APT_KEY=${OS}-${CUDA/./-}-local 29 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.2-495.29.05-1_amd64.deb 30 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.2/local_installers 31 | ;; 32 | cu113) 33 | CUDA=11.3 34 | APT_KEY=${OS}-${CUDA/./-}-local 35 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.0-465.19.01-1_amd64.deb 36 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.0/local_installers 37 | ;; 38 | cu102) 39 | CUDA=10.2 40 | APT_KEY=${CUDA/./-}-local-${CUDA}.89-440.33.01 41 | FILENAME=cuda-repo-${OS}-${APT_KEY}_1.0-1_amd64.deb 42 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}/Prod/local_installers 43 | ;; 44 | *) 45 | echo "Unrecognized CUDA_VERSION=${1}" 46 | exit 1 47 | ;; 48 | esac 49 | 50 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 51 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 52 | wget -nv ${URL}/${FILENAME} 53 | sudo dpkg -i ${FILENAME} 54 | 55 | if [ "${1}" = "cu117" ] || [ "${1}" = "cu118" ]; then 56 | sudo cp /var/cuda-repo-${APT_KEY}/cuda-*-keyring.gpg /usr/share/keyrings/ 57 | else 58 | sudo apt-key add /var/cuda-repo-${APT_KEY}/7fa2af80.pub 59 | fi 60 | 61 | sudo apt-get update 62 | sudo apt-get -y install cuda 63 | 64 | rm -f ${FILENAME} -------------------------------------------------------------------------------- /.github/workflows/cuda/Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Took from https://github.com/pyg-team/pyg-lib/ 4 | 5 | case ${1} in 6 | cu118) 7 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.8 8 | PATH=${CUDA_HOME}/bin:$PATH 9 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 10 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 11 | ;; 12 | cu117) 13 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.7 14 | PATH=${CUDA_HOME}/bin:$PATH 15 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 16 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 17 | ;; 18 | cu116) 19 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.6 20 | PATH=${CUDA_HOME}/bin:$PATH 21 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 22 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 23 | ;; 24 | cu115) 25 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.5 26 | PATH=${CUDA_HOME}/bin:$PATH 27 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 28 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 29 | ;; 30 | cu113) 31 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 32 | PATH=${CUDA_HOME}/bin:$PATH 33 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 34 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 35 | ;; 36 | *) 37 | ;; 38 | esac -------------------------------------------------------------------------------- /.github/workflows/cuda/Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Took from https://github.com/pyg-team/pyg-lib/ 4 | 5 | # Install NVIDIA drivers, see: 6 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 7 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 8 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 9 | 10 | case ${1} in 11 | cu118) 12 | CUDA_SHORT=11.8 13 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 14 | CUDA_FILE=cuda_${CUDA_SHORT}.0_522.06_windows.exe 15 | ;; 16 | cu117) 17 | CUDA_SHORT=11.7 18 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers 19 | CUDA_FILE=cuda_${CUDA_SHORT}.1_516.94_windows.exe 20 | ;; 21 | cu116) 22 | CUDA_SHORT=11.3 23 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 24 | CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 25 | ;; 26 | cu115) 27 | CUDA_SHORT=11.3 28 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 29 | CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 30 | ;; 31 | cu113) 32 | CUDA_SHORT=11.3 33 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 34 | CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 35 | ;; 36 | *) 37 | echo "Unrecognized CUDA_VERSION=${1}" 38 | exit 1 39 | ;; 40 | esac 41 | 42 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 43 | echo "" 44 | echo "Installing from ${CUDA_FILE}..." 45 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 46 | echo "Done!" 47 | rm -f "${CUDA_FILE}" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .clangd 2 | compile_commands.json 3 | 4 | # Visual Studio Code configs. 5 | .vscode/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | # lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | .DS_Store 113 | 114 | # Direnv config. 115 | .envrc 116 | 117 | # line_profiler 118 | *.lprof 119 | 120 | *build 121 | compile_commands.json 122 | *.dump 123 | 124 | data 125 | results -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "gsplat/cuda/csrc/third_party/glm"] 2 | path = gsplat/cuda/csrc/third_party/glm 3 | url = https://github.com/g-truc/glm.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include gsplat/cuda/csrc * 2 | recursive-include gsplat/cuda_legacy/csrc * 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 | 6 | 7 | 8 |

9 | 10 |
11 |

SplatAD

12 |

13 | Real-Time Lidar and Camera Rendering with 3D Gaussian Splatting for Autonomous Driving 14 |

15 |
16 |
17 | 18 |
19 | 20 | 21 | tyro logo 22 | 23 |
24 | 25 | [Project page](https://research.zenseact.com/publications/splatad/) 26 | 27 |
28 | 29 | 30 | # About 31 | This is the official repository for the CVPR 2025 paper [_SplatAD: Real-Time Lidar and Camera Rendering with 3D Gaussian Splatting for Autonomous Driving_](https://arxiv.org/abs/2411.16816). The code in this repository builds upon the open-source library [gsplat](https://github.com/nerfstudio-project/gsplat), with modifications and extensions designed for autonomous driving data. 32 | 33 | While the code contians all components needed to efficiently render camera and lidar data, the SplatAD-model itself, including dataloading, decoders, etc., will be released through [neurad-studio](https://github.com/georghess/neurad-studio). 34 | 35 | **We welcome all contributions!** 36 | 37 | # Key Features 38 | - Efficient lidar rendering 39 | - Projection to spherical coordinates 40 | - Depth and feature rasterization for a non-linear grid of points 41 | - Rolling shutter compensation for camera and lidar 42 | 43 | 44 | # Installation 45 | Our code introduce no additional dependencies. We thus refer to the original documentation from gsplat for both [installation](https://github.com/nerfstudio-project/gsplat#installation) and [development setup](https://github.com/nerfstudio-project/gsplat/blob/main/docs/DEV.md). 46 | 47 | # Usage 48 | See [`rasterization`](gsplat/rendering.py#L22) and [`lidar_rasterization`]((gsplat/rendering.py#L443)) for entry points to camera and lidar rasterization. 49 | Additionally, we provide example notebooks under [examples](examples) that demonstrate lidar rendering and rolling shutter compensation. 50 | For further examples, check out the [test files](tests). 51 | 52 | 53 | # Built On 54 | - [gsplat](https://github.com/nerfstudio-project/gsplat) - Collaboration friendly library for CUDA accelerated rasterization of Gaussians with python bindings 55 | - [3dgs-deblur](https://github.com/SpectacularAI/3dgs-deblur) - Inspiration for the rolling shutter compensation 56 | 57 | # Citation 58 | 59 | You can find our paper on [arXiv](https://arxiv.org/abs/2411.16816). 60 | 61 | If you use this code or find our paper useful, please consider citing: 62 | 63 | ```bibtex 64 | @article{hess2024splatad, 65 | title={SplatAD: Real-Time Lidar and Camera Rendering with 3D Gaussian Splatting for Autonomous Driving}, 66 | author={Hess, Georg and Lindstr{\"o}m, Carl and Fatemi, Maryam and Petersson, Christoffer and Svensson, Lennart}, 67 | journal={arXiv preprint arXiv:2411.16816}, 68 | year={2024} 69 | } 70 | ``` 71 | 72 | # Contributors 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | \+ [gsplat contributors](https://github.com/nerfstudio-project/gsplat/graphs/contributors) -------------------------------------------------------------------------------- /assets/test_garden.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlinds/splatad/5d3c9bd4b856f142c707a9c0b78161f555bacb10/assets/test_garden.npz -------------------------------------------------------------------------------- /docs/_static/imgs/front-fig-stacked.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlinds/splatad/5d3c9bd4b856f142c707a9c0b78161f555bacb10/docs/_static/imgs/front-fig-stacked.jpg -------------------------------------------------------------------------------- /gsplat/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from .cuda._torch_impl import accumulate 4 | from .cuda._wrapper import ( 5 | fully_fused_projection, 6 | isect_offset_encode, 7 | isect_tiles, 8 | persp_proj, 9 | quat_scale_to_covar_preci, 10 | rasterize_to_indices_in_range, 11 | rasterize_to_pixels, 12 | rasterize_to_points, 13 | spherical_harmonics, 14 | world_to_cam, 15 | map_points_to_lidar_tiles, 16 | points_mapping_offset_encode, 17 | populate_image_from_points, 18 | ) 19 | from .rendering import ( 20 | rasterization, 21 | rasterization_inria_wrapper, 22 | rasterization_legacy_wrapper, 23 | ) 24 | from .version import __version__ 25 | 26 | 27 | def rasterize_gaussians(*args, **kwargs): 28 | from .cuda_legacy._wrapper import rasterize_gaussians 29 | 30 | warnings.warn( 31 | "'rasterize_gaussians is deprecated and will be removed in a future release. " 32 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 33 | DeprecationWarning, 34 | ) 35 | return rasterize_gaussians(*args, **kwargs) 36 | 37 | 38 | def project_gaussians(*args, **kwargs): 39 | from .cuda_legacy._wrapper import project_gaussians 40 | 41 | warnings.warn( 42 | "'project_gaussians is deprecated and will be removed in a future release. " 43 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 44 | DeprecationWarning, 45 | ) 46 | return project_gaussians(*args, **kwargs) 47 | 48 | 49 | def map_gaussian_to_intersects(*args, **kwargs): 50 | from .cuda_legacy._wrapper import map_gaussian_to_intersects 51 | 52 | warnings.warn( 53 | "'map_gaussian_to_intersects is deprecated and will be removed in a future release. " 54 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 55 | DeprecationWarning, 56 | ) 57 | return map_gaussian_to_intersects(*args, **kwargs) 58 | 59 | 60 | def bin_and_sort_gaussians(*args, **kwargs): 61 | from .cuda_legacy._wrapper import bin_and_sort_gaussians 62 | 63 | warnings.warn( 64 | "'bin_and_sort_gaussians is deprecated and will be removed in a future release. " 65 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 66 | DeprecationWarning, 67 | ) 68 | return bin_and_sort_gaussians(*args, **kwargs) 69 | 70 | 71 | def compute_cumulative_intersects(*args, **kwargs): 72 | from .cuda_legacy._wrapper import compute_cumulative_intersects 73 | 74 | warnings.warn( 75 | "'compute_cumulative_intersects is deprecated and will be removed in a future release. " 76 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 77 | DeprecationWarning, 78 | ) 79 | return compute_cumulative_intersects(*args, **kwargs) 80 | 81 | 82 | def compute_cov2d_bounds(*args, **kwargs): 83 | from .cuda_legacy._wrapper import compute_cov2d_bounds 84 | 85 | warnings.warn( 86 | "'compute_cov2d_bounds is deprecated and will be removed in a future release. " 87 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 88 | DeprecationWarning, 89 | ) 90 | return compute_cov2d_bounds(*args, **kwargs) 91 | 92 | 93 | def get_tile_bin_edges(*args, **kwargs): 94 | from .cuda_legacy._wrapper import get_tile_bin_edges 95 | 96 | warnings.warn( 97 | "'get_tile_bin_edges is deprecated and will be removed in a future release. " 98 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.", 99 | DeprecationWarning, 100 | ) 101 | return get_tile_bin_edges(*args, **kwargs) 102 | 103 | 104 | all = [ 105 | "rasterization", 106 | "rasterization_legacy_wrapper", 107 | "rasterization_inria_wrapper", 108 | "spherical_harmonics", 109 | "isect_offset_encode", 110 | "isect_tiles", 111 | "isect_lidar_tiles", 112 | "map_points_to_lidar_tiles", 113 | "points_mapping_offset_encode", 114 | "populate_image_from_points", 115 | "persp_proj", 116 | "fully_fused_projection", 117 | "quat_scale_to_covar_preci", 118 | "rasterize_to_pixels", 119 | "rasterize_to_points", 120 | "world_to_cam", 121 | "accumulate", 122 | "rasterize_to_indices_in_range", 123 | "rasterize_to_indices_in_range_lidar", 124 | "__version__", 125 | # deprecated 126 | "rasterize_gaussians", 127 | "project_gaussians", 128 | "map_gaussian_to_intersects", 129 | "bin_and_sort_gaussians", 130 | "compute_cumulative_intersects", 131 | "compute_cov2d_bounds", 132 | "get_tile_bin_edges", 133 | ] 134 | -------------------------------------------------------------------------------- /gsplat/_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def load_test_data( 10 | data_path: Optional[str] = None, 11 | device="cuda", 12 | scene_crop: Tuple[float, float, float, float, float, float] = (-2, -2, -2, 2, 2, 2), 13 | scene_grid: int = 1, 14 | ): 15 | """Load the test data.""" 16 | assert scene_grid % 2 == 1, "scene_grid must be odd" 17 | 18 | if data_path is None: 19 | data_path = os.path.join(os.path.dirname(__file__), "../assets/test_garden.npz") 20 | data = np.load(data_path) 21 | height, width = data["height"].item(), data["width"].item() 22 | viewmats = torch.from_numpy(data["viewmats"]).float().to(device) 23 | Ks = torch.from_numpy(data["Ks"]).float().to(device) 24 | means = torch.from_numpy(data["means3d"]).float().to(device) 25 | colors = torch.from_numpy(data["colors"] / 255.0).float().to(device) 26 | C = len(viewmats) 27 | 28 | # crop 29 | aabb = torch.tensor(scene_crop, device=device) 30 | edges = aabb[3:] - aabb[:3] 31 | sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1) 32 | sel = torch.where(sel)[0] 33 | means, colors = means[sel], colors[sel] 34 | 35 | # repeat the scene into a grid (to mimic a large-scale setting) 36 | repeats = scene_grid 37 | gridx, gridy = torch.meshgrid( 38 | [ 39 | torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), 40 | torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), 41 | ], 42 | indexing="ij", 43 | ) 44 | grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(-1, 3) 45 | means = means[None, :, :] + grid[:, None, :] * edges[None, None, :] 46 | means = means.reshape(-1, 3) 47 | colors = colors.repeat(repeats**2, 1) 48 | 49 | # create gaussian attributes 50 | N = len(means) 51 | scales = torch.rand((N, 3), device=device) * 0.02 52 | quats = F.normalize(torch.randn((N, 4), device=device), dim=-1) 53 | opacities = torch.rand((N,), device=device) 54 | 55 | return means, quats, scales, opacities, colors, viewmats, Ks, width, height 56 | -------------------------------------------------------------------------------- /gsplat/compression/__init__.py: -------------------------------------------------------------------------------- 1 | from .png_compression import PngCompression 2 | -------------------------------------------------------------------------------- /gsplat/compression/png_compression.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | from typing import Any, Callable, Dict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from gsplat.compression.sort import sort_splats 12 | from gsplat.utils import inverse_log_transform, log_transform 13 | 14 | 15 | @dataclass 16 | class PngCompression: 17 | """Uses quantization and sorting to compress splats into PNG files and uses 18 | K-means clustering to compress the spherical harmonic coefficents. 19 | 20 | .. warning:: 21 | This class requires the `imageio `_, 22 | `plas `_ 23 | and `torchpq `_ packages to be installed. 24 | 25 | .. warning:: 26 | This class might throw away a few lowest opacities splats if the number of 27 | splats is not a square number. 28 | 29 | .. note:: 30 | The splats parameters are expected to be pre-activation values. It expects 31 | the following fields in the splats dictionary: "means", "scales", "quats", 32 | "opacities", "sh0", "shN". More fields can be added to the dictionary, but 33 | they will only be compressed using NPZ compression. 34 | 35 | References: 36 | - `Compact 3D Scene Representation via Self-Organizing Gaussian Grids `_ 37 | - `Making Gaussian Splats more smaller `_ 38 | 39 | Args: 40 | use_sort (bool, optional): Whether to sort splats before compression. Defaults to True. 41 | verbose (bool, optional): Whether to print verbose information. Default to True. 42 | """ 43 | 44 | use_sort: bool = True 45 | verbose: bool = True 46 | 47 | def _get_compress_fn(self, param_name: str) -> Callable: 48 | compress_fn_map = { 49 | "means": _compress_png_16bit, 50 | "scales": _compress_png, 51 | "quats": _compress_png, 52 | "opacities": _compress_png, 53 | "sh0": _compress_png, 54 | "shN": _compress_kmeans, 55 | } 56 | if param_name in compress_fn_map: 57 | return compress_fn_map[param_name] 58 | else: 59 | return _compress_npz 60 | 61 | def _get_decompress_fn(self, param_name: str) -> Callable: 62 | decompress_fn_map = { 63 | "means": _decompress_png_16bit, 64 | "scales": _decompress_png, 65 | "quats": _decompress_png, 66 | "opacities": _decompress_png, 67 | "sh0": _decompress_png, 68 | "shN": _decompress_kmeans, 69 | } 70 | if param_name in decompress_fn_map: 71 | return decompress_fn_map[param_name] 72 | else: 73 | return _decompress_npz 74 | 75 | def compress(self, compress_dir: str, splats: Dict[str, Tensor]) -> None: 76 | """Run compression 77 | 78 | Args: 79 | compress_dir (str): directory to save compressed files 80 | splats (Dict[str, Tensor]): Gaussian splats to compress 81 | """ 82 | 83 | # Param-specific preprocessing 84 | splats["means"] = log_transform(splats["means"]) 85 | splats["quats"] = F.normalize(splats["quats"], dim=-1) 86 | 87 | n_gs = len(splats["means"]) 88 | n_sidelen = int(n_gs**0.5) 89 | n_crop = n_gs - n_sidelen**2 90 | if n_crop != 0: 91 | splats = _crop_n_splats(splats, n_crop) 92 | print( 93 | f"Warning: Number of Gaussians was not square. Removed {n_crop} Gaussians." 94 | ) 95 | 96 | if self.use_sort: 97 | splats = sort_splats(splats) 98 | 99 | meta = {} 100 | for param_name in splats.keys(): 101 | compress_fn = self._get_compress_fn(param_name) 102 | kwargs = { 103 | "n_sidelen": n_sidelen, 104 | "verbose": self.verbose, 105 | } 106 | meta[param_name] = compress_fn( 107 | compress_dir, param_name, splats[param_name], **kwargs 108 | ) 109 | 110 | with open(os.path.join(compress_dir, "meta.json"), "w") as f: 111 | json.dump(meta, f) 112 | 113 | def decompress(self, compress_dir: str) -> Dict[str, Tensor]: 114 | """Run decompression 115 | 116 | Args: 117 | compress_dir (str): directory that contains compressed files 118 | 119 | Returns: 120 | Dict[str, Tensor]: decompressed Gaussian splats 121 | """ 122 | with open(os.path.join(compress_dir, "meta.json"), "r") as f: 123 | meta = json.load(f) 124 | 125 | splats = {} 126 | for param_name, param_meta in meta.items(): 127 | decompress_fn = self._get_decompress_fn(param_name) 128 | splats[param_name] = decompress_fn(compress_dir, param_name, param_meta) 129 | 130 | # Param-specific postprocessing 131 | splats["means"] = inverse_log_transform(splats["means"]) 132 | return splats 133 | 134 | 135 | def _crop_n_splats(splats: Dict[str, Tensor], n_crop: int) -> Dict[str, Tensor]: 136 | opacities = splats["opacities"] 137 | keep_indices = torch.argsort(opacities, descending=True)[:-n_crop] 138 | for k, v in splats.items(): 139 | splats[k] = v[keep_indices] 140 | return splats 141 | 142 | 143 | def _compress_png( 144 | compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs 145 | ) -> Dict[str, Any]: 146 | """Compress parameters with 8-bit quantization and lossless PNG compression. 147 | 148 | Args: 149 | compress_dir (str): compression directory 150 | param_name (str): parameter field name 151 | params (Tensor): parameters 152 | n_sidelen (int): image side length 153 | 154 | Returns: 155 | Dict[str, Any]: metadata 156 | """ 157 | import imageio.v2 as imageio 158 | 159 | if torch.numel == 0: 160 | meta = { 161 | "shape": list(params.shape), 162 | "dtype": str(params.dtype).split(".")[1], 163 | } 164 | return meta 165 | 166 | grid = params.reshape((n_sidelen, n_sidelen, -1)) 167 | mins = torch.amin(grid, dim=(0, 1)) 168 | maxs = torch.amax(grid, dim=(0, 1)) 169 | grid_norm = (grid - mins) / (maxs - mins) 170 | img_norm = grid_norm.detach().cpu().numpy() 171 | 172 | img = (img_norm * (2**8 - 1)).round().astype(np.uint8) 173 | img = img.squeeze() 174 | imageio.imwrite(os.path.join(compress_dir, f"{param_name}.png"), img) 175 | 176 | meta = { 177 | "shape": list(params.shape), 178 | "dtype": str(params.dtype).split(".")[1], 179 | "mins": mins.tolist(), 180 | "maxs": maxs.tolist(), 181 | } 182 | return meta 183 | 184 | 185 | def _decompress_png(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor: 186 | """Decompress parameters from PNG file. 187 | 188 | Args: 189 | compress_dir (str): compression directory 190 | param_name (str): parameter field name 191 | meta (Dict[str, Any]): metadata 192 | 193 | Returns: 194 | Tensor: parameters 195 | """ 196 | import imageio.v2 as imageio 197 | 198 | if not np.all(meta["shape"]): 199 | params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) 200 | return meta 201 | 202 | img = imageio.imread(os.path.join(compress_dir, f"{param_name}.png")) 203 | img_norm = img / (2**8 - 1) 204 | 205 | grid_norm = torch.tensor(img_norm) 206 | mins = torch.tensor(meta["mins"]) 207 | maxs = torch.tensor(meta["maxs"]) 208 | grid = grid_norm * (maxs - mins) + mins 209 | 210 | params = grid.reshape(meta["shape"]) 211 | params = params.to(dtype=getattr(torch, meta["dtype"])) 212 | return params 213 | 214 | 215 | def _compress_png_16bit( 216 | compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs 217 | ) -> Dict[str, Any]: 218 | """Compress parameters with 16-bit quantization and PNG compression. 219 | 220 | Args: 221 | compress_dir (str): compression directory 222 | param_name (str): parameter field name 223 | params (Tensor): parameters 224 | n_sidelen (int): image side length 225 | 226 | Returns: 227 | Dict[str, Any]: metadata 228 | """ 229 | import imageio.v2 as imageio 230 | 231 | if torch.numel == 0: 232 | meta = { 233 | "shape": list(params.shape), 234 | "dtype": str(params.dtype).split(".")[1], 235 | } 236 | return meta 237 | 238 | grid = params.reshape((n_sidelen, n_sidelen, -1)) 239 | mins = torch.amin(grid, dim=(0, 1)) 240 | maxs = torch.amax(grid, dim=(0, 1)) 241 | grid_norm = (grid - mins) / (maxs - mins) 242 | img_norm = grid_norm.detach().cpu().numpy() 243 | img = (img_norm * (2**16 - 1)).round().astype(np.uint16) 244 | 245 | img_l = img & 0xFF 246 | img_u = (img >> 8) & 0xFF 247 | imageio.imwrite( 248 | os.path.join(compress_dir, f"{param_name}_l.png"), img_l.astype(np.uint8) 249 | ) 250 | imageio.imwrite( 251 | os.path.join(compress_dir, f"{param_name}_u.png"), img_u.astype(np.uint8) 252 | ) 253 | 254 | meta = { 255 | "shape": list(params.shape), 256 | "dtype": str(params.dtype).split(".")[1], 257 | "mins": mins.tolist(), 258 | "maxs": maxs.tolist(), 259 | } 260 | return meta 261 | 262 | 263 | def _decompress_png_16bit( 264 | compress_dir: str, param_name: str, meta: Dict[str, Any] 265 | ) -> Tensor: 266 | """Decompress parameters from PNG files. 267 | 268 | Args: 269 | compress_dir (str): compression directory 270 | param_name (str): parameter field name 271 | meta (Dict[str, Any]): metadata 272 | 273 | Returns: 274 | Tensor: parameters 275 | """ 276 | import imageio.v2 as imageio 277 | 278 | if not np.all(meta["shape"]): 279 | params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) 280 | return meta 281 | 282 | img_l = imageio.imread(os.path.join(compress_dir, f"{param_name}_l.png")) 283 | img_u = imageio.imread(os.path.join(compress_dir, f"{param_name}_u.png")) 284 | img_u = img_u.astype(np.uint16) 285 | img = (img_u << 8) + img_l 286 | 287 | img_norm = img / (2**16 - 1) 288 | grid_norm = torch.tensor(img_norm) 289 | mins = torch.tensor(meta["mins"]) 290 | maxs = torch.tensor(meta["maxs"]) 291 | grid = grid_norm * (maxs - mins) + mins 292 | 293 | params = grid.reshape(meta["shape"]) 294 | params = params.to(dtype=getattr(torch, meta["dtype"])) 295 | return params 296 | 297 | 298 | def _compress_npz( 299 | compress_dir: str, param_name: str, params: Tensor, **kwargs 300 | ) -> Dict[str, Any]: 301 | """Compress parameters with numpy's NPZ compression.""" 302 | npz_dict = {"arr": params.detach().cpu().numpy()} 303 | save_fp = os.path.join(compress_dir, f"{param_name}.npz") 304 | os.makedirs(os.path.dirname(save_fp), exist_ok=True) 305 | np.savez_compressed(save_fp, **npz_dict) 306 | meta = { 307 | "shape": params.shape, 308 | "dtype": str(params.dtype).split(".")[1], 309 | } 310 | return meta 311 | 312 | 313 | def _decompress_npz(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor: 314 | """Decompress parameters with numpy's NPZ compression.""" 315 | arr = np.load(os.path.join(compress_dir, f"{param_name}.npz"))["arr"] 316 | params = torch.tensor(arr) 317 | params = params.reshape(meta["shape"]) 318 | params = params.to(dtype=getattr(torch, meta["dtype"])) 319 | return params 320 | 321 | 322 | def _compress_kmeans( 323 | compress_dir: str, 324 | param_name: str, 325 | params: Tensor, 326 | n_clusters: int = 65536, 327 | quantization: int = 6, 328 | verbose: bool = True, 329 | **kwargs, 330 | ) -> Dict[str, Any]: 331 | """Run K-means clustering on parameters and save centroids and labels to a npz file. 332 | 333 | .. warning:: 334 | TorchPQ must installed to use K-means clustering. 335 | 336 | Args: 337 | compress_dir (str): compression directory 338 | param_name (str): parameter field name 339 | params (Tensor): parameters to compress 340 | n_clusters (int): number of K-means clusters 341 | quantization (int): number of bits in quantization 342 | verbose (bool, optional): Whether to print verbose information. Default to True. 343 | 344 | Returns: 345 | Dict[str, Any]: metadata 346 | """ 347 | try: 348 | from torchpq.clustering import KMeans 349 | except: 350 | raise ImportError( 351 | "Please install torchpq with 'pip install torchpq' to use K-means clustering" 352 | ) 353 | 354 | if torch.numel == 0: 355 | meta = { 356 | "shape": list(params.shape), 357 | "dtype": str(params.dtype).split(".")[1], 358 | } 359 | return meta 360 | 361 | kmeans = KMeans(n_clusters=n_clusters, distance="manhattan", verbose=verbose) 362 | x = params.reshape(params.shape[0], -1).permute(1, 0).contiguous() 363 | labels = kmeans.fit(x) 364 | labels = labels.detach().cpu().numpy() 365 | centroids = kmeans.centroids.permute(1, 0) 366 | 367 | mins = torch.min(centroids) 368 | maxs = torch.max(centroids) 369 | centroids_norm = (centroids - mins) / (maxs - mins) 370 | centroids_norm = centroids_norm.detach().cpu().numpy() 371 | centroids_quant = ( 372 | (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) 373 | ) 374 | labels = labels.astype(np.uint16) 375 | 376 | npz_dict = { 377 | "centroids": centroids_quant, 378 | "labels": labels, 379 | } 380 | np.savez_compressed(os.path.join(compress_dir, f"{param_name}.npz"), **npz_dict) 381 | meta = { 382 | "shape": list(params.shape), 383 | "dtype": str(params.dtype).split(".")[1], 384 | "mins": mins.tolist(), 385 | "maxs": maxs.tolist(), 386 | "quantization": quantization, 387 | } 388 | return meta 389 | 390 | 391 | def _decompress_kmeans( 392 | compress_dir: str, param_name: str, meta: Dict[str, Any], **kwargs 393 | ) -> Tensor: 394 | """Decompress parameters from K-means compression. 395 | 396 | Args: 397 | compress_dir (str): compression directory 398 | param_name (str): parameter field name 399 | meta (Dict[str, Any]): metadata 400 | 401 | Returns: 402 | Tensor: parameters 403 | """ 404 | if not np.all(meta["shape"]): 405 | params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) 406 | return meta 407 | 408 | npz_dict = np.load(os.path.join(compress_dir, f"{param_name}.npz")) 409 | centroids_quant = npz_dict["centroids"] 410 | labels = npz_dict["labels"] 411 | 412 | centroids_norm = centroids_quant / (2 ** meta["quantization"] - 1) 413 | centroids_norm = torch.tensor(centroids_norm) 414 | mins = torch.tensor(meta["mins"]) 415 | maxs = torch.tensor(meta["maxs"]) 416 | centroids = centroids_norm * (maxs - mins) + mins 417 | 418 | params = centroids[labels] 419 | params = params.reshape(meta["shape"]) 420 | params = params.to(dtype=getattr(torch, meta["dtype"])) 421 | return params 422 | -------------------------------------------------------------------------------- /gsplat/compression/sort.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def sort_splats(splats: Dict[str, Tensor], verbose: bool = True) -> Dict[str, Tensor]: 8 | """Sort splats with Parallel Linear Assignment Sorting from the paper `Compact 3D Scene Representation via 9 | Self-Organizing Gaussian Grids `_. 10 | 11 | .. warning:: 12 | PLAS must installed to use sorting. 13 | 14 | Args: 15 | splats (Dict[str, Tensor]): splats 16 | verbose (bool, optional): Whether to print verbose information. Default to True. 17 | 18 | Returns: 19 | Dict[str, Tensor]: sorted splats 20 | """ 21 | try: 22 | from plas import sort_with_plas 23 | except: 24 | raise ImportError( 25 | "Please install PLAS with 'pip install git+https://github.com/fraunhoferhhi/PLAS.git' to use sorting" 26 | ) 27 | 28 | n_gs = len(splats["means"]) 29 | n_sidelen = int(n_gs**0.5) 30 | assert n_sidelen**2 == n_gs, "Must be a perfect square" 31 | 32 | sort_keys = [k for k in splats if k != "shN"] 33 | params_to_sort = torch.cat([splats[k].reshape(n_gs, -1) for k in sort_keys], dim=-1) 34 | shuffled_indices = torch.randperm( 35 | params_to_sort.shape[0], device=params_to_sort.device 36 | ) 37 | params_to_sort = params_to_sort[shuffled_indices] 38 | grid = params_to_sort.reshape((n_sidelen, n_sidelen, -1)) 39 | _, sorted_indices = sort_with_plas( 40 | grid.permute(2, 0, 1), improvement_break=1e-4, verbose=verbose 41 | ) 42 | sorted_indices = sorted_indices.squeeze().flatten() 43 | sorted_indices = shuffled_indices[sorted_indices] 44 | for k, v in splats.items(): 45 | splats[k] = v[sorted_indices] 46 | return splats 47 | -------------------------------------------------------------------------------- /gsplat/cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlinds/splatad/5d3c9bd4b856f142c707a9c0b78161f555bacb10/gsplat/cuda/__init__.py -------------------------------------------------------------------------------- /gsplat/cuda/_backend.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import shutil 5 | from subprocess import DEVNULL, call 6 | 7 | from rich.console import Console 8 | from torch.utils.cpp_extension import ( 9 | _get_build_directory, 10 | _import_module_from_library, 11 | load, 12 | ) 13 | 14 | PATH = os.path.dirname(os.path.abspath(__file__)) 15 | NO_FAST_MATH = os.getenv("NO_FAST_MATH", "0") == "1" 16 | MAX_JOBS = os.getenv("MAX_JOBS") 17 | need_to_unset_max_jobs = False 18 | if not MAX_JOBS: 19 | need_to_unset_max_jobs = True 20 | os.environ["MAX_JOBS"] = "10" 21 | 22 | 23 | def load_extension( 24 | name, 25 | sources, 26 | extra_cflags=None, 27 | extra_cuda_cflags=None, 28 | extra_ldflags=None, 29 | extra_include_paths=None, 30 | build_directory=None, 31 | ): 32 | """Load a JIT compiled extension.""" 33 | # Make sure the build directory exists. 34 | if build_directory: 35 | os.makedirs(build_directory, exist_ok=True) 36 | 37 | # If the JIT build happens concurrently in multiple processes, 38 | # race conditions can occur when removing the lock file at: 39 | # https://github.com/pytorch/pytorch/blob/e3513fb2af7951ddf725d8c5b6f6d962a053c9da/torch/utils/cpp_extension.py#L1736 40 | # But it's ok so we catch this exception and ignore it. 41 | try: 42 | return load( 43 | name, 44 | sources, 45 | extra_cflags=extra_cflags, 46 | extra_cuda_cflags=extra_cuda_cflags, 47 | extra_ldflags=extra_ldflags, 48 | extra_include_paths=extra_include_paths, 49 | build_directory=build_directory, 50 | ) 51 | except OSError: 52 | # The module should be already compiled 53 | return _import_module_from_library(name, build_directory, True) 54 | 55 | 56 | def cuda_toolkit_available(): 57 | """Check if the nvcc is avaiable on the machine.""" 58 | try: 59 | call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL) 60 | return True 61 | except FileNotFoundError: 62 | return False 63 | 64 | 65 | def cuda_toolkit_version(): 66 | """Get the cuda toolkit version.""" 67 | cuda_home = os.path.join(os.path.dirname(shutil.which("nvcc")), "..") 68 | if os.path.exists(os.path.join(cuda_home, "version.txt")): 69 | with open(os.path.join(cuda_home, "version.txt")) as f: 70 | cuda_version = f.read().strip().split()[-1] 71 | elif os.path.exists(os.path.join(cuda_home, "version.json")): 72 | with open(os.path.join(cuda_home, "version.json")) as f: 73 | cuda_version = json.load(f)["cuda"]["version"] 74 | else: 75 | raise RuntimeError("Cannot find the cuda version.") 76 | return cuda_version 77 | 78 | 79 | _C = None 80 | 81 | try: 82 | # try to import the compiled module (via setup.py) 83 | from gsplat import csrc as _C 84 | except ImportError: 85 | # if failed, try with JIT compilation 86 | if cuda_toolkit_available(): 87 | name = "gsplat_cuda" 88 | build_dir = _get_build_directory(name, verbose=False) 89 | current_dir = os.path.dirname(os.path.abspath(__file__)) 90 | glm_path = os.path.join(current_dir, "csrc", "third_party", "glm") 91 | 92 | extra_include_paths = [os.path.join(PATH, "csrc/"), glm_path] 93 | extra_cflags = ["-O3"] 94 | if NO_FAST_MATH: 95 | extra_cuda_cflags = ["-O3"] 96 | else: 97 | extra_cuda_cflags = ["-O3", "--use_fast_math"] 98 | sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list( 99 | glob.glob(os.path.join(PATH, "csrc/*.cpp")) 100 | ) 101 | 102 | # If JIT is interrupted it might leave a lock in the build directory. 103 | # We dont want it to exist in any case. 104 | try: 105 | os.remove(os.path.join(build_dir, "lock")) 106 | except OSError: 107 | pass 108 | 109 | if os.path.exists(os.path.join(build_dir, "gsplat_cuda.so")) or os.path.exists( 110 | os.path.join(build_dir, "gsplat_cuda.lib") 111 | ): 112 | # If the build exists, we assume the extension has been built 113 | # and we can load it. 114 | _C = load_extension( 115 | name=name, 116 | sources=sources, 117 | extra_cflags=extra_cflags, 118 | extra_cuda_cflags=extra_cuda_cflags, 119 | extra_include_paths=extra_include_paths, 120 | build_directory=build_dir, 121 | ) 122 | else: 123 | # Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck 124 | # if the build directory exists with a lock file in it. 125 | shutil.rmtree(build_dir) 126 | with Console().status( 127 | f"[bold yellow]gsplat: Setting up CUDA with MAX_JOBS={os.environ['MAX_JOBS']} (This may take a few minutes the first time)", 128 | spinner="bouncingBall", 129 | ): 130 | _C = load_extension( 131 | name=name, 132 | sources=sources, 133 | extra_cflags=extra_cflags, 134 | extra_cuda_cflags=extra_cuda_cflags, 135 | extra_include_paths=extra_include_paths, 136 | build_directory=build_dir, 137 | ) 138 | 139 | else: 140 | Console().print( 141 | "[yellow]gsplat: No CUDA toolkit found. gsplat will be disabled.[/yellow]" 142 | ) 143 | 144 | if need_to_unset_max_jobs: 145 | os.environ.pop("MAX_JOBS") 146 | 147 | 148 | __all__ = ["_C"] 149 | -------------------------------------------------------------------------------- /gsplat/cuda/csrc/ext.cpp: -------------------------------------------------------------------------------- 1 | #include "bindings.h" 2 | #include 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("compute_sh_fwd", &compute_sh_fwd_tensor); 6 | m.def("compute_sh_bwd", &compute_sh_bwd_tensor); 7 | 8 | m.def("quat_scale_to_covar_preci_fwd", &quat_scale_to_covar_preci_fwd_tensor); 9 | m.def("quat_scale_to_covar_preci_bwd", &quat_scale_to_covar_preci_bwd_tensor); 10 | 11 | m.def("persp_proj_fwd", &persp_proj_fwd_tensor); 12 | m.def("persp_proj_bwd", &persp_proj_bwd_tensor); 13 | 14 | m.def("lidar_proj_fwd", &lidar_proj_fwd_tensor); 15 | m.def("lidar_proj_bwd", &lidar_proj_bwd_tensor); 16 | 17 | m.def("world_to_cam_fwd", &world_to_cam_fwd_tensor); 18 | m.def("world_to_cam_bwd", &world_to_cam_bwd_tensor); 19 | 20 | m.def("compute_pix_velocity_fwd", &compute_pix_velocity_fwd_tensor); 21 | m.def("compute_pix_velocity_bwd", &compute_pix_velocity_bwd_tensor); 22 | 23 | m.def("compute_lidar_velocity_fwd", &compute_lidar_velocity_fwd_tensor); 24 | m.def("compute_lidar_velocity_bwd", &compute_lidar_velocity_bwd_tensor); 25 | 26 | m.def("fully_fused_projection_fwd", &fully_fused_projection_fwd_tensor); 27 | m.def("fully_fused_projection_bwd", &fully_fused_projection_bwd_tensor); 28 | 29 | m.def("fully_fused_lidar_projection_fwd", &fully_fused_lidar_projection_fwd_tensor); 30 | m.def("fully_fused_lidar_projection_bwd", &fully_fused_lidar_projection_bwd_tensor); 31 | 32 | m.def("isect_tiles", &isect_tiles_tensor); 33 | m.def("isect_lidar_tiles", &isect_lidar_tiles_tensor); 34 | m.def("isect_offset_encode", &isect_offset_encode_tensor); 35 | 36 | m.def("map_points_to_lidar_tiles", &map_points_to_lidar_tiles_tensor); 37 | m.def("points_mapping_offset_encode", &points_mapping_offset_encode_tensor); 38 | m.def("populate_image_from_points", &populate_image_from_points_tensor); 39 | 40 | m.def("rasterize_to_pixels_fwd", &rasterize_to_pixels_fwd_tensor); 41 | m.def("rasterize_to_pixels_bwd", &rasterize_to_pixels_bwd_tensor); 42 | 43 | m.def("rasterize_to_points_fwd", &rasterize_to_points_fwd_tensor); 44 | m.def("rasterize_to_points_bwd", &rasterize_to_points_bwd_tensor); 45 | 46 | m.def("rasterize_to_indices_in_range", &rasterize_to_indices_in_range_tensor); 47 | m.def("rasterize_to_indices_in_range_lidar", &rasterize_to_indices_in_range_lidar_tensor); 48 | 49 | // packed version 50 | m.def("fully_fused_projection_packed_fwd", &fully_fused_projection_packed_fwd_tensor); 51 | m.def("fully_fused_projection_packed_bwd", &fully_fused_projection_packed_bwd_tensor); 52 | 53 | m.def("compute_relocation", &compute_relocation_tensor); 54 | } -------------------------------------------------------------------------------- /gsplat/cuda/csrc/helpers.cuh: -------------------------------------------------------------------------------- 1 | #ifndef GSPLAT_CUDA_HELPERS_H 2 | #define GSPLAT_CUDA_HELPERS_H 3 | 4 | #include "third_party/glm/glm/glm.hpp" 5 | #include "third_party/glm/glm/gtc/type_ptr.hpp" 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #define PRAGMA_UNROLL _Pragma("unroll") 12 | 13 | #define RAD_TO_DEG 57.2957795131f 14 | 15 | namespace cg = cooperative_groups; 16 | 17 | template 18 | inline __device__ void warpSum(T *val, WarpT &warp) { 19 | PRAGMA_UNROLL 20 | for (uint32_t i = 0; i < DIM; i++) { 21 | val[i] = cg::reduce(warp, val[i], cg::plus()); 22 | } 23 | } 24 | 25 | template inline __device__ void warpSum(float3 &val, WarpT &warp) { 26 | val.x = cg::reduce(warp, val.x, cg::plus()); 27 | val.y = cg::reduce(warp, val.y, cg::plus()); 28 | val.z = cg::reduce(warp, val.z, cg::plus()); 29 | } 30 | 31 | template inline __device__ void warpSum(float2 &val, WarpT &warp) { 32 | val.x = cg::reduce(warp, val.x, cg::plus()); 33 | val.y = cg::reduce(warp, val.y, cg::plus()); 34 | } 35 | 36 | template inline __device__ void warpSum(float &val, WarpT &warp) { 37 | val = cg::reduce(warp, val, cg::plus()); 38 | } 39 | 40 | template inline __device__ void warpSum(glm::vec4 &val, WarpT &warp) { 41 | val.x = cg::reduce(warp, val.x, cg::plus()); 42 | val.y = cg::reduce(warp, val.y, cg::plus()); 43 | val.z = cg::reduce(warp, val.z, cg::plus()); 44 | val.w = cg::reduce(warp, val.w, cg::plus()); 45 | } 46 | 47 | template inline __device__ void warpSum(glm::vec3 &val, WarpT &warp) { 48 | val.x = cg::reduce(warp, val.x, cg::plus()); 49 | val.y = cg::reduce(warp, val.y, cg::plus()); 50 | val.z = cg::reduce(warp, val.z, cg::plus()); 51 | } 52 | 53 | template inline __device__ void warpSum(glm::vec2 &val, WarpT &warp) { 54 | val.x = cg::reduce(warp, val.x, cg::plus()); 55 | val.y = cg::reduce(warp, val.y, cg::plus()); 56 | } 57 | 58 | template inline __device__ void warpSum(glm::mat4 &val, WarpT &warp) { 59 | warpSum(val[0], warp); 60 | warpSum(val[1], warp); 61 | warpSum(val[2], warp); 62 | warpSum(val[3], warp); 63 | } 64 | 65 | template inline __device__ void warpSum(glm::mat3 &val, WarpT &warp) { 66 | warpSum(val[0], warp); 67 | warpSum(val[1], warp); 68 | warpSum(val[2], warp); 69 | } 70 | 71 | template inline __device__ void warpSum(glm::mat2 &val, WarpT &warp) { 72 | warpSum(val[0], warp); 73 | warpSum(val[1], warp); 74 | } 75 | 76 | template inline __device__ void warpMax(float &val, WarpT &warp) { 77 | val = cg::reduce(warp, val, cg::greater()); 78 | } 79 | 80 | inline __device__ void compute_pix_velocity( 81 | const glm::vec3 p_view, 82 | const glm::vec3 lin_vel, 83 | const glm::vec3 ang_vel, 84 | const glm::vec3 vel_view, 85 | const float fx, 86 | const float fy, 87 | const float cx, 88 | const float cy, 89 | const uint32_t width, 90 | const uint32_t height, 91 | glm::vec2 &total_vel_pix 92 | ) { 93 | 94 | float x = p_view[0], y = p_view[1], z = p_view[2]; 95 | 96 | float tan_fovx = 0.5f * width / fx; 97 | float tan_fovy = 0.5f * height / fy; 98 | float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; 99 | float lim_x_neg = cx / fx + 0.3f * tan_fovx; 100 | float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; 101 | float lim_y_neg = cy / fy + 0.3f * tan_fovy; 102 | 103 | float rz = 1.f / z; 104 | float rz2 = rz * rz; 105 | float tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); 106 | float ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); 107 | 108 | // mat3x2 is 3 columns x 2 rows. 109 | glm::mat3x2 J = glm::mat3x2(fx * rz, 0.f, // 1st column 110 | 0.f, fy * rz, // 2nd column 111 | -fx * tx * rz2, -fy * ty * rz2 // 3rd column 112 | ); 113 | 114 | glm::vec3 rot_part = glm::cross(ang_vel, p_view); 115 | glm::vec3 total_vel = lin_vel + rot_part - vel_view; 116 | // negative sign: move points to the opposite direction as the camera 117 | total_vel_pix = -J * total_vel; 118 | } 119 | 120 | inline __device__ void compute_lidar_velocity( 121 | const glm::vec3 p_view, 122 | const glm::vec3 lin_vel, 123 | const glm::vec3 ang_vel, 124 | const glm::vec3 vel_view, 125 | glm::mat3 &J, 126 | glm::vec3 &total_vel_pix 127 | ) { 128 | glm::vec3 rot_part = glm::cross(ang_vel, p_view); 129 | glm::vec3 total_vel = lin_vel + rot_part - vel_view; 130 | 131 | if (glm::length(J[0]) == 0) { 132 | const float x2 = p_view.x * p_view.x; 133 | const float y2 = p_view.y * p_view.y; 134 | const float z2 = p_view.z * p_view.z; 135 | const float r2 = x2 + y2 + z2; 136 | const float rinv = rsqrtf(r2); 137 | const float sqrtx2y2 = hypotf(p_view.x, p_view.y); 138 | const float sqrtx2y2_inv = rhypotf(p_view.x, p_view.y); 139 | const float xz = p_view.x * p_view.z; 140 | const float yz = p_view.y * p_view.z; 141 | const float r2sqrtx2y2_inv = 1.f / (r2) * sqrtx2y2_inv; 142 | 143 | // column major, mat3x2 is 3 columns x 2 rows. 144 | J = glm::mat3( 145 | -p_view.y / (x2 + y2) * RAD_TO_DEG, -xz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.x * rinv, // 1st column 146 | p_view.x / (x2 + y2) * RAD_TO_DEG, -yz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.y * rinv,// 2nd column 147 | 0.f , sqrtx2y2 / r2 * RAD_TO_DEG, p_view.z * rinv // 3rd column 148 | ); 149 | } 150 | 151 | // negative sign: move points to the opposite direction as the camera 152 | total_vel_pix = -J * total_vel; 153 | } 154 | 155 | inline __device__ void compute_and_sum_pix_velocity_vjp( 156 | const glm::vec3 p_view, 157 | const glm::vec3 lin_vel, 158 | const glm::vec3 ang_vel, 159 | const glm::vec3 vel_view, 160 | const float fx, 161 | const float fy, 162 | const float cx, 163 | const float cy, 164 | const uint32_t width, 165 | const uint32_t height, 166 | const glm::vec2 v_pix_velocity, 167 | glm::vec3 &v_p_view_accumulator, 168 | glm::vec3 &v_vel_view) 169 | { 170 | float x = p_view[0], y = p_view[1], z = p_view[2]; 171 | 172 | float tan_fovx = 0.5f * width / fx; 173 | float tan_fovy = 0.5f * height / fy; 174 | float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; 175 | float lim_x_neg = cx / fx + 0.3f * tan_fovx; 176 | float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; 177 | float lim_y_neg = cy / fy + 0.3f * tan_fovy; 178 | 179 | float rz = 1.f / z; 180 | float rz2 = rz * rz; 181 | float tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); 182 | float ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); 183 | 184 | // mat3x2 is 3 columns x 2 rows. 185 | glm::mat3x2 J = glm::mat3x2(fx * rz, 0.f, // 1st column 186 | 0.f, fy * rz, // 2nd column 187 | -fx * tx * rz2, -fy * ty * rz2 // 3rd column 188 | ); 189 | 190 | glm::vec3 rot_part = glm::cross(ang_vel, p_view); 191 | glm::vec3 total_vel = lin_vel + rot_part - vel_view; 192 | 193 | glm::mat3x2 dJ_dz = glm::mat3x2( 194 | -fx * rz2, 195 | 0.f, 196 | 0.f, 197 | -fy * rz2, 198 | 2.f * fx * tx * rz2 * rz, 199 | 2.f * fy * ty * rz2 * rz 200 | ); 201 | 202 | if (x * rz <= lim_x_pos && x * rz >= -lim_x_neg) { 203 | v_p_view_accumulator.x += v_pix_velocity.x * fx * rz2 * total_vel.z; //-glm::dot(v_pix_velocity, dJ_dx * total_vel); 204 | } else { 205 | v_p_view_accumulator.z += v_pix_velocity.x * fx * rz2 * rz * tx * total_vel.z; //-glm::dot(v_pix_velocity, dJ_dx * rz * tx * total_vel); 206 | } 207 | if (y * rz <= lim_y_pos && y * rz >= -lim_y_neg) { 208 | v_p_view_accumulator.y += v_pix_velocity.y * fy * rz2 * total_vel.z; //glm::dot(v_pix_velocity, dJ_dy * total_vel); 209 | } else { 210 | v_p_view_accumulator.z += v_pix_velocity.y * fy * rz2 * rz * ty * total_vel.z; // glm::dot(v_pix_velocity, dJ_dy * rz * ty * total_vel); 211 | } 212 | v_p_view_accumulator.z -= glm::dot(v_pix_velocity, dJ_dz * total_vel); 213 | 214 | glm::vec3 v_rot_part = -glm::transpose(J) * v_pix_velocity; // = v_total_vel 215 | 216 | // (v_rot_part^T * cross_prod_matrix(ang_vel))^T 217 | // = cross_prod_matrix(ang_vel)^T * v_rot_part // ... skew-symmetry 218 | // = -cross_prod_matrix(ang_vel) * v_rot_part 219 | // = -cross(ang_vel, v_rot_part) 220 | glm::vec3 v_p_view_rot = -glm::cross(ang_vel, v_rot_part); 221 | 222 | v_p_view_accumulator.x += v_p_view_rot[0]; 223 | v_p_view_accumulator.y += v_p_view_rot[1]; 224 | v_p_view_accumulator.z += v_p_view_rot[2]; 225 | 226 | v_vel_view -= v_rot_part; 227 | } 228 | 229 | inline __device__ void compute_and_sum_lidar_velocity_vjp( 230 | const glm::vec3 p_view, 231 | const glm::vec3 lin_vel, 232 | const glm::vec3 ang_vel, 233 | const glm::vec3 vel_view, 234 | const glm::vec3 v_pix_velocity, 235 | glm::vec3 &v_p_view_accumulator, 236 | glm::vec3 &v_vel_view) 237 | { 238 | glm::vec3 rot_part = glm::cross(ang_vel, p_view); 239 | glm::vec3 total_vel = lin_vel + rot_part - vel_view; 240 | 241 | const float x = p_view.x; 242 | const float y = p_view.y; 243 | const float z = p_view.z; 244 | const float x2 = x * x; 245 | const float y2 = y * y; 246 | const float z2 = z * z; 247 | const float x4 = x2 * x2; 248 | const float y4 = y2 * y2; 249 | const float x2plusy2 = x2 + y2; 250 | const float sqrtx2y2 = hypot(x, y); 251 | const float sqrtx2y2_inv = rhypot(x, y); 252 | const float x2plusy2squared = x2plusy2 * x2plusy2; 253 | const float x2plusy2pow3by2 = sqrtx2y2 * x2plusy2; 254 | const float r2 = x2 + y2 + z2; 255 | const float r4 = r2 * r2; 256 | const float rinv = rsqrtf(r2); 257 | const float r3_inv = 1 / r2 * rinv; 258 | const float xz = x * z; 259 | const float xy = x * y; 260 | const float yz = y * z; 261 | const float r2sqrtx2y2_inv = 1.f / (r2) * sqrtx2y2_inv; 262 | const float xyz = x * y * z; 263 | const float denom1 = 1.f / (x2plusy2pow3by2 * r4); 264 | const float denom2 = 1.f / r4 * sqrtx2y2_inv; 265 | 266 | // column major, mat3x2 is 3 columns x 2 rows. 267 | glm::mat3 J = glm::mat3( 268 | -p_view.y / (x2 + y2) * RAD_TO_DEG, -xz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.x * rinv, // 1st column 269 | p_view.x / (x2 + y2) * RAD_TO_DEG, -yz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.y * rinv, // 2nd column 270 | 0.f, sqrtx2y2 / r2 * RAD_TO_DEG, p_view.z * rinv // 3rd column 271 | ); 272 | 273 | 274 | glm::mat3 dJ_dx = glm::mat3( 275 | 2.f * xy / x2plusy2squared * RAD_TO_DEG , z * (2.f * x4 + x2 * y2 - y2 * (y2 + z2)) * denom1 * RAD_TO_DEG, (y2 + z2) * r3_inv, 276 | (y2 - x2) / x2plusy2squared * RAD_TO_DEG, xyz * (3.f * x2 + 3.f * y2 + z2) * denom1 * RAD_TO_DEG , - xy * r3_inv, 277 | 0.f , -x * (x2 + y2 - z2) * denom2 * RAD_TO_DEG , - xz * r3_inv 278 | ); 279 | 280 | glm::mat3 dJ_dy = glm::mat3( 281 | (y2 - x2) / x2plusy2squared * RAD_TO_DEG, xyz * (3.f * x2 + 3.f * y2 + z2) * denom1 * RAD_TO_DEG , - xy * r3_inv, 282 | -2.f * xy / x2plusy2squared * RAD_TO_DEG, -z * (x4 + x2 * (z2 - y2) - 2.f * y4) * denom1 * RAD_TO_DEG, (x2 + z2) * r3_inv, 283 | 0.f , -y * (x2 + y2 - z2) * denom2 * RAD_TO_DEG , - yz * r3_inv 284 | ); 285 | 286 | glm::mat3 dJ_dz = glm::mat3( 287 | 0.f , -x * (x2 + y2 - z2) * denom2 * RAD_TO_DEG, - xz * r3_inv, 288 | 0.f , -y * (x2 + y2 - z2) * denom2 * RAD_TO_DEG, - yz * r3_inv, 289 | 0.f , -2.f * z * sqrtx2y2 / r4 * RAD_TO_DEG , (x2 + y2) * r3_inv 290 | ); 291 | 292 | v_p_view_accumulator.x -= glm::dot(v_pix_velocity, dJ_dx * total_vel); 293 | v_p_view_accumulator.y -= glm::dot(v_pix_velocity, dJ_dy * total_vel); 294 | v_p_view_accumulator.z -= glm::dot(v_pix_velocity, dJ_dz * total_vel); 295 | 296 | glm::vec3 v_rot_part = -glm::transpose(J) * v_pix_velocity; // = v_total_vel 297 | 298 | // (v_rot_part^T * cross_prod_matrix(ang_vel))^T 299 | // = cross_prod_matrix(ang_vel)^T * v_rot_part // ... skew-symmetry 300 | // = -cross_prod_matrix(ang_vel) * v_rot_part 301 | // = -cross(ang_vel, v_rot_part) 302 | glm::vec3 v_p_view_rot = -glm::cross(ang_vel, v_rot_part); 303 | 304 | v_p_view_accumulator.x += v_p_view_rot[0]; 305 | v_p_view_accumulator.y += v_p_view_rot[1]; 306 | v_p_view_accumulator.z += v_p_view_rot[2]; 307 | 308 | v_vel_view -= v_rot_part; 309 | } 310 | 311 | #endif // GSPLAT_CUDA_HELPERS_H -------------------------------------------------------------------------------- /gsplat/cuda/csrc/relocation.cu: -------------------------------------------------------------------------------- 1 | #include "bindings.h" 2 | 3 | // Equation (9) in "3D Gaussian Splatting as Markov Chain Monte Carlo" 4 | __global__ void compute_relocation_kernel(int N, float *opacities, float *scales, 5 | int *ratios, float *binoms, int n_max, 6 | float *new_opacities, float *new_scales) { 7 | int idx = threadIdx.x + blockIdx.x * blockDim.x; 8 | if (idx >= N) 9 | return; 10 | 11 | int n_idx = ratios[idx]; 12 | float denom_sum = 0.0f; 13 | 14 | // compute new opacity 15 | new_opacities[idx] = 1.0f - powf(1.0f - opacities[idx], 1.0f / n_idx); 16 | 17 | // compute new scale 18 | for (int i = 1; i <= n_idx; ++i) { 19 | for (int k = 0; k <= (i - 1); ++k) { 20 | float bin_coeff = binoms[(i - 1) * n_max + k]; 21 | float term = (pow(-1.0f, k) / sqrt(static_cast(k + 1))) * 22 | pow(new_opacities[idx], k + 1); 23 | denom_sum += (bin_coeff * term); 24 | } 25 | } 26 | float coeff = (opacities[idx] / denom_sum); 27 | for (int i = 0; i < 3; ++i) 28 | new_scales[idx * 3 + i] = coeff * scales[idx * 3 + i]; 29 | } 30 | 31 | std::tuple 32 | compute_relocation_tensor(torch::Tensor &opacities, torch::Tensor &scales, 33 | torch::Tensor &ratios, torch::Tensor &binoms, 34 | const int n_max) { 35 | DEVICE_GUARD(opacities); 36 | CHECK_INPUT(opacities); 37 | CHECK_INPUT(scales); 38 | CHECK_INPUT(ratios); 39 | CHECK_INPUT(binoms); 40 | torch::Tensor new_opacities = torch::empty_like(opacities); 41 | torch::Tensor new_scales = torch::empty_like(scales); 42 | 43 | uint32_t N = opacities.size(0); 44 | if (N) { 45 | at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); 46 | compute_relocation_kernel<<<(N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, 47 | stream>>>( 48 | N, opacities.data_ptr(), scales.data_ptr(), 49 | ratios.data_ptr(), binoms.data_ptr(), n_max, 50 | new_opacities.data_ptr(), new_scales.data_ptr()); 51 | } 52 | return std::make_tuple(new_opacities, new_scales); 53 | } 54 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/__init__.py: -------------------------------------------------------------------------------- 1 | # from typing import Callable 2 | 3 | 4 | # def _make_lazy_cuda_func(name: str) -> Callable: 5 | # def call_cuda(*args, **kwargs): 6 | # # pylint: disable=import-outside-toplevel 7 | # from ._backend import _C 8 | 9 | # return getattr(_C, name)(*args, **kwargs) 10 | 11 | # return call_cuda 12 | 13 | 14 | # nd_rasterize_forward = _make_lazy_cuda_func("nd_rasterize_forward") 15 | # nd_rasterize_backward = _make_lazy_cuda_func("nd_rasterize_backward") 16 | # rasterize_forward = _make_lazy_cuda_func("rasterize_forward") 17 | # rasterize_backward = _make_lazy_cuda_func("rasterize_backward") 18 | # compute_cov2d_bounds = _make_lazy_cuda_func("compute_cov2d_bounds") 19 | # project_gaussians_forward = _make_lazy_cuda_func("project_gaussians_forward") 20 | # project_gaussians_backward = _make_lazy_cuda_func("project_gaussians_backward") 21 | # compute_sh_forward = _make_lazy_cuda_func("compute_sh_forward") 22 | # compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward") 23 | # map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects") 24 | # get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges") 25 | # rasterize_forward = _make_lazy_cuda_func("rasterize_forward") 26 | # nd_rasterize_forward = _make_lazy_cuda_func("nd_rasterize_forward") 27 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/_backend.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import shutil 5 | from subprocess import DEVNULL, call 6 | 7 | from rich.console import Console 8 | from torch.utils.cpp_extension import _get_build_directory, load 9 | 10 | PATH = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | 13 | def cuda_toolkit_available(): 14 | """Check if the nvcc is avaiable on the machine.""" 15 | try: 16 | call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL) 17 | return True 18 | except FileNotFoundError: 19 | return False 20 | 21 | 22 | def cuda_toolkit_version(): 23 | """Get the cuda toolkit version.""" 24 | cuda_home = os.path.join(os.path.dirname(shutil.which("nvcc")), "..") 25 | if os.path.exists(os.path.join(cuda_home, "version.txt")): 26 | with open(os.path.join(cuda_home, "version.txt")) as f: 27 | cuda_version = f.read().strip().split()[-1] 28 | elif os.path.exists(os.path.join(cuda_home, "version.json")): 29 | with open(os.path.join(cuda_home, "version.json")) as f: 30 | cuda_version = json.load(f)["cuda"]["version"] 31 | else: 32 | raise RuntimeError("Cannot find the cuda version.") 33 | return cuda_version 34 | 35 | 36 | _C = None 37 | 38 | try: 39 | # try to import the compiled module (via setup.py) 40 | from gsplat import csrc_legacy as _C 41 | except ImportError: 42 | # if failed, try with JIT compilation 43 | if cuda_toolkit_available(): 44 | name = "gsplat_cuda_legacy" 45 | build_dir = _get_build_directory(name, verbose=False) 46 | extra_include_paths = [os.path.join(PATH, "..", "cuda/", "csrc/")] 47 | extra_cflags = ["-O3"] 48 | extra_cuda_cflags = ["-O3"] 49 | sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list( 50 | glob.glob(os.path.join(PATH, "csrc/*.cpp")) 51 | ) 52 | 53 | # If JIT is interrupted it might leave a lock in the build directory. 54 | # We dont want it to exist in any case. 55 | try: 56 | os.remove(os.path.join(build_dir, "lock")) 57 | except OSError: 58 | pass 59 | 60 | if os.path.exists( 61 | os.path.join(build_dir, "gsplat_cuda_legacy.so") 62 | ) or os.path.exists(os.path.join(build_dir, "gsplat_cuda_legacy.lib")): 63 | # If the build exists, we assume the extension has been built 64 | # and we can load it. 65 | 66 | _C = load( 67 | name=name, 68 | sources=sources, 69 | extra_cflags=extra_cflags, 70 | extra_cuda_cflags=extra_cuda_cflags, 71 | extra_include_paths=extra_include_paths, 72 | ) 73 | else: 74 | # Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck 75 | # if the build directory exists with a lock file in it. 76 | shutil.rmtree(build_dir) 77 | with Console().status( 78 | "[bold yellow]gsplat (legacy): Setting up CUDA (This may take a few minutes the first time)", 79 | spinner="bouncingBall", 80 | ): 81 | _C = load( 82 | name=name, 83 | sources=sources, 84 | extra_cflags=extra_cflags, 85 | extra_cuda_cflags=extra_cuda_cflags, 86 | extra_include_paths=extra_include_paths, 87 | ) 88 | else: 89 | Console().print( 90 | "[yellow]gsplat (legacy): No CUDA toolkit found. gsplat will be disabled.[/yellow]" 91 | ) 92 | 93 | 94 | __all__ = ["_C"] 95 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) # You can adjust the minimum required version 2 | set(CMAKE_CUDA_ARCHITECTURES 70 75 89) # Ti 2080 uses 75. V100 uses 70. RTX 4090 uses 89. 3 | 4 | project(gsplat CXX CUDA) 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_EXTENSIONS OFF) 7 | set(CMAKE_CUDA_STANDARD 17) 8 | 9 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 10 | 11 | # our library library 12 | add_library(gsplat forward.cu backward.cu helpers.cuh) 13 | target_link_libraries(gsplat PUBLIC cuda) 14 | target_include_directories(gsplat PRIVATE 15 | ${PROJECT_SOURCE_DIR}/third_party/glm 16 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 17 | ) 18 | set_target_properties(gsplat PROPERTIES CUDA_ARCHITECTURES "70;75;86") 19 | 20 | # # To add an executable that uses the gsplat library, 21 | # # follow example in the comments for a script `run_forward.cpp` 22 | # # Add the executable 23 | # add_executable(run_forward run_forward.cpp) 24 | 25 | # # Link against CUDA runtime library 26 | # target_link_libraries(run_forward PUBLIC cuda gsplat) 27 | 28 | # # Include directories for the header-only library 29 | # target_include_directories(run_forward PRIVATE 30 | # ${PROJECT_SOURCE_DIR}/third_party/glm 31 | # ) 32 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/backward.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // for f : R(n) -> R(m), J in R(m, n), 6 | // v is cotangent in R(m), e.g. dL/df in R(m), 7 | // compute vjp i.e. vT J -> R(n) 8 | __global__ void project_gaussians_backward_kernel( 9 | const int num_points, 10 | const float3* __restrict__ means3d, 11 | const float3* __restrict__ scales, 12 | const float glob_scale, 13 | const float4* __restrict__ quats, 14 | const float* __restrict__ viewmat, 15 | const float4 intrins, 16 | const dim3 img_size, 17 | const float* __restrict__ cov3d, 18 | const int* __restrict__ radii, 19 | const float3* __restrict__ conics, 20 | const float* __restrict__ compensation, 21 | const float2* __restrict__ v_xy, 22 | const float* __restrict__ v_depth, 23 | const float3* __restrict__ v_conic, 24 | const float* __restrict__ v_compensation, 25 | float3* __restrict__ v_cov2d, 26 | float* __restrict__ v_cov3d, 27 | float3* __restrict__ v_mean3d, 28 | float3* __restrict__ v_scale, 29 | float4* __restrict__ v_quat 30 | ); 31 | 32 | // compute jacobians of output image wrt binned and sorted gaussians 33 | __global__ void nd_rasterize_backward_kernel( 34 | const dim3 tile_bounds, 35 | const dim3 img_size, 36 | const unsigned channels, 37 | const int32_t* __restrict__ gaussians_ids_sorted, 38 | const int2* __restrict__ tile_bins, 39 | const float2* __restrict__ xys, 40 | const float3* __restrict__ conics, 41 | const float* __restrict__ rgbs, 42 | const float* __restrict__ opacities, 43 | const float* __restrict__ background, 44 | const float* __restrict__ final_Ts, 45 | const int* __restrict__ final_index, 46 | const float* __restrict__ v_output, 47 | const float* __restrict__ v_output_alpha, 48 | float2* __restrict__ v_xy, 49 | float2* __restrict__ v_xy_abs, 50 | float3* __restrict__ v_conic, 51 | float* __restrict__ v_rgb, 52 | float* __restrict__ v_opacity 53 | ); 54 | 55 | __global__ void rasterize_backward_kernel( 56 | const dim3 tile_bounds, 57 | const dim3 img_size, 58 | const int32_t* __restrict__ gaussian_ids_sorted, 59 | const int2* __restrict__ tile_bins, 60 | const float2* __restrict__ xys, 61 | const float3* __restrict__ conics, 62 | const float3* __restrict__ rgbs, 63 | const float* __restrict__ opacities, 64 | const float3& __restrict__ background, 65 | const float* __restrict__ final_Ts, 66 | const int* __restrict__ final_index, 67 | const float3* __restrict__ v_output, 68 | const float* __restrict__ v_output_alpha, 69 | float2* __restrict__ v_xy, 70 | float2* __restrict__ v_xy_abs, 71 | float3* __restrict__ v_conic, 72 | float3* __restrict__ v_rgb, 73 | float* __restrict__ v_opacity 74 | ); 75 | 76 | __device__ void project_cov3d_ewa_vjp( 77 | const float3 &mean3d, 78 | const float *cov3d, 79 | const float *viewmat, 80 | const float fx, 81 | const float fy, 82 | const float3 &v_cov2d, 83 | float3 &v_mean3d, 84 | float *v_cov3d 85 | ); 86 | 87 | __device__ void scale_rot_to_cov3d_vjp( 88 | const float3 scale, 89 | const float glob_scale, 90 | const float4 quat, 91 | const float *v_cov3d, 92 | float3 &v_scale, 93 | float4 &v_quat 94 | ); 95 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/bindings.h: -------------------------------------------------------------------------------- 1 | #include "cuda_runtime.h" 2 | #include "forward.cuh" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | #define DEVICE_GUARD(_ten) \ 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); 18 | 19 | std::tuple< 20 | torch::Tensor, // output conics 21 | torch::Tensor> // output radii 22 | compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &A); 23 | 24 | torch::Tensor compute_sh_forward_tensor( 25 | const std::string &method, 26 | unsigned num_points, 27 | unsigned degree, 28 | unsigned degrees_to_use, 29 | torch::Tensor &viewdirs, 30 | torch::Tensor &coeffs 31 | ); 32 | 33 | torch::Tensor compute_sh_backward_tensor( 34 | const std::string &method, 35 | unsigned num_points, 36 | unsigned degree, 37 | unsigned degrees_to_use, 38 | torch::Tensor &viewdirs, 39 | torch::Tensor &v_colors 40 | ); 41 | 42 | std::tuple< 43 | torch::Tensor, 44 | torch::Tensor, 45 | torch::Tensor, 46 | torch::Tensor, 47 | torch::Tensor, 48 | torch::Tensor, 49 | torch::Tensor> 50 | project_gaussians_forward_tensor( 51 | const int num_points, 52 | torch::Tensor &means3d, 53 | torch::Tensor &scales, 54 | const float glob_scale, 55 | torch::Tensor &quats, 56 | torch::Tensor &viewmat, 57 | const float fx, 58 | const float fy, 59 | const float cx, 60 | const float cy, 61 | const unsigned img_height, 62 | const unsigned img_width, 63 | const unsigned block_width, 64 | const float clip_thresh 65 | ); 66 | 67 | std::tuple< 68 | torch::Tensor, 69 | torch::Tensor, 70 | torch::Tensor, 71 | torch::Tensor, 72 | torch::Tensor> 73 | project_gaussians_backward_tensor( 74 | const int num_points, 75 | torch::Tensor &means3d, 76 | torch::Tensor &scales, 77 | const float glob_scale, 78 | torch::Tensor &quats, 79 | torch::Tensor &viewmat, 80 | const float fx, 81 | const float fy, 82 | const float cx, 83 | const float cy, 84 | const unsigned img_height, 85 | const unsigned img_width, 86 | torch::Tensor &cov3d, 87 | torch::Tensor &radii, 88 | torch::Tensor &conics, 89 | torch::Tensor &compensation, 90 | torch::Tensor &v_xy, 91 | torch::Tensor &v_depth, 92 | torch::Tensor &v_conic, 93 | torch::Tensor &v_compensation 94 | ); 95 | 96 | 97 | std::tuple map_gaussian_to_intersects_tensor( 98 | const int num_points, 99 | const int num_intersects, 100 | const torch::Tensor &xys, 101 | const torch::Tensor &depths, 102 | const torch::Tensor &radii, 103 | const torch::Tensor &cum_tiles_hit, 104 | const std::tuple tile_bounds, 105 | const unsigned block_width 106 | ); 107 | 108 | torch::Tensor get_tile_bin_edges_tensor( 109 | int num_intersects, 110 | const torch::Tensor &isect_ids_sorted, 111 | const std::tuple tile_bounds 112 | ); 113 | 114 | std::tuple< 115 | torch::Tensor, 116 | torch::Tensor, 117 | torch::Tensor 118 | > rasterize_forward_tensor( 119 | const std::tuple tile_bounds, 120 | const std::tuple block, 121 | const std::tuple img_size, 122 | const torch::Tensor &gaussian_ids_sorted, 123 | const torch::Tensor &tile_bins, 124 | const torch::Tensor &xys, 125 | const torch::Tensor &conics, 126 | const torch::Tensor &colors, 127 | const torch::Tensor &opacities, 128 | const torch::Tensor &background 129 | ); 130 | 131 | std::tuple< 132 | torch::Tensor, 133 | torch::Tensor, 134 | torch::Tensor 135 | > nd_rasterize_forward_tensor( 136 | const std::tuple tile_bounds, 137 | const std::tuple block, 138 | const std::tuple img_size, 139 | const torch::Tensor &gaussian_ids_sorted, 140 | const torch::Tensor &tile_bins, 141 | const torch::Tensor &xys, 142 | const torch::Tensor &conics, 143 | const torch::Tensor &colors, 144 | const torch::Tensor &opacities, 145 | const torch::Tensor &background 146 | ); 147 | 148 | 149 | std:: 150 | tuple< 151 | torch::Tensor, // dL_dxy 152 | torch::Tensor, // dL_dxy_abs 153 | torch::Tensor, // dL_dconic 154 | torch::Tensor, // dL_dcolors 155 | torch::Tensor // dL_dopacity 156 | > 157 | nd_rasterize_backward_tensor( 158 | const unsigned img_height, 159 | const unsigned img_width, 160 | const unsigned block_width, 161 | const torch::Tensor &gaussians_ids_sorted, 162 | const torch::Tensor &tile_bins, 163 | const torch::Tensor &xys, 164 | const torch::Tensor &conics, 165 | const torch::Tensor &colors, 166 | const torch::Tensor &opacities, 167 | const torch::Tensor &background, 168 | const torch::Tensor &final_Ts, 169 | const torch::Tensor &final_idx, 170 | const torch::Tensor &v_output, // dL_dout_color 171 | const torch::Tensor &v_output_alpha 172 | ); 173 | 174 | std:: 175 | tuple< 176 | torch::Tensor, // dL_dxy 177 | torch::Tensor, // dL_dxy_abs 178 | torch::Tensor, // dL_dconic 179 | torch::Tensor, // dL_dcolors 180 | torch::Tensor // dL_dopacity 181 | > 182 | rasterize_backward_tensor( 183 | const unsigned img_height, 184 | const unsigned img_width, 185 | const unsigned block_width, 186 | const torch::Tensor &gaussians_ids_sorted, 187 | const torch::Tensor &tile_bins, 188 | const torch::Tensor &xys, 189 | const torch::Tensor &conics, 190 | const torch::Tensor &colors, 191 | const torch::Tensor &opacities, 192 | const torch::Tensor &background, 193 | const torch::Tensor &final_Ts, 194 | const torch::Tensor &final_idx, 195 | const torch::Tensor &v_output, // dL_dout_color 196 | const torch::Tensor &v_output_alpha 197 | ); 198 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/config.h: -------------------------------------------------------------------------------- 1 | #define MAX_BLOCK_SIZE ( 16 * 16 ) 2 | #define N_THREADS 256 3 | 4 | #define MAX_REGISTER_CHANNELS 3 5 | 6 | #define CUDA_CALL(x) \ 7 | do { \ 8 | if ((x) != cudaSuccess) { \ 9 | printf( \ 10 | "Error at %s:%d - %s\n", \ 11 | __FILE__, \ 12 | __LINE__, \ 13 | cudaGetErrorString(cudaGetLastError()) \ 14 | ); \ 15 | exit(EXIT_FAILURE); \ 16 | } \ 17 | } while (0) 18 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/ext.cpp: -------------------------------------------------------------------------------- 1 | #include "bindings.h" 2 | #include 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | // auto diff functions 6 | m.def("nd_rasterize_forward", &nd_rasterize_forward_tensor); 7 | m.def("nd_rasterize_backward", &nd_rasterize_backward_tensor); 8 | m.def("rasterize_forward", &rasterize_forward_tensor); 9 | m.def("rasterize_backward", &rasterize_backward_tensor); 10 | m.def("project_gaussians_forward", &project_gaussians_forward_tensor); 11 | m.def("project_gaussians_backward", &project_gaussians_backward_tensor); 12 | m.def("compute_sh_forward", &compute_sh_forward_tensor); 13 | m.def("compute_sh_backward", &compute_sh_backward_tensor); 14 | // utils 15 | m.def("compute_cov2d_bounds", &compute_cov2d_bounds_tensor); 16 | m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor); 17 | m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor); 18 | } 19 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/forward.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // compute the 2d gaussian parameters from 3d gaussian parameters 6 | __global__ void project_gaussians_forward_kernel( 7 | const int num_points, 8 | const float3* __restrict__ means3d, 9 | const float3* __restrict__ scales, 10 | const float glob_scale, 11 | const float4* __restrict__ quats, 12 | const float* __restrict__ viewmat, 13 | const float4 intrins, 14 | const dim3 img_size, 15 | const dim3 tile_bounds, 16 | const unsigned block_width, 17 | const float clip_thresh, 18 | float* __restrict__ covs3d, 19 | float2* __restrict__ xys, 20 | float* __restrict__ depths, 21 | int* __restrict__ radii, 22 | float3* __restrict__ conics, 23 | float* __restrict__ compensation, 24 | int32_t* __restrict__ num_tiles_hit 25 | ); 26 | 27 | // compute output color image from binned and sorted gaussians 28 | __global__ void rasterize_forward( 29 | const dim3 tile_bounds, 30 | const dim3 img_size, 31 | const int32_t* __restrict__ gaussian_ids_sorted, 32 | const int2* __restrict__ tile_bins, 33 | const float2* __restrict__ xys, 34 | const float3* __restrict__ conics, 35 | const float3* __restrict__ colors, 36 | const float* __restrict__ opacities, 37 | float* __restrict__ final_Ts, 38 | int* __restrict__ final_index, 39 | float3* __restrict__ out_img, 40 | const float3& __restrict__ background 41 | ); 42 | 43 | // compute output color image from binned and sorted gaussians 44 | __global__ void nd_rasterize_forward( 45 | const dim3 tile_bounds, 46 | const dim3 img_size, 47 | const unsigned channels, 48 | const int32_t* __restrict__ gaussian_ids_sorted, 49 | const int2* __restrict__ tile_bins, 50 | const float2* __restrict__ xys, 51 | const float3* __restrict__ conics, 52 | const float* __restrict__ colors, 53 | const float* __restrict__ opacities, 54 | float* __restrict__ final_Ts, 55 | int* __restrict__ final_index, 56 | float* __restrict__ out_img, 57 | const float* __restrict__ background 58 | ); 59 | 60 | // device helper to approximate projected 2d cov from 3d mean and cov 61 | __device__ void project_cov3d_ewa( 62 | const float3 &mean3d, 63 | const float *cov3d, 64 | const float *viewmat, 65 | const float fx, 66 | const float fy, 67 | const float tan_fovx, 68 | const float tan_fovy, 69 | float3 &cov2d, 70 | float &comp 71 | ); 72 | 73 | // device helper to get 3D covariance from scale and quat parameters 74 | __device__ void scale_rot_to_cov3d( 75 | const float3 scale, const float glob_scale, const float4 quat, float *cov3d 76 | ); 77 | 78 | __global__ void map_gaussian_to_intersects( 79 | const int num_points, 80 | const float2* __restrict__ xys, 81 | const float* __restrict__ depths, 82 | const int* __restrict__ radii, 83 | const int32_t* __restrict__ cum_tiles_hit, 84 | const dim3 tile_bounds, 85 | const unsigned block_width, 86 | int64_t* __restrict__ isect_ids, 87 | int32_t* __restrict__ gaussian_ids 88 | ); 89 | 90 | __global__ void get_tile_bin_edges( 91 | const int num_intersects, const int64_t* __restrict__ isect_ids_sorted, int2* __restrict__ tile_bins 92 | ); 93 | 94 | __global__ void rasterize_forward( 95 | const dim3 tile_bounds, 96 | const dim3 img_size, 97 | const int32_t* __restrict__ gaussian_ids_sorted, 98 | const int2* __restrict__ tile_bins, 99 | const float2* __restrict__ xys, 100 | const float3* __restrict__ conics, 101 | const float3* __restrict__ colors, 102 | const float* __restrict__ opacities, 103 | float* __restrict__ final_Ts, 104 | int* __restrict__ final_index, 105 | float3* __restrict__ out_img, 106 | const float3& __restrict__ background 107 | ); 108 | 109 | __global__ void nd_rasterize_forward( 110 | const dim3 tile_bounds, 111 | const dim3 img_size, 112 | const unsigned channels, 113 | const int32_t* __restrict__ gaussian_ids_sorted, 114 | const int2* __restrict__ tile_bins, 115 | const float2* __restrict__ xys, 116 | const float3* __restrict__ conics, 117 | const float* __restrict__ colors, 118 | const float* __restrict__ opacities, 119 | float* __restrict__ final_Ts, 120 | int* __restrict__ final_index, 121 | float* __restrict__ out_img, 122 | const float* __restrict__ background 123 | ); 124 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/helpers.cuh: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | #include "third_party/glm/glm/glm.hpp" 3 | #include "third_party/glm/glm/gtc/type_ptr.hpp" 4 | #include 5 | #include 6 | 7 | inline __device__ void get_bbox(const float2 center, const float2 dims, 8 | const dim3 img_size, uint2 &bb_min, uint2 &bb_max) { 9 | // get bounding box with center and dims, within bounds 10 | // bounding box coords returned in tile coords, inclusive min, exclusive max 11 | // clamp between 0 and tile bounds 12 | bb_min.x = min(max(0, (int)(center.x - dims.x)), img_size.x); 13 | bb_max.x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x); 14 | bb_min.y = min(max(0, (int)(center.y - dims.y)), img_size.y); 15 | bb_max.y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y); 16 | } 17 | 18 | inline __device__ void get_tile_bbox(const float2 pix_center, const float pix_radius, 19 | const dim3 tile_bounds, uint2 &tile_min, 20 | uint2 &tile_max, const int block_size) { 21 | // gets gaussian dimensions in tile space, i.e. the span of a gaussian in 22 | // tile_grid (image divided into tiles) 23 | float2 tile_center = {pix_center.x / (float)block_size, 24 | pix_center.y / (float)block_size}; 25 | float2 tile_radius = {pix_radius / (float)block_size, 26 | pix_radius / (float)block_size}; 27 | get_bbox(tile_center, tile_radius, tile_bounds, tile_min, tile_max); 28 | } 29 | 30 | inline __device__ bool compute_cov2d_bounds(const float3 cov2d, float3 &conic, 31 | float &radius) { 32 | // find eigenvalues of 2d covariance matrix 33 | // expects upper triangular values of cov matrix as float3 34 | // then compute the radius and conic dimensions 35 | // the conic is the inverse cov2d matrix, represented here with upper 36 | // triangular values. 37 | float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; 38 | if (det == 0.f) 39 | return false; 40 | float inv_det = 1.f / det; 41 | 42 | // inverse of 2x2 cov2d matrix 43 | conic.x = cov2d.z * inv_det; 44 | conic.y = -cov2d.y * inv_det; 45 | conic.z = cov2d.x * inv_det; 46 | 47 | float b = 0.5f * (cov2d.x + cov2d.z); 48 | float v1 = b + sqrt(max(0.1f, b * b - det)); 49 | float v2 = b - sqrt(max(0.1f, b * b - det)); 50 | // take 3 sigma of covariance 51 | radius = ceil(3.f * sqrt(max(v1, v2))); 52 | return true; 53 | } 54 | 55 | // compute vjp from df/d_conic to df/c_cov2d 56 | inline __device__ void cov2d_to_conic_vjp(const float3 &conic, const float3 &v_conic, 57 | float3 &v_cov2d) { 58 | // conic = inverse cov2d 59 | // df/d_cov2d = -conic * df/d_conic * conic 60 | glm::mat2 X = glm::mat2(conic.x, conic.y, conic.y, conic.z); 61 | glm::mat2 G = glm::mat2(v_conic.x, v_conic.y / 2.f, v_conic.y / 2.f, v_conic.z); 62 | glm::mat2 v_Sigma = -X * G * X; 63 | v_cov2d.x = v_Sigma[0][0]; 64 | v_cov2d.y = v_Sigma[1][0] + v_Sigma[0][1]; 65 | v_cov2d.z = v_Sigma[1][1]; 66 | } 67 | 68 | inline __device__ void cov2d_to_compensation_vjp(const float compensation, 69 | const float3 &conic, 70 | const float v_compensation, 71 | float3 &v_cov2d) { 72 | // comp = sqrt(det(cov2d - 0.3 I) / det(cov2d)) 73 | // conic = inverse(cov2d) 74 | // df / d_cov2d = df / d comp * 0.5 / comp * [ d comp^2 / d cov2d ] 75 | // d comp^2 / d cov2d = (1 - comp^2) * conic - 0.3 I * det(conic) 76 | float inv_det = conic.x * conic.z - conic.y * conic.y; 77 | float one_minus_sqr_comp = 1 - compensation * compensation; 78 | float v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6); 79 | v_cov2d.x += v_sqr_comp * (one_minus_sqr_comp * conic.x - 0.3 * inv_det); 80 | v_cov2d.y += 2 * v_sqr_comp * (one_minus_sqr_comp * conic.y); 81 | v_cov2d.z += v_sqr_comp * (one_minus_sqr_comp * conic.z - 0.3 * inv_det); 82 | } 83 | 84 | // helper for applying R^T * p for a ROW MAJOR 4x3 matrix [R, t], ignoring t 85 | inline __device__ float3 transform_4x3_rot_only_transposed(const float *mat, 86 | const float3 p) { 87 | float3 out = { 88 | mat[0] * p.x + mat[4] * p.y + mat[8] * p.z, 89 | mat[1] * p.x + mat[5] * p.y + mat[9] * p.z, 90 | mat[2] * p.x + mat[6] * p.y + mat[10] * p.z, 91 | }; 92 | return out; 93 | } 94 | 95 | // helper for applying R * p + T, expect mat to be ROW MAJOR 96 | inline __device__ float3 transform_4x3(const float *mat, const float3 p) { 97 | float3 out = { 98 | mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], 99 | mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], 100 | mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], 101 | }; 102 | return out; 103 | } 104 | 105 | // helper to apply 4x4 transform to 3d vector, return homo coords 106 | // expects mat to be ROW MAJOR 107 | inline __device__ float4 transform_4x4(const float *mat, const float3 p) { 108 | float4 out = { 109 | mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], 110 | mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], 111 | mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], 112 | mat[12] * p.x + mat[13] * p.y + mat[14] * p.z + mat[15], 113 | }; 114 | return out; 115 | } 116 | 117 | inline __device__ float2 project_pix(const float2 fxfy, const float3 p_view, 118 | const float2 pp) { 119 | float rw = 1.f / (p_view.z + 1e-6f); 120 | float2 p_proj = {p_view.x * rw, p_view.y * rw}; 121 | float2 p_pix = {p_proj.x * fxfy.x + pp.x, p_proj.y * fxfy.y + pp.y}; 122 | return p_pix; 123 | } 124 | 125 | // given v_xy_pix, get v_xyz 126 | inline __device__ float3 project_pix_vjp(const float2 fxfy, const float3 p_view, 127 | const float2 v_xy) { 128 | float rw = 1.f / (p_view.z + 1e-6f); 129 | float2 v_proj = {fxfy.x * v_xy.x, fxfy.y * v_xy.y}; 130 | float3 v_view = {v_proj.x * rw, v_proj.y * rw, 131 | -(v_proj.x * p_view.x + v_proj.y * p_view.y) * rw * rw}; 132 | return v_view; 133 | } 134 | 135 | inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) { 136 | // quat to rotation matrix 137 | float w = quat.x; 138 | float x = quat.y; 139 | float y = quat.z; 140 | float z = quat.w; 141 | 142 | // glm matrices are column-major 143 | return glm::mat3( 144 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y + w * z), 2.f * (x * z - w * y), 145 | 2.f * (x * y - w * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z + w * x), 146 | 2.f * (x * z + w * y), 2.f * (y * z - w * x), 1.f - 2.f * (x * x + y * y)); 147 | } 148 | 149 | inline __device__ float4 quat_to_rotmat_vjp(const float4 quat, const glm::mat3 v_R) { 150 | float w = quat.x; 151 | float x = quat.y; 152 | float y = quat.z; 153 | float z = quat.w; 154 | 155 | float4 v_quat; 156 | // v_R is COLUMN MAJOR 157 | // w element stored in x field 158 | v_quat.x = 2.f * ( 159 | // v_quat.w = 2.f * ( 160 | x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + 161 | z * (v_R[0][1] - v_R[1][0])); 162 | // x element in y field 163 | v_quat.y = 164 | 2.f * ( 165 | // v_quat.x = 2.f * ( 166 | -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + 167 | z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1])); 168 | // y element in z field 169 | v_quat.z = 170 | 2.f * ( 171 | // v_quat.y = 2.f * ( 172 | x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + 173 | z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2])); 174 | // z element in w field 175 | v_quat.w = 176 | 2.f * ( 177 | // v_quat.z = 2.f * ( 178 | x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - 179 | 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0])); 180 | return v_quat; 181 | } 182 | 183 | inline __device__ glm::mat3 scale_to_mat(const float3 scale, const float glob_scale) { 184 | glm::mat3 S = glm::mat3(1.f); 185 | S[0][0] = glob_scale * scale.x; 186 | S[1][1] = glob_scale * scale.y; 187 | S[2][2] = glob_scale * scale.z; 188 | return S; 189 | } 190 | 191 | // device helper for culling near points 192 | inline __device__ bool clip_near_plane(const float3 p, const float *viewmat, 193 | float3 &p_view, float thresh) { 194 | p_view = transform_4x3(viewmat, p); 195 | if (p_view.z <= thresh) { 196 | return true; 197 | } 198 | return false; 199 | } 200 | -------------------------------------------------------------------------------- /gsplat/cuda_legacy/csrc/sh.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #define CHANNELS 3 4 | namespace cg = cooperative_groups; 5 | 6 | enum class SHType { 7 | Poly, 8 | Fast, 9 | }; 10 | 11 | __device__ __constant__ float SH_C0 = 0.28209479177387814f; 12 | __device__ __constant__ float SH_C1 = 0.4886025119029199f; 13 | __device__ __constant__ float SH_C2[] = { 14 | 1.0925484305920792f, 15 | -1.0925484305920792f, 16 | 0.31539156525252005f, 17 | -1.0925484305920792f, 18 | 0.5462742152960396f}; 19 | __device__ __constant__ float SH_C3[] = { 20 | -0.5900435899266435f, 21 | 2.890611442640554f, 22 | -0.4570457994644658f, 23 | 0.3731763325901154f, 24 | -0.4570457994644658f, 25 | 1.445305721320277f, 26 | -0.5900435899266435f}; 27 | __device__ __constant__ float SH_C4[] = { 28 | 2.5033429417967046f, 29 | -1.7701307697799304, 30 | 0.9461746957575601f, 31 | -0.6690465435572892f, 32 | 0.10578554691520431f, 33 | -0.6690465435572892f, 34 | 0.47308734787878004f, 35 | -1.7701307697799304f, 36 | 0.6258357354491761f}; 37 | 38 | // This function is used in both host and device code 39 | __host__ __device__ unsigned num_sh_bases(const unsigned degree) { 40 | if (degree == 0) 41 | return 1; 42 | if (degree == 1) 43 | return 4; 44 | if (degree == 2) 45 | return 9; 46 | if (degree == 3) 47 | return 16; 48 | return 25; 49 | } 50 | 51 | // Evaluate spherical harmonics bases at unit direction for high orders using approach described by 52 | // Efficient Spherical Harmonic Evaluation, Peter-Pike Sloan, JCGT 2013 53 | // See https://jcgt.org/published/0002/02/06/ for reference implementation 54 | __device__ void sh_coeffs_to_color_fast( 55 | const unsigned degree, 56 | const float3 &viewdir, 57 | const float *coeffs, 58 | float *colors 59 | ) { 60 | for (int c = 0; c < CHANNELS; ++c) { 61 | colors[c] = 0.2820947917738781f * coeffs[c]; 62 | } 63 | if (degree < 1) { 64 | return; 65 | } 66 | 67 | float norm = sqrt( 68 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z 69 | ); 70 | float x = viewdir.x / norm; 71 | float y = viewdir.y / norm; 72 | float z = viewdir.z / norm; 73 | 74 | for (int c = 0; c < CHANNELS; ++c) { 75 | float fTmp0A = 0.48860251190292f; 76 | colors[c] += fTmp0A * 77 | (-y * coeffs[1 * CHANNELS + c] + 78 | z * coeffs[2 * CHANNELS + c] - 79 | x * coeffs[3 * CHANNELS + c]); 80 | } 81 | if (degree < 2) { 82 | return; 83 | } 84 | float z2 = z * z; 85 | 86 | float fTmp0B = -1.092548430592079f * z; 87 | float fTmp1A = 0.5462742152960395f; 88 | float fC1 = x * x - y * y; 89 | float fS1 = 2.f * x * y; 90 | float pSH6 = (0.9461746957575601f * z2 - 0.3153915652525201f); 91 | float pSH7 = fTmp0B * x; 92 | float pSH5 = fTmp0B * y; 93 | float pSH8 = fTmp1A * fC1; 94 | float pSH4 = fTmp1A * fS1; 95 | for (int c = 0; c < CHANNELS; ++c) { 96 | colors[c] += 97 | pSH4 * coeffs[4 * CHANNELS + c] + pSH5 * coeffs[5 * CHANNELS + c] + 98 | pSH6 * coeffs[6 * CHANNELS + c] + pSH7 * coeffs[7 * CHANNELS + c] + 99 | pSH8 * coeffs[8 * CHANNELS + c]; 100 | } 101 | if (degree < 3) { 102 | return; 103 | } 104 | 105 | float fTmp0C = -2.285228997322329f * z2 + 0.4570457994644658f; 106 | float fTmp1B = 1.445305721320277f * z; 107 | float fTmp2A = -0.5900435899266435f; 108 | float fC2 = x * fC1 - y * fS1; 109 | float fS2 = x * fS1 + y * fC1; 110 | float pSH12 = z * (1.865881662950577f * z2 - 1.119528997770346f); 111 | float pSH13 = fTmp0C * x; 112 | float pSH11 = fTmp0C * y; 113 | float pSH14 = fTmp1B * fC1; 114 | float pSH10 = fTmp1B * fS1; 115 | float pSH15 = fTmp2A * fC2; 116 | float pSH9 = fTmp2A * fS2; 117 | for (int c = 0; c < CHANNELS; ++c) { 118 | colors[c] += pSH9 * coeffs[9 * CHANNELS + c] + 119 | pSH10 * coeffs[10 * CHANNELS + c] + 120 | pSH11 * coeffs[11 * CHANNELS + c] + 121 | pSH12 * coeffs[12 * CHANNELS + c] + 122 | pSH13 * coeffs[13 * CHANNELS + c] + 123 | pSH14 * coeffs[14 * CHANNELS + c] + 124 | pSH15 * coeffs[15 * CHANNELS + c]; 125 | } 126 | if (degree < 4) { 127 | return; 128 | } 129 | 130 | float fTmp0D = z * (-4.683325804901025f * z2 + 2.007139630671868f); 131 | float fTmp1C = 3.31161143515146f * z2 - 0.47308734787878f; 132 | float fTmp2B = -1.770130769779931f * z; 133 | float fTmp3A = 0.6258357354491763f; 134 | float fC3 = x * fC2 - y * fS2; 135 | float fS3 = x * fS2 + y * fC2; 136 | float pSH20 = (1.984313483298443f * z * pSH12 - 1.006230589874905f * pSH6); 137 | float pSH21 = fTmp0D * x; 138 | float pSH19 = fTmp0D * y; 139 | float pSH22 = fTmp1C * fC1; 140 | float pSH18 = fTmp1C * fS1; 141 | float pSH23 = fTmp2B * fC2; 142 | float pSH17 = fTmp2B * fS2; 143 | float pSH24 = fTmp3A * fC3; 144 | float pSH16 = fTmp3A * fS3; 145 | for (int c = 0; c < CHANNELS; ++c) { 146 | colors[c] += pSH16 * coeffs[16 * CHANNELS + c] + 147 | pSH17 * coeffs[17 * CHANNELS + c] + 148 | pSH18 * coeffs[18 * CHANNELS + c] + 149 | pSH19 * coeffs[19 * CHANNELS + c] + 150 | pSH20 * coeffs[20 * CHANNELS + c] + 151 | pSH21 * coeffs[21 * CHANNELS + c] + 152 | pSH22 * coeffs[22 * CHANNELS + c] + 153 | pSH23 * coeffs[23 * CHANNELS + c] + 154 | pSH24 * coeffs[24 * CHANNELS + c]; 155 | } 156 | } 157 | 158 | __device__ void sh_coeffs_to_color_fast_vjp( 159 | const unsigned degree, 160 | const float3 &viewdir, 161 | const float *v_colors, 162 | float *v_coeffs 163 | ) { 164 | // Expects v_colors to be len CHANNELS 165 | // and v_coeffs to be num_bases * CHANNELS 166 | #pragma unroll 167 | for (int c = 0; c < CHANNELS; ++c) { 168 | v_coeffs[c] = 0.2820947917738781f * v_colors[c]; 169 | } 170 | if (degree < 1) { 171 | return; 172 | } 173 | float norm = sqrt( 174 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z 175 | ); 176 | float x = viewdir.x / norm; 177 | float y = viewdir.y / norm; 178 | float z = viewdir.z / norm; 179 | 180 | 181 | float fTmp0A = 0.48860251190292f; 182 | #pragma unroll 183 | for (int c = 0; c < CHANNELS; ++c) { 184 | v_coeffs[1 * CHANNELS + c] = -fTmp0A * y * v_colors[c]; 185 | v_coeffs[2 * CHANNELS + c] = fTmp0A * z * v_colors[c]; 186 | v_coeffs[3 * CHANNELS + c] = -fTmp0A * x * v_colors[c]; 187 | } 188 | if (degree < 2) { 189 | return; 190 | } 191 | 192 | float z2 = z * z; 193 | float fTmp0B = -1.092548430592079f * z; 194 | float fTmp1A = 0.5462742152960395f; 195 | float fC1 = x * x - y * y; 196 | float fS1 = 2.f * x * y; 197 | float pSH6 = (0.9461746957575601f * z2 - 0.3153915652525201f); 198 | float pSH7 = fTmp0B * x; 199 | float pSH5 = fTmp0B * y; 200 | float pSH8 = fTmp1A * fC1; 201 | float pSH4 = fTmp1A * fS1; 202 | #pragma unroll 203 | for (int c = 0; c < CHANNELS; ++c) { 204 | v_coeffs[4 * CHANNELS + c] = pSH4 * v_colors[c]; 205 | v_coeffs[5 * CHANNELS + c] = pSH5 * v_colors[c]; 206 | v_coeffs[6 * CHANNELS + c] = pSH6 * v_colors[c]; 207 | v_coeffs[7 * CHANNELS + c] = pSH7 * v_colors[c]; 208 | v_coeffs[8 * CHANNELS + c] = pSH8 * v_colors[c]; 209 | } 210 | if (degree < 3) { 211 | return; 212 | } 213 | 214 | float fTmp0C = -2.285228997322329f * z2 + 0.4570457994644658f; 215 | float fTmp1B = 1.445305721320277f * z; 216 | float fTmp2A = -0.5900435899266435f; 217 | float fC2 = x * fC1 - y * fS1; 218 | float fS2 = x * fS1 + y * fC1; 219 | float pSH12 = z * (1.865881662950577f * z2 - 1.119528997770346f); 220 | float pSH13 = fTmp0C * x; 221 | float pSH11 = fTmp0C * y; 222 | float pSH14 = fTmp1B * fC1; 223 | float pSH10 = fTmp1B * fS1; 224 | float pSH15 = fTmp2A * fC2; 225 | float pSH9 = fTmp2A * fS2; 226 | #pragma unroll 227 | for (int c = 0; c < CHANNELS; ++c) { 228 | v_coeffs[9 * CHANNELS + c] = pSH9 * v_colors[c]; 229 | v_coeffs[10 * CHANNELS + c] = pSH10 * v_colors[c]; 230 | v_coeffs[11 * CHANNELS + c] = pSH11 * v_colors[c]; 231 | v_coeffs[12 * CHANNELS + c] = pSH12 * v_colors[c]; 232 | v_coeffs[13 * CHANNELS + c] = pSH13 * v_colors[c]; 233 | v_coeffs[14 * CHANNELS + c] = pSH14 * v_colors[c]; 234 | v_coeffs[15 * CHANNELS + c] = pSH15 * v_colors[c]; 235 | } 236 | if (degree < 4) { 237 | return; 238 | } 239 | 240 | float fTmp0D = z * (-4.683325804901025f * z2 + 2.007139630671868f); 241 | float fTmp1C = 3.31161143515146f * z2 - 0.47308734787878f; 242 | float fTmp2B = -1.770130769779931f * z; 243 | float fTmp3A = 0.6258357354491763f; 244 | float fC3 = x * fC2 - y * fS2; 245 | float fS3 = x * fS2 + y * fC2; 246 | float pSH20 = (1.984313483298443f * z * pSH12 + -1.006230589874905f * pSH6); 247 | float pSH21 = fTmp0D * x; 248 | float pSH19 = fTmp0D * y; 249 | float pSH22 = fTmp1C * fC1; 250 | float pSH18 = fTmp1C * fS1; 251 | float pSH23 = fTmp2B * fC2; 252 | float pSH17 = fTmp2B * fS2; 253 | float pSH24 = fTmp3A * fC3; 254 | float pSH16 = fTmp3A * fS3; 255 | #pragma unroll 256 | for (int c = 0; c < CHANNELS; ++c) { 257 | v_coeffs[16 * CHANNELS + c] = pSH16 * v_colors[c]; 258 | v_coeffs[17 * CHANNELS + c] = pSH17 * v_colors[c]; 259 | v_coeffs[18 * CHANNELS + c] = pSH18 * v_colors[c]; 260 | v_coeffs[19 * CHANNELS + c] = pSH19 * v_colors[c]; 261 | v_coeffs[20 * CHANNELS + c] = pSH20 * v_colors[c]; 262 | v_coeffs[21 * CHANNELS + c] = pSH21 * v_colors[c]; 263 | v_coeffs[22 * CHANNELS + c] = pSH22 * v_colors[c]; 264 | v_coeffs[23 * CHANNELS + c] = pSH23 * v_colors[c]; 265 | v_coeffs[24 * CHANNELS + c] = pSH24 * v_colors[c]; 266 | } 267 | } 268 | __device__ void sh_coeffs_to_color( 269 | const unsigned degree, 270 | const float3 &viewdir, 271 | const float *coeffs, 272 | float *colors 273 | ) { 274 | // Expects v_colors to be len CHANNELS 275 | // and v_coeffs to be num_bases * CHANNELS 276 | for (int c = 0; c < CHANNELS; ++c) { 277 | colors[c] = SH_C0 * coeffs[c]; 278 | } 279 | if (degree < 1) { 280 | return; 281 | } 282 | 283 | float norm = sqrt( 284 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z 285 | ); 286 | float x = viewdir.x / norm; 287 | float y = viewdir.y / norm; 288 | float z = viewdir.z / norm; 289 | 290 | float xx = x * x; 291 | float xy = x * y; 292 | float xz = x * z; 293 | float yy = y * y; 294 | float yz = y * z; 295 | float zz = z * z; 296 | // expects CHANNELS * num_bases coefficients 297 | // supports up to num_bases = 25 298 | for (int c = 0; c < CHANNELS; ++c) { 299 | colors[c] += SH_C1 * (-y * coeffs[1 * CHANNELS + c] + 300 | z * coeffs[2 * CHANNELS + c] - 301 | x * coeffs[3 * CHANNELS + c]); 302 | if (degree < 2) { 303 | continue; 304 | } 305 | colors[c] += 306 | (SH_C2[0] * xy * coeffs[4 * CHANNELS + c] + 307 | SH_C2[1] * yz * coeffs[5 * CHANNELS + c] + 308 | SH_C2[2] * (2.f * zz - xx - yy) * coeffs[6 * CHANNELS + c] + 309 | SH_C2[3] * xz * coeffs[7 * CHANNELS + c] + 310 | SH_C2[4] * (xx - yy) * coeffs[8 * CHANNELS + c]); 311 | if (degree < 3) { 312 | continue; 313 | } 314 | colors[c] += 315 | (SH_C3[0] * y * (3.f * xx - yy) * coeffs[9 * CHANNELS + c] + 316 | SH_C3[1] * xy * z * coeffs[10 * CHANNELS + c] + 317 | SH_C3[2] * y * (4.f * zz - xx - yy) * coeffs[11 * CHANNELS + c] + 318 | SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy) * 319 | coeffs[12 * CHANNELS + c] + 320 | SH_C3[4] * x * (4.f * zz - xx - yy) * coeffs[13 * CHANNELS + c] + 321 | SH_C3[5] * z * (xx - yy) * coeffs[14 * CHANNELS + c] + 322 | SH_C3[6] * x * (xx - 3.f * yy) * coeffs[15 * CHANNELS + c]); 323 | if (degree < 4) { 324 | continue; 325 | } 326 | colors[c] += 327 | (SH_C4[0] * xy * (xx - yy) * coeffs[16 * CHANNELS + c] + 328 | SH_C4[1] * yz * (3.f * xx - yy) * coeffs[17 * CHANNELS + c] + 329 | SH_C4[2] * xy * (7.f * zz - 1.f) * coeffs[18 * CHANNELS + c] + 330 | SH_C4[3] * yz * (7.f * zz - 3.f) * coeffs[19 * CHANNELS + c] + 331 | SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f) * 332 | coeffs[20 * CHANNELS + c] + 333 | SH_C4[5] * xz * (7.f * zz - 3.f) * coeffs[21 * CHANNELS + c] + 334 | SH_C4[6] * (xx - yy) * (7.f * zz - 1.f) * 335 | coeffs[22 * CHANNELS + c] + 336 | SH_C4[7] * xz * (xx - 3.f * yy) * coeffs[23 * CHANNELS + c] + 337 | SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)) * 338 | coeffs[24 * CHANNELS + c]); 339 | } 340 | } 341 | 342 | __device__ void sh_coeffs_to_color_vjp( 343 | const unsigned degree, 344 | const float3 &viewdir, 345 | const float *v_colors, 346 | float *v_coeffs 347 | ) { 348 | // Expects v_colors to be len CHANNELS 349 | // and v_coeffs to be num_bases * CHANNELS 350 | #pragma unroll 351 | for (int c = 0; c < CHANNELS; ++c) { 352 | v_coeffs[c] = SH_C0 * v_colors[c]; 353 | } 354 | if (degree < 1) { 355 | return; 356 | } 357 | 358 | float norm = sqrt( 359 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z 360 | ); 361 | float x = viewdir.x / norm; 362 | float y = viewdir.y / norm; 363 | float z = viewdir.z / norm; 364 | 365 | float xx = x * x; 366 | float xy = x * y; 367 | float xz = x * z; 368 | float yy = y * y; 369 | float yz = y * z; 370 | float zz = z * z; 371 | 372 | #pragma unroll 373 | for (int c = 0; c < CHANNELS; ++c) { 374 | float v1 = -SH_C1 * y; 375 | float v2 = SH_C1 * z; 376 | float v3 = -SH_C1 * x; 377 | v_coeffs[1 * CHANNELS + c] = v1 * v_colors[c]; 378 | v_coeffs[2 * CHANNELS + c] = v2 * v_colors[c]; 379 | v_coeffs[3 * CHANNELS + c] = v3 * v_colors[c]; 380 | if (degree < 2) { 381 | continue; 382 | } 383 | float v4 = SH_C2[0] * xy; 384 | float v5 = SH_C2[1] * yz; 385 | float v6 = SH_C2[2] * (2.f * zz - xx - yy); 386 | float v7 = SH_C2[3] * xz; 387 | float v8 = SH_C2[4] * (xx - yy); 388 | v_coeffs[4 * CHANNELS + c] = v4 * v_colors[c]; 389 | v_coeffs[5 * CHANNELS + c] = v5 * v_colors[c]; 390 | v_coeffs[6 * CHANNELS + c] = v6 * v_colors[c]; 391 | v_coeffs[7 * CHANNELS + c] = v7 * v_colors[c]; 392 | v_coeffs[8 * CHANNELS + c] = v8 * v_colors[c]; 393 | if (degree < 3) { 394 | continue; 395 | } 396 | float v9 = SH_C3[0] * y * (3.f * xx - yy); 397 | float v10 = SH_C3[1] * xy * z; 398 | float v11 = SH_C3[2] * y * (4.f * zz - xx - yy); 399 | float v12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy); 400 | float v13 = SH_C3[4] * x * (4.f * zz - xx - yy); 401 | float v14 = SH_C3[5] * z * (xx - yy); 402 | float v15 = SH_C3[6] * x * (xx - 3.f * yy); 403 | v_coeffs[9 * CHANNELS + c] = v9 * v_colors[c]; 404 | v_coeffs[10 * CHANNELS + c] = v10 * v_colors[c]; 405 | v_coeffs[11 * CHANNELS + c] = v11 * v_colors[c]; 406 | v_coeffs[12 * CHANNELS + c] = v12 * v_colors[c]; 407 | v_coeffs[13 * CHANNELS + c] = v13 * v_colors[c]; 408 | v_coeffs[14 * CHANNELS + c] = v14 * v_colors[c]; 409 | v_coeffs[15 * CHANNELS + c] = v15 * v_colors[c]; 410 | if (degree < 4) { 411 | continue; 412 | } 413 | float v16 = SH_C4[0] * xy * (xx - yy); 414 | float v17 = SH_C4[1] * yz * (3.f * xx - yy); 415 | float v18 = SH_C4[2] * xy * (7.f * zz - 1.f); 416 | float v19 = SH_C4[3] * yz * (7.f * zz - 3.f); 417 | float v20 = SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f); 418 | float v21 = SH_C4[5] * xz * (7.f * zz - 3.f); 419 | float v22 = SH_C4[6] * (xx - yy) * (7.f * zz - 1.f); 420 | float v23 = SH_C4[7] * xz * (xx - 3.f * yy); 421 | float v24 = SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)); 422 | v_coeffs[16 * CHANNELS + c] = v16 * v_colors[c]; 423 | v_coeffs[17 * CHANNELS + c] = v17 * v_colors[c]; 424 | v_coeffs[18 * CHANNELS + c] = v18 * v_colors[c]; 425 | v_coeffs[19 * CHANNELS + c] = v19 * v_colors[c]; 426 | v_coeffs[20 * CHANNELS + c] = v20 * v_colors[c]; 427 | v_coeffs[21 * CHANNELS + c] = v21 * v_colors[c]; 428 | v_coeffs[22 * CHANNELS + c] = v22 * v_colors[c]; 429 | v_coeffs[23 * CHANNELS + c] = v23 * v_colors[c]; 430 | v_coeffs[24 * CHANNELS + c] = v24 * v_colors[c]; 431 | } 432 | } 433 | 434 | template 435 | __global__ void compute_sh_forward_kernel( 436 | const unsigned num_points, 437 | const unsigned degree, 438 | const unsigned degrees_to_use, 439 | const float3* __restrict__ viewdirs, 440 | const float* __restrict__ coeffs, 441 | float* __restrict__ colors 442 | ) { 443 | unsigned idx = cg::this_grid().thread_rank(); 444 | if (idx >= num_points) { 445 | return; 446 | } 447 | const unsigned num_channels = 3; 448 | unsigned num_bases = num_sh_bases(degree); 449 | unsigned idx_sh = num_bases * num_channels * idx; 450 | unsigned idx_col = num_channels * idx; 451 | 452 | switch (SH_TYPE) 453 | { 454 | case SHType::Poly: 455 | sh_coeffs_to_color( 456 | degrees_to_use, viewdirs[idx], &(coeffs[idx_sh]), &(colors[idx_col]) 457 | ); 458 | break; 459 | case SHType::Fast: 460 | sh_coeffs_to_color_fast( 461 | degrees_to_use, viewdirs[idx], &(coeffs[idx_sh]), &(colors[idx_col]) 462 | ); 463 | break; 464 | } 465 | } 466 | 467 | template 468 | __global__ void compute_sh_backward_kernel( 469 | const unsigned num_points, 470 | const unsigned degree, 471 | const unsigned degrees_to_use, 472 | const float3* __restrict__ viewdirs, 473 | const float* __restrict__ v_colors, 474 | float* __restrict__ v_coeffs 475 | ) { 476 | unsigned idx = cg::this_grid().thread_rank(); 477 | if (idx >= num_points) { 478 | return; 479 | } 480 | const unsigned num_channels = 3; 481 | unsigned num_bases = num_sh_bases(degree); 482 | unsigned idx_sh = num_bases * num_channels * idx; 483 | unsigned idx_col = num_channels * idx; 484 | 485 | switch (SH_TYPE) 486 | { 487 | case SHType::Poly: 488 | sh_coeffs_to_color_vjp( 489 | degrees_to_use, viewdirs[idx], &(v_colors[idx_col]), &(v_coeffs[idx_sh]) 490 | ); 491 | break; 492 | case SHType::Fast: 493 | sh_coeffs_to_color_fast_vjp( 494 | degrees_to_use, viewdirs[idx], &(v_colors[idx_col]), &(v_coeffs[idx_sh]) 495 | ); 496 | break; 497 | } 498 | } 499 | -------------------------------------------------------------------------------- /gsplat/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, List, Optional, Union 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import torch.distributed.nn.functional as distF 7 | from torch import Tensor 8 | 9 | 10 | def all_gather_int32( 11 | world_size: int, value: Union[int, Tensor], device: Optional[torch.device] = None 12 | ) -> List[int]: 13 | """Gather an 32-bit integer from all ranks. 14 | 15 | .. note:: 16 | This implementation is faster than using `torch.distributed.all_gather_object`. 17 | 18 | .. note:: 19 | This function is not differentiable to the input tensor. 20 | 21 | Args: 22 | world_size: The total number of ranks. 23 | value: The integer to gather. Could be a scalar or a tensor. 24 | device: Only required if `value` is a scalar. The device to put the tensor on. 25 | 26 | Returns: 27 | A list of integers, where the i-th element is the value from the i-th rank. 28 | Could be a list of scalars or tensors based on the input `value`. 29 | """ 30 | if world_size == 1: 31 | return [value] 32 | 33 | # move to CUDA 34 | if isinstance(value, int): 35 | assert device is not None, "device is required for scalar input" 36 | value_tensor = torch.tensor(value, dtype=torch.int, device=device) 37 | else: 38 | value_tensor = value 39 | assert value_tensor.is_cuda, "value should be on CUDA" 40 | 41 | # gather 42 | collected = torch.empty( 43 | world_size, dtype=value_tensor.dtype, device=value_tensor.device 44 | ) 45 | dist.all_gather_into_tensor(collected, value_tensor) 46 | 47 | if isinstance(value, int): 48 | # return as list of integers on CPU 49 | return collected.tolist() 50 | else: 51 | # return as list of single-element tensors 52 | return collected.unbind() 53 | 54 | 55 | def all_to_all_int32( 56 | world_size: int, 57 | values: List[Union[int, Tensor]], 58 | device: Optional[torch.device] = None, 59 | ) -> List[int]: 60 | """Exchange 32-bit integers between all ranks in a many-to-many fashion. 61 | 62 | .. note:: 63 | This function is not differentiable to the input tensors. 64 | 65 | Args: 66 | world_size: The total number of ranks. 67 | values: A list of integers to exchange. Could be a list of scalars or tensors. 68 | Should have the same length as `world_size`. 69 | device: Only required if `values` contains scalars. The device to put the tensors on. 70 | 71 | Returns: 72 | A list of integers. Could be a list of scalars or tensors based on the input `values`. 73 | Have the same length as `world_size`. 74 | """ 75 | if world_size == 1: 76 | return values 77 | 78 | assert ( 79 | len(values) == world_size 80 | ), "The length of values should be equal to world_size" 81 | 82 | if any(isinstance(v, int) for v in values): 83 | assert device is not None, "device is required for scalar input" 84 | 85 | # move to CUDA 86 | values_tensor = [ 87 | (torch.tensor(v, dtype=torch.int, device=device) if isinstance(v, int) else v) 88 | for v in values 89 | ] 90 | 91 | # all_to_all 92 | collected = [torch.empty_like(v) for v in values_tensor] 93 | dist.all_to_all(collected, values_tensor) 94 | 95 | # return as a list of integers or tensors, based on the input 96 | return [ 97 | v.item() if isinstance(tensor, int) else v 98 | for v, tensor in zip(collected, values) 99 | ] 100 | 101 | 102 | def all_gather_tensor_list(world_size: int, tensor_list: List[Tensor]) -> List[Tensor]: 103 | """Gather a list of tensors from all ranks. 104 | 105 | .. note:: 106 | This function expects the tensors in the `tensor_list` to have the same shape 107 | and data type across all ranks. 108 | 109 | .. note:: 110 | This function is differentiable to the tensors in `tensor_list`. 111 | 112 | .. note:: 113 | For efficiency, this function internally concatenates the tensors in `tensor_list` 114 | and performs a single gather operation. Thus it requires all tensors in the list 115 | to have the same first-dimension size. 116 | 117 | Args: 118 | world_size: The total number of ranks. 119 | tensor_list: A list of tensors to gather. The size of the first dimension of all 120 | the tensors in this list should be the same. The rest dimensions can be 121 | arbitrary. Shape: [(N, *), (N, *), ...] 122 | 123 | Returns: 124 | A list of tensors gathered from all ranks, where the i-th element is corresponding 125 | to the i-th tensor in `tensor_list`. The returned tensors have the shape 126 | [(N * world_size, *), (N * world_size, *), ...] 127 | 128 | Examples: 129 | 130 | .. code-block:: python 131 | 132 | >>> # on rank 0 133 | >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] 134 | >>> # on rank 1 135 | >>> # tensor_list = [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])] 136 | >>> collected = all_gather_tensor_list(world_rank, world_size, tensor_list) 137 | >>> # on both ranks 138 | >>> # [torch.tensor([1, 2, 3, 7, 8, 9]), torch.tensor([4, 5, 6, 10, 11, 12])] 139 | 140 | """ 141 | if world_size == 1: 142 | return tensor_list 143 | 144 | N = len(tensor_list[0]) 145 | for tensor in tensor_list: 146 | assert len(tensor) == N, "All tensors should have the same first dimension size" 147 | 148 | # concatenate tensors and record their sizes 149 | data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1) 150 | sizes = [t.numel() // N for t in tensor_list] 151 | 152 | if data.requires_grad: 153 | # differentiable gather 154 | collected = distF.all_gather(data) 155 | else: 156 | # non-differentiable gather 157 | collected = [torch.empty_like(data) for _ in range(world_size)] 158 | torch.distributed.all_gather(collected, data) 159 | collected = torch.cat(collected, dim=0) 160 | 161 | # split the collected tensor and reshape to the original shape 162 | out_tensor_tuple = torch.split(collected, sizes, dim=-1) 163 | out_tensor_list = [] 164 | for out_tensor, tensor in zip(out_tensor_tuple, tensor_list): 165 | out_tensor = out_tensor.view(-1, *tensor.shape[1:]) # [N * world_size, *] 166 | out_tensor_list.append(out_tensor) 167 | return out_tensor_list 168 | 169 | 170 | def all_to_all_tensor_list( 171 | world_size: int, 172 | tensor_list: List[Tensor], 173 | splits: List[Union[int, Tensor]], 174 | output_splits: Optional[List[Union[int, Tensor]]] = None, 175 | ) -> List[Tensor]: 176 | """Split and exchange tensors between all ranks in a many-to-many fashion. 177 | 178 | Args: 179 | world_size: The total number of ranks. 180 | tensor_list: A list of tensors to split and exchange. The size of the first 181 | dimension of all the tensors in this list should be the same. The rest 182 | dimensions can be arbitrary. Shape: [(N, *), (N, *), ...] 183 | splits: A list of integers representing the number of elements to send to each 184 | rank. It will be used to split the tensor in the `tensor_list`. 185 | The sum of the elements in this list should be equal to N. The size of this 186 | list should be equal to the `world_size`. 187 | output_splits: Splits of the output tensors. Could be pre-calculated via 188 | `all_to_all_int32(world_size, splits)`. If not provided, it will 189 | be calculated internally. 190 | 191 | Returns: 192 | A list of tensors exchanged between all ranks, where the i-th element is 193 | corresponding to the i-th tensor in `tensor_list`. Note the shape of the 194 | returned tensors might be different from the input tensors, depending on the 195 | splits. 196 | 197 | Examples: 198 | 199 | .. code-block:: python 200 | 201 | >>> # on rank 0 202 | >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] 203 | >>> # splits = [2, 1] 204 | 205 | >>> # on rank 1 206 | >>> # tensor_list = [torch.tensor([7, 8]), torch.tensor([9, 10])] 207 | >>> # splits = [1, 1] 208 | 209 | >>> collected = all_to_all_tensor_list(world_rank, world_size, tensor_list, splits) 210 | 211 | >>> # on rank 0 212 | >>> # [torch.tensor([1, 2, 7]), torch.tensor([4, 5, 9])] 213 | >>> # on rank 1 214 | >>> # [torch.tensor([3, 8]), torch.tensor([6, 10])] 215 | 216 | """ 217 | if world_size == 1: 218 | return tensor_list 219 | 220 | N = len(tensor_list[0]) 221 | for tensor in tensor_list: 222 | assert len(tensor) == N, "All tensors should have the same first dimension size" 223 | 224 | assert ( 225 | len(splits) == world_size 226 | ), "The length of splits should be equal to world_size" 227 | 228 | # concatenate tensors and record their sizes 229 | data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1) 230 | sizes = [t.numel() // N for t in tensor_list] 231 | 232 | # all_to_all 233 | if output_splits is not None: 234 | collected_splits = output_splits 235 | else: 236 | collected_splits = all_to_all_int32(world_size, splits, device=data.device) 237 | collected = [ 238 | torch.empty((l, *data.shape[1:]), dtype=data.dtype, device=data.device) 239 | for l in collected_splits 240 | ] 241 | # torch.split requires tuple of integers 242 | splits = [s.item() if isinstance(s, Tensor) else s for s in splits] 243 | if data.requires_grad: 244 | # differentiable all_to_all 245 | distF.all_to_all(collected, data.split(splits, dim=0)) 246 | else: 247 | # non-differentiable all_to_all 248 | torch.distributed.all_to_all(collected, list(data.split(splits, dim=0))) 249 | collected = torch.cat(collected, dim=0) 250 | 251 | # split the collected tensor and reshape to the original shape 252 | out_tensor_tuple = torch.split(collected, sizes, dim=-1) 253 | out_tensor_list = [] 254 | for out_tensor, tensor in zip(out_tensor_tuple, tensor_list): 255 | out_tensor = out_tensor.view(-1, *tensor.shape[1:]) 256 | out_tensor_list.append(out_tensor) 257 | return out_tensor_list 258 | 259 | 260 | def _find_free_port(): 261 | import socket 262 | 263 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 264 | # Binding to port 0 will cause the OS to find an available port for us 265 | sock.bind(("", 0)) 266 | port = sock.getsockname()[1] 267 | sock.close() 268 | # NOTE: there is still a chance the port could be taken by other processes. 269 | return port 270 | 271 | 272 | def _distributed_worker( 273 | world_rank: int, 274 | world_size: int, 275 | fn: Callable, 276 | args: Any, 277 | local_rank: Optional[int] = None, 278 | verbose: bool = False, 279 | ) -> bool: 280 | if local_rank is None: # single Node 281 | local_rank = world_rank 282 | if verbose: 283 | print("Distributed worker: %d / %d" % (world_rank + 1, world_size)) 284 | distributed = world_size > 1 285 | if distributed: 286 | torch.cuda.set_device(local_rank) 287 | torch.distributed.init_process_group( 288 | backend="nccl", world_size=world_size, rank=world_rank 289 | ) 290 | # Dump collection that participates all ranks. 291 | # This initializes the communicator required by `batch_isend_irecv`. 292 | # See: https://github.com/pytorch/pytorch/pull/74701 293 | _ = [None for _ in range(world_size)] 294 | torch.distributed.all_gather_object(_, 0) 295 | fn(local_rank, world_rank, world_size, args) 296 | if distributed: 297 | torch.distributed.barrier() 298 | torch.distributed.destroy_process_group() 299 | if verbose: 300 | print("Job Done for worker: %d / %d" % (world_rank + 1, world_size)) 301 | return True 302 | 303 | 304 | def cli(fn: Callable, args: Any, verbose: bool = False) -> bool: 305 | """Wrapper to run a function in a distributed environment. 306 | 307 | The function `fn` should have the following signature: 308 | 309 | ```python 310 | def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None: 311 | pass 312 | ``` 313 | 314 | Usage: 315 | 316 | ```python 317 | # Launch with "CUDA_VISIBLE_DEVICES=0,1,2,3 python my_script.py" 318 | if __name__ == "__main__": 319 | cli(fn, None, verbose=True) 320 | ``` 321 | """ 322 | assert torch.cuda.is_available(), "CUDA device is required!" 323 | if "OMPI_COMM_WORLD_SIZE" in os.environ: # multi-node 324 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 325 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) # dist.get_world_size() 326 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) # dist.get_rank() 327 | return _distributed_worker( 328 | world_rank, world_size, fn, args, local_rank, verbose 329 | ) 330 | 331 | world_size = torch.cuda.device_count() 332 | distributed = world_size > 1 333 | 334 | if distributed: 335 | os.environ["MASTER_ADDR"] = "localhost" 336 | os.environ["MASTER_PORT"] = str(_find_free_port()) 337 | process_context = torch.multiprocessing.spawn( 338 | _distributed_worker, 339 | args=(world_size, fn, args, None, verbose), 340 | nprocs=world_size, 341 | join=False, 342 | ) 343 | try: 344 | process_context.join() 345 | except KeyboardInterrupt: 346 | # this is important. 347 | # if we do not explicitly terminate all launched subprocesses, 348 | # they would continue living even after this main process ends, 349 | # eventually making the OD machine unusable! 350 | for i, process in enumerate(process_context.processes): 351 | if process.is_alive(): 352 | if verbose: 353 | print("terminating process " + str(i) + "...") 354 | process.terminate() 355 | process.join() 356 | if verbose: 357 | print("process " + str(i) + " finished") 358 | return True 359 | else: 360 | return _distributed_worker(0, 1, fn=fn, args=args) 361 | -------------------------------------------------------------------------------- /gsplat/profile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import wraps 4 | from typing import Callable, Optional 5 | 6 | import torch 7 | 8 | profiler = {} 9 | 10 | 11 | class timeit(object): 12 | """Profiler that is controled by the TIMEIT environment variable. 13 | 14 | If TIMEIT is set to 1, the profiler will measure the time taken by the decorated function. 15 | 16 | Usage: 17 | 18 | ```python 19 | @timeit() 20 | def my_function(): 21 | pass 22 | 23 | # Or 24 | 25 | with timeit(name="stage1"): 26 | my_function() 27 | 28 | print(profiler) 29 | ``` 30 | """ 31 | 32 | def __init__(self, name: str = "unnamed"): 33 | self.name = name 34 | self.start_time: Optional[float] = None 35 | self.enabled = os.environ.get("TIMEIT", "0") == "1" 36 | 37 | def __enter__(self): 38 | if self.enabled: 39 | torch.cuda.synchronize() 40 | self.start_time = time.perf_counter() 41 | 42 | def __exit__(self, exc_type, exc_val, exc_tb): 43 | if self.enabled: 44 | torch.cuda.synchronize() 45 | end_time = time.perf_counter() 46 | total_time = end_time - self.start_time 47 | if self.name not in profiler: 48 | profiler[self.name] = total_time 49 | else: 50 | profiler[self.name] += total_time 51 | 52 | def __call__(self, f: Callable) -> Callable: 53 | @wraps(f) 54 | def decorated(*args, **kwargs): 55 | with self: 56 | self.name = f.__name__ 57 | return f(*args, **kwargs) 58 | 59 | return decorated 60 | -------------------------------------------------------------------------------- /gsplat/relocation.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .cuda._wrapper import _make_lazy_cuda_func 8 | 9 | 10 | def compute_relocation( 11 | opacities: Tensor, # [N] 12 | scales: Tensor, # [N, 3] 13 | ratios: Tensor, # [N] 14 | binoms: Tensor, # [n_max, n_max] 15 | ) -> Tuple[Tensor, Tensor]: 16 | """Compute new Gaussians from a set of old Gaussians. 17 | 18 | This function interprets the Gaussians as samples from a likelihood distribution. 19 | It uses the old opacities and scales to compute the new opacities and scales. 20 | This is an implementation of the paper 21 | `3D Gaussian Splatting as Markov Chain Monte Carlo `_, 22 | 23 | Args: 24 | opacities: The opacities of the Gaussians. [N] 25 | scales: The scales of the Gaussians. [N, 3] 26 | ratios: The relative frequencies for each of the Gaussians. [N] 27 | binoms: Precomputed lookup table for binomial coefficients used in 28 | Equation 9 in the paper. [n_max, n_max] 29 | 30 | Returns: 31 | A tuple: 32 | 33 | **new_opacities**: The opacities of the new Gaussians. [N] 34 | **new_scales**: The scales of the Gaussians. [N, 3] 35 | """ 36 | 37 | N = opacities.shape[0] 38 | n_max, _ = binoms.shape 39 | assert scales.shape == (N, 3), scales.shape 40 | assert ratios.shape == (N,), ratios.shape 41 | opacities = opacities.contiguous() 42 | scales = scales.contiguous() 43 | ratios.clamp_(min=1, max=n_max) 44 | ratios = ratios.int().contiguous() 45 | 46 | new_opacities, new_scales = _make_lazy_cuda_func("compute_relocation")( 47 | opacities, scales, ratios, binoms, n_max 48 | ) 49 | return new_opacities, new_scales 50 | -------------------------------------------------------------------------------- /gsplat/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Strategy 2 | from .default import DefaultStrategy 3 | from .mcmc import MCMCStrategy 4 | -------------------------------------------------------------------------------- /gsplat/strategy/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Union 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class Strategy: 9 | """Base class for the GS densification strategy. 10 | 11 | This class is an base class that defines the interface for the GS 12 | densification strategy. 13 | """ 14 | 15 | def check_sanity( 16 | self, 17 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 18 | optimizers: Dict[str, torch.optim.Optimizer], 19 | ): 20 | """Sanity check for the parameters and optimizers.""" 21 | assert set(params.keys()) == set(optimizers.keys()), ( 22 | "params and optimizers must have the same keys, " 23 | f"but got {params.keys()} and {optimizers.keys()}" 24 | ) 25 | 26 | for optimizer in optimizers.values(): 27 | assert len(optimizer.param_groups) == 1, ( 28 | "Each optimizer must have exactly one param_group, " 29 | "that cooresponds to each parameter, " 30 | f"but got {len(optimizer.param_groups)}" 31 | ) 32 | 33 | def step_pre_backward( 34 | self, 35 | *args, 36 | **kwargs, 37 | ): 38 | """Callback function to be executed before the `loss.backward()` call.""" 39 | pass 40 | 41 | def step_post_backward( 42 | self, 43 | *args, 44 | **kwargs, 45 | ): 46 | """Callback function to be executed after the `loss.backward()` call.""" 47 | pass 48 | -------------------------------------------------------------------------------- /gsplat/strategy/default.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Tuple, Union 3 | 4 | import torch 5 | from typing_extensions import Literal 6 | 7 | from .base import Strategy 8 | from .ops import duplicate, remove, reset_opa, split 9 | 10 | 11 | @dataclass 12 | class DefaultStrategy(Strategy): 13 | """A default strategy that follows the original 3DGS paper: 14 | 15 | `3D Gaussian Splatting for Real-Time Radiance Field Rendering `_ 16 | 17 | The strategy will: 18 | 19 | - Periodically duplicate GSs with high image plane gradients and small scales. 20 | - Periodically split GSs with high image plane gradients and large scales. 21 | - Periodically prune GSs with low opacity. 22 | - Periodically reset GSs to a lower opacity. 23 | 24 | If `absgrad=True`, it will use the absolute gradients instead of average gradients 25 | for GS duplicating & splitting, following the AbsGS paper: 26 | 27 | `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_ 28 | 29 | Which typically leads to better results but requires to set the `grow_grad2d` to a 30 | higher value, e.g., 0.0008. Also, the :func:`rasterization` function should be called 31 | with `absgrad=True` as well so that the absolute gradients are computed. 32 | 33 | Args: 34 | prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. 35 | grow_grad2d (float): GSs with image plane gradient above this value will be 36 | split/duplicated. Default is 0.0002. 37 | grow_scale3d (float): GSs with 3d scale (normalized by scene_scale) below this 38 | value will be duplicated. Above will be split. Default is 0.01. 39 | grow_scale2d (float): GSs with 2d scale (normalized by image resolution) above 40 | this value will be split. Default is 0.05. 41 | prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this 42 | value will be pruned. Default is 0.1. 43 | prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above 44 | this value will be pruned. Default is 0.15. 45 | refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this 46 | iteration. Default is 0. Set to a positive value to enable this feature. 47 | refine_start_iter (int): Start refining GSs after this iteration. Default is 500. 48 | refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. 49 | reset_every (int): Reset opacities every this steps. Default is 3000. 50 | refine_every (int): Refine GSs every this steps. Default is 100. 51 | pause_refine_after_reset (int): Pause refining GSs until this number of steps after 52 | reset, Default is 0 (no pause at all) and one might want to set this number to the 53 | number of images in training set. 54 | absgrad (bool): Use absolute gradients for GS splitting. Default is False. 55 | revised_opacity (bool): Whether to use revised opacity heuristic from 56 | arXiv:2404.06109 (experimental). Default is False. 57 | verbose (bool): Whether to print verbose information. Default is False. 58 | key_for_gradient (str): Which variable uses for densification strategy. 59 | 3DGS uses "means2d" gradient and 2DGS uses a similar gradient which stores 60 | in variable "gradient_2dgs". 61 | 62 | Examples: 63 | 64 | >>> from gsplat import DefaultStrategy, rasterization 65 | >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... 66 | >>> optimizers: Dict[str, torch.optim.Optimizer] = ... 67 | >>> strategy = DefaultStrategy() 68 | >>> strategy.check_sanity(params, optimizers) 69 | >>> strategy_state = strategy.initialize_state() 70 | >>> for step in range(1000): 71 | ... render_image, render_alpha, info = rasterization(...) 72 | ... strategy.step_pre_backward(params, optimizers, strategy_state, step, info) 73 | ... loss = ... 74 | ... loss.backward() 75 | ... strategy.step_post_backward(params, optimizers, strategy_state, step, info) 76 | 77 | """ 78 | 79 | prune_opa: float = 0.005 80 | grow_grad2d: float = 0.0002 81 | grow_scale3d: float = 0.01 82 | grow_scale2d: float = 0.05 83 | prune_scale3d: float = 0.1 84 | prune_scale2d: float = 0.15 85 | refine_scale2d_stop_iter: int = 0 86 | refine_start_iter: int = 500 87 | refine_stop_iter: int = 15_000 88 | reset_every: int = 3000 89 | refine_every: int = 100 90 | pause_refine_after_reset: int = 0 91 | absgrad: bool = False 92 | revised_opacity: bool = False 93 | verbose: bool = False 94 | key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d" 95 | 96 | def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: 97 | """Initialize and return the running state for this strategy. 98 | 99 | The returned state should be passed to the `step_pre_backward()` and 100 | `step_post_backward()` functions. 101 | """ 102 | # Postpone the initialization of the state to the first step so that we can 103 | # put them on the correct device. 104 | # - grad2d: running accum of the norm of the image plane gradients for each GS. 105 | # - count: running accum of how many time each GS is visible. 106 | # - radii: the radii of the GSs (normalized by the image resolution). 107 | state = {"grad2d": None, "count": None, "scene_scale": scene_scale} 108 | if self.refine_scale2d_stop_iter > 0: 109 | state["radii"] = None 110 | return state 111 | 112 | def check_sanity( 113 | self, 114 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 115 | optimizers: Dict[str, torch.optim.Optimizer], 116 | ): 117 | """Sanity check for the parameters and optimizers. 118 | 119 | Check if: 120 | * `params` and `optimizers` have the same keys. 121 | * Each optimizer has exactly one param_group, corresponding to each parameter. 122 | * The following keys are present: {"means", "scales", "quats", "opacities"}. 123 | 124 | Raises: 125 | AssertionError: If any of the above conditions is not met. 126 | 127 | .. note:: 128 | It is not required but highly recommended for the user to call this function 129 | after initializing the strategy to ensure the convention of the parameters 130 | and optimizers is as expected. 131 | """ 132 | 133 | super().check_sanity(params, optimizers) 134 | # The following keys are required for this strategy. 135 | for key in ["means", "scales", "quats", "opacities"]: 136 | assert key in params, f"{key} is required in params but missing." 137 | 138 | def step_pre_backward( 139 | self, 140 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 141 | optimizers: Dict[str, torch.optim.Optimizer], 142 | state: Dict[str, Any], 143 | step: int, 144 | info: Dict[str, Any], 145 | ): 146 | """Callback function to be executed before the `loss.backward()` call.""" 147 | assert ( 148 | self.key_for_gradient in info 149 | ), "The 2D means of the Gaussians is required but missing." 150 | info[self.key_for_gradient].retain_grad() 151 | 152 | def step_post_backward( 153 | self, 154 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 155 | optimizers: Dict[str, torch.optim.Optimizer], 156 | state: Dict[str, Any], 157 | step: int, 158 | info: Dict[str, Any], 159 | packed: bool = False, 160 | ): 161 | """Callback function to be executed after the `loss.backward()` call.""" 162 | if step >= self.refine_stop_iter: 163 | return 164 | 165 | self._update_state(params, state, info, packed=packed) 166 | 167 | if ( 168 | step > self.refine_start_iter 169 | and step % self.refine_every == 0 170 | and step % self.reset_every >= self.pause_refine_after_reset 171 | ): 172 | # grow GSs 173 | n_dupli, n_split = self._grow_gs(params, optimizers, state, step) 174 | if self.verbose: 175 | print( 176 | f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " 177 | f"Now having {len(params['means'])} GSs." 178 | ) 179 | 180 | # prune GSs 181 | n_prune = self._prune_gs(params, optimizers, state, step) 182 | if self.verbose: 183 | print( 184 | f"Step {step}: {n_prune} GSs pruned. " 185 | f"Now having {len(params['means'])} GSs." 186 | ) 187 | 188 | # reset running stats 189 | state["grad2d"].zero_() 190 | state["count"].zero_() 191 | if self.refine_scale2d_stop_iter > 0: 192 | state["radii"].zero_() 193 | torch.cuda.empty_cache() 194 | 195 | if step % self.reset_every == 0: 196 | reset_opa( 197 | params=params, 198 | optimizers=optimizers, 199 | state=state, 200 | value=self.prune_opa * 2.0, 201 | ) 202 | 203 | def _update_state( 204 | self, 205 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 206 | state: Dict[str, Any], 207 | info: Dict[str, Any], 208 | packed: bool = False, 209 | ): 210 | for key in [ 211 | "width", 212 | "height", 213 | "n_cameras", 214 | "radii", 215 | "gaussian_ids", 216 | self.key_for_gradient, 217 | ]: 218 | assert key in info, f"{key} is required but missing." 219 | 220 | # normalize grads to [-1, 1] screen space 221 | if self.absgrad: 222 | grads = info[self.key_for_gradient].absgrad.clone() 223 | else: 224 | grads = info[self.key_for_gradient].grad.clone() 225 | grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] 226 | grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] 227 | 228 | # initialize state on the first run 229 | n_gaussian = len(list(params.values())[0]) 230 | 231 | if state["grad2d"] is None: 232 | state["grad2d"] = torch.zeros(n_gaussian, device=grads.device) 233 | if state["count"] is None: 234 | state["count"] = torch.zeros(n_gaussian, device=grads.device) 235 | if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: 236 | assert "radii" in info, "radii is required but missing." 237 | state["radii"] = torch.zeros(n_gaussian, device=grads.device) 238 | 239 | # update the running state 240 | if packed: 241 | # grads is [nnz, 2] 242 | gs_ids = info["gaussian_ids"] # [nnz] 243 | radii = info["radii"].max(dim=-1).values # [nnz] 244 | else: 245 | # grads is [C, N, 2] 246 | sel = info["radii"][..., 0] > 0.0 # [C, N] 247 | gs_ids = torch.where(sel)[1] # [nnz] 248 | grads = grads[sel] # [nnz, 2] 249 | radii = info["radii"][sel].max(dim=-1).values # [nnz] 250 | 251 | state["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) 252 | state["count"].index_add_(0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32)) 253 | if self.refine_scale2d_stop_iter > 0: 254 | # Should be ideally using scatter max 255 | state["radii"][gs_ids] = torch.maximum( 256 | state["radii"][gs_ids], 257 | # normalize radii to [0, 1] screen space 258 | radii / float(max(info["width"], info["height"])), 259 | ) 260 | 261 | @torch.no_grad() 262 | def _grow_gs( 263 | self, 264 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 265 | optimizers: Dict[str, torch.optim.Optimizer], 266 | state: Dict[str, Any], 267 | step: int, 268 | ) -> Tuple[int, int]: 269 | count = state["count"] 270 | grads = state["grad2d"] / count.clamp_min(1) 271 | device = grads.device 272 | 273 | is_grad_high = grads > self.grow_grad2d 274 | is_small = ( 275 | torch.exp(params["scales"]).max(dim=-1).values 276 | <= self.grow_scale3d * state["scene_scale"] 277 | ) 278 | is_dupli = is_grad_high & is_small 279 | n_dupli = is_dupli.sum().item() 280 | 281 | is_large = ~is_small 282 | is_split = is_large 283 | if step < self.refine_scale2d_stop_iter: 284 | is_split |= state["radii"] > self.grow_scale2d 285 | is_split = is_grad_high & is_split 286 | n_split = is_split.sum().item() 287 | 288 | # first duplicate 289 | if n_dupli > 0: 290 | duplicate(params=params, optimizers=optimizers, state=state, mask=is_dupli) 291 | 292 | # new GSs added by duplication will not be split 293 | is_split = torch.cat( 294 | [ 295 | is_split, 296 | torch.zeros(n_dupli, dtype=torch.bool, device=device), 297 | ] 298 | ) 299 | 300 | # then split 301 | if n_split > 0: 302 | split( 303 | params=params, 304 | optimizers=optimizers, 305 | state=state, 306 | mask=is_split, 307 | revised_opacity=self.revised_opacity, 308 | ) 309 | return n_dupli, n_split 310 | 311 | @torch.no_grad() 312 | def _prune_gs( 313 | self, 314 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 315 | optimizers: Dict[str, torch.optim.Optimizer], 316 | state: Dict[str, Any], 317 | step: int, 318 | ) -> int: 319 | is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa 320 | if step > self.reset_every: 321 | is_too_big = ( 322 | torch.exp(params["scales"]).max(dim=-1).values 323 | > self.prune_scale3d * state["scene_scale"] 324 | ) 325 | # The official code also implements sreen-size pruning but 326 | # it's actually not being used due to a bug: 327 | # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 328 | # We implement it here for completeness but set `refine_scale2d_stop_iter` 329 | # to 0 by default to disable it. 330 | if step < self.refine_scale2d_stop_iter: 331 | is_too_big |= state["radii"] > self.prune_scale2d 332 | 333 | is_prune = is_prune | is_too_big 334 | 335 | n_prune = is_prune.sum().item() 336 | if n_prune > 0: 337 | remove(params=params, optimizers=optimizers, state=state, mask=is_prune) 338 | 339 | return n_prune 340 | -------------------------------------------------------------------------------- /gsplat/strategy/mcmc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Union 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from .base import Strategy 9 | from .ops import inject_noise_to_position, relocate, sample_add 10 | 11 | 12 | @dataclass 13 | class MCMCStrategy(Strategy): 14 | """Strategy that follows the paper: 15 | 16 | `3D Gaussian Splatting as Markov Chain Monte Carlo `_ 17 | 18 | This strategy will: 19 | 20 | - Periodically teleport GSs with low opacity to a place that has high opacity. 21 | - Periodically introduce new GSs sampled based on the opacity distribution. 22 | - Periodically perturb the GSs locations. 23 | 24 | Args: 25 | cap_max (int): Maximum number of GSs. Default to 1_000_000. 26 | noise_lr (float): MCMC samping noise learning rate. Default to 5e5. 27 | refine_start_iter (int): Start refining GSs after this iteration. Default to 500. 28 | refine_stop_iter (int): Stop refining GSs after this iteration. Default to 25_000. 29 | refine_every (int): Refine GSs every this steps. Default to 100. 30 | min_opacity (float): GSs with opacity below this value will be pruned. Default to 0.005. 31 | verbose (bool): Whether to print verbose information. Default to False. 32 | 33 | Examples: 34 | 35 | >>> from gsplat import MCMCStrategy, rasterization 36 | >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... 37 | >>> optimizers: Dict[str, torch.optim.Optimizer] = ... 38 | >>> strategy = MCMCStrategy() 39 | >>> strategy.check_sanity(params, optimizers) 40 | >>> strategy_state = strategy.initialize_state() 41 | >>> for step in range(1000): 42 | ... render_image, render_alpha, info = rasterization(...) 43 | ... loss = ... 44 | ... loss.backward() 45 | ... strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3) 46 | 47 | """ 48 | 49 | cap_max: int = 1_000_000 50 | noise_lr: float = 5e5 51 | refine_start_iter: int = 500 52 | refine_stop_iter: int = 25_000 53 | refine_every: int = 100 54 | min_opacity: float = 0.005 55 | verbose: bool = False 56 | 57 | def initialize_state(self) -> Dict[str, Any]: 58 | """Initialize and return the running state for this strategy.""" 59 | n_max = 51 60 | binoms = torch.zeros((n_max, n_max)) 61 | for n in range(n_max): 62 | for k in range(n + 1): 63 | binoms[n, k] = math.comb(n, k) 64 | return {"binoms": binoms} 65 | 66 | def check_sanity( 67 | self, 68 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 69 | optimizers: Dict[str, torch.optim.Optimizer], 70 | ): 71 | """Sanity check for the parameters and optimizers. 72 | 73 | Check if: 74 | * `params` and `optimizers` have the same keys. 75 | * Each optimizer has exactly one param_group, corresponding to each parameter. 76 | * The following keys are present: {"means", "scales", "quats", "opacities"}. 77 | 78 | Raises: 79 | AssertionError: If any of the above conditions is not met. 80 | 81 | .. note:: 82 | It is not required but highly recommended for the user to call this function 83 | after initializing the strategy to ensure the convention of the parameters 84 | and optimizers is as expected. 85 | """ 86 | 87 | super().check_sanity(params, optimizers) 88 | # The following keys are required for this strategy. 89 | for key in ["means", "scales", "quats", "opacities"]: 90 | assert key in params, f"{key} is required in params but missing." 91 | 92 | # def step_pre_backward( 93 | # self, 94 | # params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 95 | # optimizers: Dict[str, torch.optim.Optimizer], 96 | # # state: Dict[str, Any], 97 | # step: int, 98 | # info: Dict[str, Any], 99 | # ): 100 | # """Callback function to be executed before the `loss.backward()` call.""" 101 | # pass 102 | 103 | def step_post_backward( 104 | self, 105 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 106 | optimizers: Dict[str, torch.optim.Optimizer], 107 | state: Dict[str, Any], 108 | step: int, 109 | info: Dict[str, Any], 110 | lr: float, 111 | ): 112 | """Callback function to be executed after the `loss.backward()` call. 113 | 114 | Args: 115 | lr (float): Learning rate for "means" attribute of the GS. 116 | """ 117 | # move to the correct device 118 | state["binoms"] = state["binoms"].to(params["means"].device) 119 | 120 | binoms = state["binoms"] 121 | 122 | if ( 123 | step < self.refine_stop_iter 124 | and step > self.refine_start_iter 125 | and step % self.refine_every == 0 126 | ): 127 | # teleport GSs 128 | n_relocated_gs = self._relocate_gs(params, optimizers, binoms) 129 | if self.verbose: 130 | print(f"Step {step}: Relocated {n_relocated_gs} GSs.") 131 | 132 | # add new GSs 133 | n_new_gs = self._add_new_gs(params, optimizers, binoms) 134 | if self.verbose: 135 | print( 136 | f"Step {step}: Added {n_new_gs} GSs. " 137 | f"Now having {len(params['means'])} GSs." 138 | ) 139 | 140 | torch.cuda.empty_cache() 141 | 142 | # add noise to GSs 143 | inject_noise_to_position( 144 | params=params, optimizers=optimizers, state={}, scaler=lr * self.noise_lr 145 | ) 146 | 147 | @torch.no_grad() 148 | def _relocate_gs( 149 | self, 150 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 151 | optimizers: Dict[str, torch.optim.Optimizer], 152 | binoms: Tensor, 153 | ) -> int: 154 | opacities = torch.sigmoid(params["opacities"].flatten()) 155 | dead_mask = opacities <= self.min_opacity 156 | n_gs = dead_mask.sum().item() 157 | if n_gs > 0: 158 | relocate( 159 | params=params, 160 | optimizers=optimizers, 161 | state={}, 162 | mask=dead_mask, 163 | binoms=binoms, 164 | min_opacity=self.min_opacity, 165 | ) 166 | return n_gs 167 | 168 | @torch.no_grad() 169 | def _add_new_gs( 170 | self, 171 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 172 | optimizers: Dict[str, torch.optim.Optimizer], 173 | binoms: Tensor, 174 | ) -> int: 175 | current_n_points = len(params["means"]) 176 | n_target = min(self.cap_max, int(1.05 * current_n_points)) 177 | n_gs = max(0, n_target - current_n_points) 178 | if n_gs > 0: 179 | sample_add( 180 | params=params, 181 | optimizers=optimizers, 182 | state={}, 183 | n=n_gs, 184 | binoms=binoms, 185 | min_opacity=self.min_opacity, 186 | ) 187 | return n_gs 188 | -------------------------------------------------------------------------------- /gsplat/strategy/ops.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from gsplat import quat_scale_to_covar_preci 9 | from gsplat.relocation import compute_relocation 10 | from gsplat.utils import normalized_quat_to_rotmat 11 | 12 | 13 | @torch.no_grad() 14 | def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Tensor: 15 | """Sample from a distribution using torch.multinomial or numpy.random.choice. 16 | 17 | This function adaptively chooses between `torch.multinomial` and `numpy.random.choice` 18 | based on the number of elements in `weights`. If the number of elements exceeds 19 | the torch.multinomial limit (2^24), it falls back to using `numpy.random.choice`. 20 | 21 | Args: 22 | weights (Tensor): A 1D tensor of weights for each element. 23 | n (int): The number of samples to draw. 24 | replacement (bool): Whether to sample with replacement. Default is True. 25 | 26 | Returns: 27 | Tensor: A 1D tensor of sampled indices. 28 | """ 29 | num_elements = weights.size(0) 30 | 31 | if num_elements <= 2**24: 32 | # Use torch.multinomial for elements within the limit 33 | return torch.multinomial(weights, n, replacement=replacement) 34 | else: 35 | # Fallback to numpy.random.choice for larger element spaces 36 | weights = weights / weights.sum() 37 | weights_np = weights.detach().cpu().numpy() 38 | sampled_idxs_np = np.random.choice( 39 | num_elements, size=n, p=weights_np, replace=replacement 40 | ) 41 | sampled_idxs = torch.from_numpy(sampled_idxs_np) 42 | 43 | # Return the sampled indices on the original device 44 | return sampled_idxs.to(weights.device) 45 | 46 | 47 | @torch.no_grad() 48 | def _update_param_with_optimizer( 49 | param_fn: Callable[[str, Tensor], Tensor], 50 | optimizer_fn: Callable[[str, Tensor], Tensor], 51 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 52 | optimizers: Dict[str, torch.optim.Optimizer], 53 | names: Union[List[str], None] = None, 54 | ): 55 | """Update the parameters and the state in the optimizers with defined functions. 56 | 57 | Args: 58 | param_fn: A function that takes the name of the parameter and the parameter itself, 59 | and returns the new parameter. 60 | optimizer_fn: A function that takes the key of the optimizer state and the state value, 61 | and returns the new state value. 62 | params: A dictionary of parameters. 63 | optimizers: A dictionary of optimizers, each corresponding to a parameter. 64 | names: A list of key names to update. If None, update all. Default: None. 65 | """ 66 | if names is None: 67 | # If names is not provided, update all parameters 68 | names = list(params.keys()) 69 | 70 | # Split the names into optimized and un-optimized 71 | un_optimized_names = set(names) - set(optimizers.keys()) 72 | optimized_names = set(names) - un_optimized_names 73 | 74 | for name in optimized_names: 75 | optimizer = optimizers[name] 76 | for i, param_group in enumerate(optimizer.param_groups): 77 | p = param_group["params"][0] 78 | p_state = optimizer.state[p] 79 | del optimizer.state[p] 80 | for key in p_state.keys(): 81 | if key != "step": 82 | v = p_state[key] 83 | p_state[key] = optimizer_fn(key, v) 84 | p_new = param_fn(name, p) 85 | optimizer.param_groups[i]["params"] = [p_new] 86 | optimizer.state[p_new] = p_state 87 | params[name] = p_new 88 | 89 | for name in un_optimized_names: 90 | params[name] = param_fn(name, params[name]) 91 | 92 | 93 | @torch.no_grad() 94 | def duplicate( 95 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 96 | optimizers: Dict[str, torch.optim.Optimizer], 97 | state: Dict[str, Tensor], 98 | mask: Tensor, 99 | ): 100 | """Inplace duplicate the Gaussian with the given mask. 101 | 102 | Args: 103 | params: A dictionary of parameters. 104 | optimizers: A dictionary of optimizers, each corresponding to a parameter. 105 | mask: A boolean mask to duplicate the Gaussians. 106 | """ 107 | device = mask.device 108 | sel = torch.where(mask)[0] 109 | 110 | def param_fn(name: str, p: Tensor) -> Tensor: 111 | return torch.nn.Parameter(torch.cat([p, p[sel]])) 112 | 113 | def optimizer_fn(key: str, v: Tensor) -> Tensor: 114 | return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)]) 115 | 116 | # update the parameters and the state in the optimizers 117 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) 118 | # update the extra running state 119 | for k, v in state.items(): 120 | if isinstance(v, torch.Tensor): 121 | state[k] = torch.cat((v, v[sel])) 122 | 123 | 124 | @torch.no_grad() 125 | def split( 126 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 127 | optimizers: Dict[str, torch.optim.Optimizer], 128 | state: Dict[str, Tensor], 129 | mask: Tensor, 130 | revised_opacity: bool = False, 131 | ): 132 | """Inplace split the Gaussian with the given mask. 133 | 134 | Args: 135 | params: A dictionary of parameters. 136 | optimizers: A dictionary of optimizers, each corresponding to a parameter. 137 | mask: A boolean mask to split the Gaussians. 138 | revised_opacity: Whether to use revised opacity formulation 139 | from arXiv:2404.06109. Default: False. 140 | """ 141 | device = mask.device 142 | sel = torch.where(mask)[0] 143 | rest = torch.where(~mask)[0] 144 | 145 | scales = torch.exp(params["scales"][sel]) 146 | quats = F.normalize(params["quats"][sel], dim=-1) 147 | rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] 148 | samples = torch.einsum( 149 | "nij,nj,bnj->bni", 150 | rotmats, 151 | scales, 152 | torch.randn(2, len(scales), 3, device=device), 153 | ) # [2, N, 3] 154 | 155 | def param_fn(name: str, p: Tensor) -> Tensor: 156 | repeats = [2] + [1] * (p.dim() - 1) 157 | if name == "means": 158 | p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3] 159 | elif name == "scales": 160 | p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3] 161 | elif name == "opacities" and revised_opacity: 162 | new_opacities = 1.0 - torch.sqrt(1.0 - torch.sigmoid(p[sel])) 163 | p_split = torch.logit(new_opacities).repeat(repeats) # [2N] 164 | else: 165 | p_split = p[sel].repeat(repeats) 166 | p_new = torch.cat([p[rest], p_split]) 167 | p_new = torch.nn.Parameter(p_new) 168 | return p_new 169 | 170 | def optimizer_fn(key: str, v: Tensor) -> Tensor: 171 | v_split = torch.zeros((2 * len(sel), *v.shape[1:]), device=device) 172 | return torch.cat([v[rest], v_split]) 173 | 174 | # update the parameters and the state in the optimizers 175 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) 176 | # update the extra running state 177 | for k, v in state.items(): 178 | if isinstance(v, torch.Tensor): 179 | repeats = [2] + [1] * (v.dim() - 1) 180 | v_new = v[sel].repeat(repeats) 181 | state[k] = torch.cat((v[rest], v_new)) 182 | 183 | 184 | @torch.no_grad() 185 | def remove( 186 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 187 | optimizers: Dict[str, torch.optim.Optimizer], 188 | state: Dict[str, Tensor], 189 | mask: Tensor, 190 | ): 191 | """Inplace remove the Gaussian with the given mask. 192 | 193 | Args: 194 | params: A dictionary of parameters. 195 | optimizers: A dictionary of optimizers, each corresponding to a parameter. 196 | mask: A boolean mask to remove the Gaussians. 197 | """ 198 | sel = torch.where(~mask)[0] 199 | 200 | def param_fn(name: str, p: Tensor) -> Tensor: 201 | return torch.nn.Parameter(p[sel]) 202 | 203 | def optimizer_fn(key: str, v: Tensor) -> Tensor: 204 | return v[sel] 205 | 206 | # update the parameters and the state in the optimizers 207 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) 208 | # update the extra running state 209 | for k, v in state.items(): 210 | if isinstance(v, torch.Tensor): 211 | state[k] = v[sel] 212 | 213 | 214 | @torch.no_grad() 215 | def reset_opa( 216 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 217 | optimizers: Dict[str, torch.optim.Optimizer], 218 | state: Dict[str, Tensor], 219 | value: float, 220 | ): 221 | """Inplace reset the opacities to the given post-sigmoid value. 222 | 223 | Args: 224 | params: A dictionary of parameters. 225 | optimizers: A dictionary of optimizers, each corresponding to a parameter. 226 | value: The value to reset the opacities 227 | """ 228 | 229 | def param_fn(name: str, p: Tensor) -> Tensor: 230 | if name == "opacities": 231 | opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item()) 232 | return torch.nn.Parameter(opacities) 233 | else: 234 | raise ValueError(f"Unexpected parameter name: {name}") 235 | 236 | def optimizer_fn(key: str, v: Tensor) -> Tensor: 237 | return torch.zeros_like(v) 238 | 239 | # update the parameters and the state in the optimizers 240 | _update_param_with_optimizer( 241 | param_fn, optimizer_fn, params, optimizers, names=["opacities"] 242 | ) 243 | 244 | 245 | @torch.no_grad() 246 | def relocate( 247 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 248 | optimizers: Dict[str, torch.optim.Optimizer], 249 | state: Dict[str, Tensor], 250 | mask: Tensor, 251 | binoms: Tensor, 252 | min_opacity: float = 0.005, 253 | ): 254 | """Inplace relocate some dead Gaussians to the lives ones. 255 | 256 | Args: 257 | params: A dictionary of parameters. 258 | optimizers: A dictionary of optimizers, each corresponding to a parameter. 259 | mask: A boolean mask to indicates which Gaussians are dead. 260 | """ 261 | # support "opacities" with shape [N,] or [N, 1] 262 | opacities = torch.sigmoid(params["opacities"]) 263 | 264 | dead_indices = mask.nonzero(as_tuple=True)[0] 265 | alive_indices = (~mask).nonzero(as_tuple=True)[0] 266 | n = len(dead_indices) 267 | 268 | # Sample for new GSs 269 | eps = torch.finfo(torch.float32).eps 270 | probs = opacities[alive_indices].flatten() # ensure its shape is [N,] 271 | sampled_idxs = _multinomial_sample(probs, n, replacement=True) 272 | sampled_idxs = alive_indices[sampled_idxs] 273 | new_opacities, new_scales = compute_relocation( 274 | opacities=opacities[sampled_idxs], 275 | scales=torch.exp(params["scales"])[sampled_idxs], 276 | ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, 277 | binoms=binoms, 278 | ) 279 | new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity) 280 | 281 | def param_fn(name: str, p: Tensor) -> Tensor: 282 | if name == "opacities": 283 | p[sampled_idxs] = torch.logit(new_opacities) 284 | elif name == "scales": 285 | p[sampled_idxs] = torch.log(new_scales) 286 | p[dead_indices] = p[sampled_idxs] 287 | return torch.nn.Parameter(p) 288 | 289 | def optimizer_fn(key: str, v: Tensor) -> Tensor: 290 | v[sampled_idxs] = 0 291 | return v 292 | 293 | # update the parameters and the state in the optimizers 294 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) 295 | # update the extra running state 296 | for k, v in state.items(): 297 | if isinstance(v, torch.Tensor): 298 | v[sampled_idxs] = 0 299 | 300 | 301 | @torch.no_grad() 302 | def sample_add( 303 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 304 | optimizers: Dict[str, torch.optim.Optimizer], 305 | state: Dict[str, Tensor], 306 | n: int, 307 | binoms: Tensor, 308 | min_opacity: float = 0.005, 309 | ): 310 | opacities = torch.sigmoid(params["opacities"]) 311 | 312 | eps = torch.finfo(torch.float32).eps 313 | probs = opacities.flatten() 314 | sampled_idxs = _multinomial_sample(probs, n, replacement=True) 315 | new_opacities, new_scales = compute_relocation( 316 | opacities=opacities[sampled_idxs], 317 | scales=torch.exp(params["scales"])[sampled_idxs], 318 | ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, 319 | binoms=binoms, 320 | ) 321 | new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity) 322 | 323 | def param_fn(name: str, p: Tensor) -> Tensor: 324 | if name == "opacities": 325 | p[sampled_idxs] = torch.logit(new_opacities) 326 | elif name == "scales": 327 | p[sampled_idxs] = torch.log(new_scales) 328 | p = torch.cat([p, p[sampled_idxs]]) 329 | return torch.nn.Parameter(p) 330 | 331 | def optimizer_fn(key: str, v: Tensor) -> Tensor: 332 | v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device) 333 | return torch.cat([v, v_new]) 334 | 335 | # update the parameters and the state in the optimizers 336 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) 337 | # update the extra running state 338 | for k, v in state.items(): 339 | v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device) 340 | if isinstance(v, torch.Tensor): 341 | state[k] = torch.cat((v, v_new)) 342 | 343 | 344 | @torch.no_grad() 345 | def inject_noise_to_position( 346 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], 347 | optimizers: Dict[str, torch.optim.Optimizer], 348 | state: Dict[str, Tensor], 349 | scaler: float, 350 | ): 351 | opacities = torch.sigmoid(params["opacities"].flatten()) 352 | scales = torch.exp(params["scales"]) 353 | covars, _ = quat_scale_to_covar_preci( 354 | params["quats"], 355 | scales, 356 | compute_covar=True, 357 | compute_preci=False, 358 | triu=False, 359 | ) 360 | 361 | def op_sigmoid(x, k=100, x0=0.995): 362 | return 1 / (1 + torch.exp(-k * (x - x0))) 363 | 364 | noise = ( 365 | torch.randn_like(params["means"]) 366 | * (op_sigmoid(1 - opacities)).unsqueeze(-1) 367 | * scaler 368 | ) 369 | noise = torch.einsum("bij,bj->bi", covars, noise) 370 | params["means"].add_(noise) 371 | -------------------------------------------------------------------------------- /gsplat/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | 8 | def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: 9 | """Convert normalized quaternion to rotation matrix. 10 | 11 | Args: 12 | quat: Normalized quaternion in wxyz convension. (..., 4) 13 | 14 | Returns: 15 | Rotation matrix (..., 3, 3) 16 | """ 17 | assert quat.shape[-1] == 4, quat.shape 18 | w, x, y, z = torch.unbind(quat, dim=-1) 19 | mat = torch.stack( 20 | [ 21 | 1 - 2 * (y**2 + z**2), 22 | 2 * (x * y - w * z), 23 | 2 * (x * z + w * y), 24 | 2 * (x * y + w * z), 25 | 1 - 2 * (x**2 + z**2), 26 | 2 * (y * z - w * x), 27 | 2 * (x * z - w * y), 28 | 2 * (y * z + w * x), 29 | 1 - 2 * (x**2 + y**2), 30 | ], 31 | dim=-1, 32 | ) 33 | return mat.reshape(quat.shape[:-1] + (3, 3)) 34 | 35 | 36 | def log_transform(x): 37 | return torch.sign(x) * torch.log1p(torch.abs(x)) 38 | 39 | 40 | def inverse_log_transform(y): 41 | return torch.sign(y) * (torch.expm1(torch.abs(y))) 42 | 43 | 44 | def depth_to_points( 45 | depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True 46 | ) -> Tensor: 47 | """Convert depth maps to 3D points 48 | 49 | Args: 50 | depths: Depth maps [..., H, W, 1] 51 | camtoworlds: Camera-to-world transformation matrices [..., 4, 4] 52 | Ks: Camera intrinsics [..., 3, 3] 53 | z_depth: Whether the depth is in z-depth (True) or ray depth (False) 54 | 55 | Returns: 56 | points: 3D points in the world coordinate system [..., H, W, 3] 57 | """ 58 | assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}" 59 | assert camtoworlds.shape[-2:] == ( 60 | 4, 61 | 4, 62 | ), f"Invalid viewmats shape: {camtoworlds.shape}" 63 | assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}" 64 | assert ( 65 | depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2] 66 | ), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}" 67 | 68 | device = depths.device 69 | height, width = depths.shape[-3:-1] 70 | 71 | x, y = torch.meshgrid( 72 | torch.arange(width, device=device), 73 | torch.arange(height, device=device), 74 | indexing="xy", 75 | ) # [H, W] 76 | 77 | fx = Ks[..., 0, 0] # [...] 78 | fy = Ks[..., 1, 1] # [...] 79 | cx = Ks[..., 0, 2] # [...] 80 | cy = Ks[..., 1, 2] # [...] 81 | 82 | # camera directions in camera coordinates 83 | camera_dirs = F.pad( 84 | torch.stack( 85 | [ 86 | (x - cx[..., None, None] + 0.5) / fx[..., None, None], 87 | (y - cy[..., None, None] + 0.5) / fy[..., None, None], 88 | ], 89 | dim=-1, 90 | ), 91 | (0, 1), 92 | value=1.0, 93 | ) # [..., H, W, 3] 94 | 95 | # ray directions in world coordinates 96 | directions = torch.einsum( 97 | "...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs 98 | ) # [..., H, W, 3] 99 | origins = camtoworlds[..., :3, -1] # [..., 3] 100 | 101 | if not z_depth: 102 | directions = F.normalize(directions, dim=-1) 103 | 104 | points = origins[..., None, None, :] + depths * directions 105 | return points 106 | 107 | 108 | def depth_to_normal( 109 | depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True 110 | ) -> Tensor: 111 | """Convert depth maps to surface normals 112 | 113 | Args: 114 | depths: Depth maps [..., H, W, 1] 115 | camtoworlds: Camera-to-world transformation matrices [..., 4, 4] 116 | Ks: Camera intrinsics [..., 3, 3] 117 | z_depth: Whether the depth is in z-depth (True) or ray depth (False) 118 | 119 | Returns: 120 | normals: Surface normals in the world coordinate system [..., H, W, 3] 121 | """ 122 | points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3] 123 | dx = torch.cat( 124 | [points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3 125 | ) # [..., H-2, W-2, 3] 126 | dy = torch.cat( 127 | [points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2 128 | ) # [..., H-2, W-2, 3] 129 | normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3] 130 | normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3] 131 | return normals 132 | 133 | 134 | def get_projection_matrix(znear, zfar, fovX, fovY, device="cuda"): 135 | """Create OpenGL-style projection matrix""" 136 | tanHalfFovY = math.tan((fovY / 2)) 137 | tanHalfFovX = math.tan((fovX / 2)) 138 | 139 | top = tanHalfFovY * znear 140 | bottom = -top 141 | right = tanHalfFovX * znear 142 | left = -right 143 | 144 | P = torch.zeros(4, 4, device=device) 145 | 146 | z_sign = 1.0 147 | 148 | P[0, 0] = 2.0 * znear / (right - left) 149 | P[1, 1] = 2.0 * znear / (top - bottom) 150 | P[0, 2] = (right + left) / (right - left) 151 | P[1, 2] = (top + bottom) / (top - bottom) 152 | P[3, 2] = z_sign 153 | P[2, 2] = z_sign * zfar / (zfar - znear) 154 | P[2, 3] = -(zfar * znear) / (zfar - znear) 155 | return P 156 | 157 | 158 | # def depth_to_normal( 159 | # depths: Tensor, camtoworlds: Tensor, Ks: Tensor, near_plane: float, far_plane: float 160 | # ) -> Tensor: 161 | # """ 162 | # Convert depth to surface normal 163 | 164 | # Args: 165 | # depths: Z-depth of the Gaussians. 166 | # camtoworlds: camera to world transformation matrix. 167 | # Ks: camera intrinsics. 168 | # near_plane: Near plane distance. 169 | # far_plane: Far plane distance. 170 | 171 | # Returns: 172 | # Surface normals. 173 | # """ 174 | # height, width = depths.shape[1:3] 175 | # viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4] 176 | 177 | # normals = [] 178 | # for cid, depth in enumerate(depths): 179 | # FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) 180 | # FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) 181 | # world_view_transform = viewmats[cid].transpose(0, 1) 182 | # projection_matrix = _get_projection_matrix( 183 | # znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device 184 | # ).transpose(0, 1) 185 | # full_proj_transform = ( 186 | # world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) 187 | # ).squeeze(0) 188 | # normal = _depth_to_normal( 189 | # depth, 190 | # world_view_transform, 191 | # full_proj_transform, 192 | # Ks[cid, 0, 0], 193 | # Ks[cid, 1, 1], 194 | # ) 195 | # normals.append(normal) 196 | # normals = torch.stack(normals, dim=0) 197 | # return normals 198 | 199 | 200 | # # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/61c7b417393d5e0c58b742ad5e2e5f9e9f240cc6/utils/point_utils.py#L26 201 | # def _depths_to_points( 202 | # depthmap, world_view_transform, full_proj_transform, fx, fy 203 | # ) -> Tensor: 204 | # c2w = (world_view_transform.T).inverse() 205 | # H, W = depthmap.shape[:2] 206 | 207 | # intrins = ( 208 | # torch.tensor([[fx, 0.0, W / 2.0], [0.0, fy, H / 2.0], [0.0, 0.0, 1.0]]) 209 | # .float() 210 | # .cuda() 211 | # ) 212 | 213 | # grid_x, grid_y = torch.meshgrid( 214 | # torch.arange(W, device="cuda").float(), 215 | # torch.arange(H, device="cuda").float(), 216 | # indexing="xy", 217 | # ) 218 | # points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( 219 | # -1, 3 220 | # ) 221 | # rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T 222 | # rays_o = c2w[:3, 3] 223 | # points = depthmap.reshape(-1, 1) * rays_d + rays_o 224 | # return points 225 | 226 | 227 | # def _depth_to_normal( 228 | # depth, world_view_transform, full_proj_transform, fx, fy 229 | # ) -> Tensor: 230 | # points = _depths_to_points( 231 | # depth, 232 | # world_view_transform, 233 | # full_proj_transform, 234 | # fx, 235 | # fy, 236 | # ).reshape(*depth.shape[:2], 3) 237 | # output = torch.zeros_like(points) 238 | # dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 239 | # dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 240 | # normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 241 | # output[1:-1, 1:-1, :] = normal_map 242 | # return output 243 | 244 | 245 | # def _get_projection_matrix(znear, zfar, fovX, fovY, device="cuda") -> Tensor: 246 | # tanHalfFovY = math.tan((fovY / 2)) 247 | # tanHalfFovX = math.tan((fovX / 2)) 248 | 249 | # top = tanHalfFovY * znear 250 | # bottom = -top 251 | # right = tanHalfFovX * znear 252 | # left = -right 253 | 254 | # P = torch.zeros(4, 4, device=device) 255 | 256 | # z_sign = 1.0 257 | 258 | # P[0, 0] = 2.0 * znear / (right - left) 259 | # P[1, 1] = 2.0 * znear / (top - bottom) 260 | # P[0, 2] = (right + left) / (right - left) 261 | # P[1, 2] = (top + bottom) / (top - bottom) 262 | # P[3, 2] = z_sign 263 | # P[2, 2] = z_sign * zfar / (zfar - znear) 264 | # P[2, 3] = -(zfar * znear) / (zfar - znear) 265 | # return P 266 | -------------------------------------------------------------------------------- /gsplat/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | import platform 5 | import sys 6 | 7 | from setuptools import find_packages, setup 8 | 9 | __version__ = None 10 | exec(open("gsplat/version.py", "r").read()) 11 | 12 | URL = "https://github.com/nerfstudio-project/gsplat" 13 | 14 | BUILD_NO_CUDA = os.getenv("BUILD_NO_CUDA", "0") == "1" 15 | WITH_SYMBOLS = os.getenv("WITH_SYMBOLS", "0") == "1" 16 | LINE_INFO = os.getenv("LINE_INFO", "0") == "1" 17 | 18 | 19 | def get_ext(): 20 | from torch.utils.cpp_extension import BuildExtension 21 | 22 | return BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) 23 | 24 | 25 | def get_extensions(): 26 | import torch 27 | from torch.__config__ import parallel_info 28 | from torch.utils.cpp_extension import CUDAExtension 29 | 30 | extensions_dir_v1 = osp.join("gsplat", "cuda_legacy", "csrc") 31 | sources_v1 = glob.glob(osp.join(extensions_dir_v1, "*.cu")) + glob.glob( 32 | osp.join(extensions_dir_v1, "*.cpp") 33 | ) 34 | sources_v1 = [path for path in sources_v1 if "hip" not in path] 35 | 36 | extensions_dir_v2 = osp.join("gsplat", "cuda", "csrc") 37 | sources_v2 = glob.glob(osp.join(extensions_dir_v2, "*.cu")) + glob.glob( 38 | osp.join(extensions_dir_v2, "*.cpp") 39 | ) 40 | sources_v2 = [path for path in sources_v2 if "hip" not in path] 41 | 42 | undef_macros = [] 43 | define_macros = [] 44 | 45 | if sys.platform == "win32": 46 | define_macros += [("gsplat_EXPORTS", None)] 47 | 48 | extra_compile_args = {"cxx": ["-O3"]} 49 | if not os.name == "nt": # Not on Windows: 50 | extra_compile_args["cxx"] += ["-Wno-sign-compare"] 51 | extra_link_args = [] if WITH_SYMBOLS else ["-s"] 52 | 53 | info = parallel_info() 54 | if ( 55 | "backend: OpenMP" in info 56 | and "OpenMP not found" not in info 57 | and sys.platform != "darwin" 58 | ): 59 | extra_compile_args["cxx"] += ["-DAT_PARALLEL_OPENMP"] 60 | if sys.platform == "win32": 61 | extra_compile_args["cxx"] += ["/openmp"] 62 | else: 63 | extra_compile_args["cxx"] += ["-fopenmp"] 64 | else: 65 | print("Compiling without OpenMP...") 66 | 67 | # Compile for mac arm64 68 | if sys.platform == "darwin" and platform.machine() == "arm64": 69 | extra_compile_args["cxx"] += ["-arch", "arm64"] 70 | extra_link_args += ["-arch", "arm64"] 71 | 72 | nvcc_flags = os.getenv("NVCC_FLAGS", "") 73 | nvcc_flags = [] if nvcc_flags == "" else nvcc_flags.split(" ") 74 | nvcc_flags += ["-O3", "--use_fast_math"] 75 | if LINE_INFO: 76 | nvcc_flags += ["-lineinfo"] 77 | if torch.version.hip: 78 | # USE_ROCM was added to later versions of PyTorch. 79 | # Define here to support older PyTorch versions as well: 80 | define_macros += [("USE_ROCM", None)] 81 | undef_macros += ["__HIP_NO_HALF_CONVERSIONS__"] 82 | else: 83 | nvcc_flags += ["--expt-relaxed-constexpr"] 84 | extra_compile_args["nvcc"] = nvcc_flags 85 | if sys.platform == "win32": 86 | extra_compile_args["nvcc"] += ["-DWIN32_LEAN_AND_MEAN"] 87 | 88 | extension_v1 = CUDAExtension( 89 | f"gsplat.csrc_legacy", 90 | sources_v1, 91 | include_dirs=[extensions_dir_v2], # glm lives in v2. 92 | define_macros=define_macros, 93 | undef_macros=undef_macros, 94 | extra_compile_args=extra_compile_args, 95 | extra_link_args=extra_link_args, 96 | ) 97 | extension_v2 = CUDAExtension( 98 | f"gsplat.csrc", 99 | sources_v2, 100 | include_dirs=[extensions_dir_v2], # glm lives in v2. 101 | define_macros=define_macros, 102 | undef_macros=undef_macros, 103 | extra_compile_args=extra_compile_args, 104 | extra_link_args=extra_link_args, 105 | ) 106 | 107 | return [extension_v1, extension_v2] 108 | 109 | 110 | setup( 111 | name="gsplat", 112 | version=__version__, 113 | description=" Python package for differentiable rasterization of gaussians", 114 | keywords="gaussian, splatting, cuda", 115 | url=URL, 116 | download_url=f"{URL}/archive/gsplat-{__version__}.tar.gz", 117 | python_requires=">=3.7", 118 | install_requires=[ 119 | "ninja", 120 | "numpy", 121 | "jaxtyping", 122 | "rich>=12", 123 | "torch", 124 | "typing_extensions; python_version<'3.8'", 125 | ], 126 | extras_require={ 127 | # dev dependencies. Install them by `pip install gsplat[dev]` 128 | "dev": [ 129 | "black[jupyter]==22.3.0", 130 | "isort==5.10.1", 131 | "pylint==2.13.4", 132 | "pytest==7.1.2", 133 | "pytest-xdist==2.5.0", 134 | "typeguard>=2.13.3", 135 | "pyyaml==6.0", 136 | "build", 137 | "twine", 138 | ], 139 | }, 140 | ext_modules=get_extensions() if not BUILD_NO_CUDA else [], 141 | cmdclass={"build_ext": get_ext()} if not BUILD_NO_CUDA else {}, 142 | packages=find_packages(), 143 | # https://github.com/pypa/setuptools/issues/1461#issuecomment-954725244 144 | include_package_data=True, 145 | ) 146 | -------------------------------------------------------------------------------- /tests/_test_distributed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from gsplat.distributed import ( 5 | all_gather_int32, 6 | all_gather_tensor_list, 7 | all_to_all_int32, 8 | all_to_all_tensor_list, 9 | cli, 10 | ) 11 | 12 | 13 | def _main_all_gather_int32(local_rank: int, world_rank: int, world_size: int, _): 14 | device = torch.device("cuda", local_rank) 15 | 16 | value = world_rank 17 | collected = all_gather_int32(world_size, value, device=device) 18 | for i in range(world_size): 19 | assert collected[i] == i 20 | 21 | value = torch.tensor(world_rank, device=device, dtype=torch.int) 22 | collected = all_gather_int32(world_size, value, device=device) 23 | for i in range(world_size): 24 | assert collected[i] == torch.tensor(i, device=device, dtype=torch.int) 25 | 26 | 27 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 28 | def test_all_gather_int32(): 29 | cli(_main_all_gather_int32, None, verbose=True) 30 | 31 | 32 | def _main_all_to_all_int32(local_rank: int, world_rank: int, world_size: int, _): 33 | device = torch.device("cuda", local_rank) 34 | 35 | values = list(range(world_size)) 36 | collected = all_to_all_int32(world_size, values, device=device) 37 | for i in range(world_size): 38 | assert collected[i] == world_rank 39 | 40 | values = torch.arange(world_size, device=device, dtype=torch.int) 41 | collected = all_to_all_int32(world_size, values, device=device) 42 | for i in range(world_size): 43 | assert collected[i] == torch.tensor(world_rank, device=device, dtype=torch.int) 44 | 45 | 46 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 47 | def test_all_to_all_int32(): 48 | cli(_main_all_to_all_int32, None, verbose=True) 49 | 50 | 51 | def _main_all_gather_tensor_list(local_rank: int, world_rank: int, world_size: int, _): 52 | device = torch.device("cuda", local_rank) 53 | N = 10 54 | 55 | tensor_list = [ 56 | torch.full((N, 2), world_rank, device=device), 57 | torch.full((N, 3, 3), world_rank, device=device), 58 | ] 59 | 60 | target_list = [ 61 | torch.cat([torch.full((N, 2), i, device=device) for i in range(world_size)]), 62 | torch.cat([torch.full((N, 3, 3), i, device=device) for i in range(world_size)]), 63 | ] 64 | 65 | collected = all_gather_tensor_list(world_size, tensor_list) 66 | for tensor, target in zip(collected, target_list): 67 | assert torch.equal(tensor, target) 68 | 69 | 70 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 71 | def test_all_gather_tensor_list(): 72 | cli(_main_all_gather_tensor_list, None, verbose=True) 73 | 74 | 75 | def _main_all_to_all_tensor_list(local_rank: int, world_rank: int, world_size: int, _): 76 | device = torch.device("cuda", local_rank) 77 | splits = torch.arange(0, world_size, device=device) 78 | N = splits.sum().item() 79 | 80 | tensor_list = [ 81 | torch.full((N, 2), world_rank, device=device), 82 | torch.full((N, 3, 3), world_rank, device=device), 83 | ] 84 | 85 | target_list = [ 86 | torch.cat( 87 | [ 88 | torch.full((splits[world_rank], 2), i, device=device) 89 | for i in range(world_size) 90 | ] 91 | ), 92 | torch.cat( 93 | [ 94 | torch.full((splits[world_rank], 3, 3), i, device=device) 95 | for i in range(world_size) 96 | ] 97 | ), 98 | ] 99 | 100 | collected = all_to_all_tensor_list(world_size, tensor_list, splits) 101 | for tensor, target in zip(collected, target_list): 102 | assert torch.equal(tensor, target) 103 | 104 | 105 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 106 | def test_all_to_all_tensor_list(): 107 | cli(_main_all_to_all_tensor_list, None, verbose=True) 108 | 109 | 110 | if __name__ == "__main__": 111 | test_all_gather_int32() 112 | test_all_to_all_int32() 113 | test_all_gather_tensor_list() 114 | test_all_to_all_tensor_list() 115 | -------------------------------------------------------------------------------- /tests/test_compression.py: -------------------------------------------------------------------------------- 1 | """Tests for the functions in the CUDA extension. 2 | 3 | Usage: 4 | ```bash 5 | pytest -s 6 | ``` 7 | """ 8 | 9 | import pytest 10 | import torch 11 | 12 | device = torch.device("cuda:0") 13 | 14 | 15 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 16 | def test_png_compression(): 17 | from gsplat.compression import PngCompression 18 | 19 | torch.manual_seed(42) 20 | 21 | # Prepare Gaussians 22 | N = 100000 23 | splats = torch.nn.ParameterDict( 24 | { 25 | "means": torch.randn(N, 3), 26 | "scales": torch.randn(N, 3), 27 | "quats": torch.randn(N, 4), 28 | "opacities": torch.randn(N), 29 | "sh0": torch.randn(N, 1, 3), 30 | "shN": torch.randn(N, 24, 3), 31 | "features": torch.randn(N, 128), 32 | } 33 | ).to(device) 34 | compress_dir = "/tmp/gsplat/compression" 35 | 36 | compression_method = PngCompression() 37 | # run compression and save the compressed files to compress_dir 38 | compression_method.compress(compress_dir, splats) 39 | # decompress the compressed files 40 | splats_c = compression_method.decompress(compress_dir) 41 | 42 | 43 | if __name__ == "__main__": 44 | test_png_compression() 45 | -------------------------------------------------------------------------------- /tests/test_rasterization.py: -------------------------------------------------------------------------------- 1 | """Tests for the functions in the CUDA extension. 2 | 3 | Usage: 4 | ```bash 5 | pytest -s 6 | ``` 7 | """ 8 | 9 | import math 10 | from typing import Optional 11 | 12 | import pytest 13 | import torch 14 | 15 | device = torch.device("cuda:0") 16 | 17 | 18 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 19 | @pytest.mark.parametrize("per_view_color", [True, False]) 20 | @pytest.mark.parametrize("sh_degree", [None, 3]) 21 | @pytest.mark.parametrize("render_mode", ["RGB", "RGB+D", "D"]) 22 | @pytest.mark.parametrize("packed", [True, False]) 23 | def test_rasterization(per_view_color: bool, sh_degree: Optional[int], render_mode: str, packed: bool): 24 | from gsplat.rendering import rasterization 25 | 26 | torch.manual_seed(42) 27 | 28 | C, N = 2, 10_000 29 | means = torch.rand(N, 3, device=device) 30 | velocities = torch.randn(N, 3, device=device) * 0.01 31 | quats = torch.randn(N, 4, device=device) 32 | scales = torch.rand(N, 3, device=device) 33 | opacities = torch.rand(N, device=device) 34 | if per_view_color: 35 | if sh_degree is None: 36 | colors = torch.rand(C, N, 3, device=device) 37 | else: 38 | colors = torch.rand(C, N, (sh_degree + 1) ** 2, 3, device=device) 39 | else: 40 | if sh_degree is None: 41 | colors = torch.rand(N, 3, device=device) 42 | else: 43 | colors = torch.rand(N, (sh_degree + 1) ** 2, 3, device=device) 44 | 45 | width, height = 300, 200 46 | focal = 300.0 47 | Ks = torch.tensor( 48 | [[focal, 0.0, width / 2.0], [0.0, focal, height / 2.0], [0.0, 0.0, 1.0]], 49 | device=device, 50 | ).expand(C, -1, -1) 51 | viewmats = torch.eye(4, device=device).expand(C, -1, -1) 52 | 53 | linear_velocity = torch.randn(C, 3, device=device) * 0.01 54 | angular_velocity = torch.randn(C, 3, device=device) * 0.01 55 | rolling_shutter_time = torch.rand(C, device=device) * 0.1 56 | 57 | colors_out, _, _ = rasterization( 58 | means=means, 59 | quats=quats, 60 | scales=scales, 61 | opacities=opacities, 62 | colors=colors, 63 | velocities=velocities, 64 | viewmats=viewmats, 65 | Ks=Ks, 66 | width=width, 67 | height=height, 68 | linear_velocity=linear_velocity, 69 | angular_velocity=angular_velocity, 70 | rolling_shutter_time=rolling_shutter_time, 71 | sh_degree=sh_degree, 72 | render_mode=render_mode, 73 | packed=packed, 74 | ) 75 | 76 | if render_mode == "D": 77 | assert colors_out.shape == (C, height, width, 1) 78 | elif render_mode == "RGB": 79 | assert colors_out.shape == (C, height, width, 3) 80 | elif render_mode == "RGB+D": 81 | assert colors_out.shape == (C, height, width, 4) 82 | 83 | 84 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 85 | @pytest.mark.parametrize("channels", [3, 32, 128]) 86 | def test_lidar_rasterization(channels: int): 87 | from gsplat.rendering import lidar_rasterization 88 | 89 | torch.manual_seed(42) 90 | 91 | C, N = 2, 10_000 92 | means = torch.rand(N, 3, device=device) 93 | quats = torch.randn(N, 4, device=device) 94 | scales = torch.rand(N, 3, device=device) 95 | opacities = torch.rand(N, device=device) 96 | velocities = torch.randn(N, 3, device=device) * 0.01 97 | 98 | min_azimuth = -180 99 | max_azimuth = 180 100 | min_elevation = -45 101 | max_elevation = 45 102 | n_elevation_channels = 32 103 | azimuth_resolution = 0.2 104 | 105 | tile_width = 64 106 | tile_height = 4 107 | tile_elevation_boundaries = torch.linspace( 108 | min_elevation, max_elevation, n_elevation_channels // tile_height + 1, device=means.device 109 | ) 110 | 111 | viewmats = torch.eye(4, device=device).expand(C, -1, -1) 112 | lidar_features = torch.randn(C, len(means), channels, device=device) 113 | 114 | image_width = math.ceil((max_azimuth - min_azimuth) / azimuth_resolution) 115 | raster_pts_azim = torch.linspace( 116 | min_azimuth + azimuth_resolution / 2, max_azimuth - azimuth_resolution / 2, image_width, device=means.device 117 | ) 118 | raster_pts_elev = torch.linspace( 119 | min_elevation + (max_elevation - min_elevation) / n_elevation_channels / 2, 120 | max_elevation - (max_elevation - min_elevation) / n_elevation_channels / 2, 121 | n_elevation_channels, 122 | device=means.device, 123 | ) 124 | raster_pts = torch.stack(torch.meshgrid(raster_pts_elev, raster_pts_azim), dim=-1)[..., [1, 0]] 125 | ranges = torch.rand(n_elevation_channels, image_width, 1, device=device) * 10 126 | keep_range_mask = torch.rand(n_elevation_channels, image_width, device=device) > 0.1 127 | raster_pts = torch.cat([raster_pts, ranges], dim=-1) 128 | raster_pts = raster_pts.unsqueeze(0).repeat(C, 1, 1, 1) 129 | # add randomness 130 | raster_pts += torch.randn_like(raster_pts) * 0.01 * keep_range_mask[None, ..., None] 131 | 132 | linear_velocity = torch.randn(C, 3, device=device) * 0.01 133 | angular_velocity = torch.randn(C, 3, device=device) * 0.01 134 | rolling_shutter_time = torch.rand(C, device=device) * 0.1 135 | 136 | # add timestamps 137 | raster_pts = torch.cat( 138 | [raster_pts, rolling_shutter_time.max() * torch.randn(raster_pts[..., 0:1].shape, device=raster_pts.device)], 139 | dim=-1, 140 | ) 141 | 142 | render_lidar_features, _, _, _ = lidar_rasterization( 143 | means=means, 144 | quats=quats, 145 | scales=scales, 146 | opacities=opacities, 147 | lidar_features=lidar_features, 148 | velocities=velocities, 149 | linear_velocity=linear_velocity, 150 | angular_velocity=angular_velocity, 151 | rolling_shutter_time=rolling_shutter_time, 152 | viewmats=viewmats, 153 | min_azimuth=min_azimuth, 154 | max_azimuth=max_azimuth, 155 | min_elevation=min_elevation, 156 | max_elevation=max_elevation, 157 | n_elevation_channels=n_elevation_channels, 158 | azimuth_resolution=azimuth_resolution, 159 | raster_pts=raster_pts, 160 | tile_width=tile_width, 161 | tile_height=tile_height, 162 | tile_elevation_boundaries=tile_elevation_boundaries, 163 | ) 164 | 165 | n_azimuth_pixels = math.ceil((max_azimuth - min_azimuth) / azimuth_resolution) 166 | assert render_lidar_features.shape == (C, n_elevation_channels, n_azimuth_pixels, channels + 1) 167 | -------------------------------------------------------------------------------- /tests/test_strategy.py: -------------------------------------------------------------------------------- 1 | """Tests for the functions in the CUDA extension. 2 | 3 | Usage: 4 | ```bash 5 | pytest -s 6 | ``` 7 | """ 8 | 9 | import pytest 10 | import torch 11 | 12 | device = torch.device("cuda:0") 13 | 14 | 15 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") 16 | def test_strategy(): 17 | from gsplat.rendering import rasterization 18 | from gsplat.strategy import DefaultStrategy, MCMCStrategy 19 | 20 | torch.manual_seed(42) 21 | 22 | # Prepare Gaussians 23 | N = 100 24 | params = torch.nn.ParameterDict( 25 | { 26 | "means": torch.randn(N, 3), 27 | "scales": torch.rand(N, 3), 28 | "quats": torch.randn(N, 4), 29 | "opacities": torch.rand(N), 30 | "colors": torch.rand(N, 3), 31 | } 32 | ).to(device) 33 | optimizers = {k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items()} 34 | 35 | # A dummy rendering call 36 | render_colors, render_alphas, info = rasterization( 37 | means=params["means"], 38 | quats=params["quats"], # F.normalize is fused into the kernel 39 | scales=torch.exp(params["scales"]), 40 | opacities=torch.sigmoid(params["opacities"]), 41 | colors=params["colors"], 42 | velocities=None, 43 | viewmats=torch.eye(4).unsqueeze(0).to(device), 44 | Ks=torch.eye(3).unsqueeze(0).to(device), 45 | width=10, 46 | height=10, 47 | packed=False, 48 | ) 49 | 50 | # Test DefaultStrategy 51 | strategy = DefaultStrategy(verbose=True) 52 | strategy.check_sanity(params, optimizers) 53 | state = strategy.initialize_state() 54 | strategy.step_pre_backward(params, optimizers, state, step=600, info=info) 55 | render_colors.mean().backward(retain_graph=True) 56 | strategy.step_post_backward(params, optimizers, state, step=600, info=info) 57 | 58 | # Test MCMCStrategy 59 | strategy = MCMCStrategy(verbose=True) 60 | strategy.check_sanity(params, optimizers) 61 | state = strategy.initialize_state() 62 | render_colors.mean().backward(retain_graph=True) 63 | strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3) 64 | 65 | 66 | if __name__ == "__main__": 67 | test_strategy() 68 | --------------------------------------------------------------------------------